目录

PCT_jittor

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

简介

本项目基于 Jittor 深度学习框架,实现了点云分类网络 **Point Cloud Transformer (PCT)**。在 ModelNet40 数据集上,PCT 通过自注意力机制建模点云中点的全局关系,获得具有高判别力的全局特征,最终完成 40 类三维形状分类任务。

项目结构

.
├── pct.py
├── README.md
├── .gitignore
└── result.json

依赖环境

快速安装 Jittor(以 Linux 为例):

pip install jittor

数据集准备

ModelNet40 数据集已预处理为 .npy 格式,目录结构应如下:

data/
├── train_points.npy      # 训练点云坐标,形状 (num_samples, num_points, 3)
├── train_labels.npy      # 训练标签,形状 (num_samples,)
└── test_points.npy       # 测试点云坐标,形状 (num_samples, num_points, 3)

脚本 pct.py 默认从 data/ 目录读取,如需修改请编辑代码中的路径变量。

使用方法

训练模型

运行以下命令开始训练(会自动使用 GPU,若无 GPU 则自动切换至 CPU):

python pct.py

训练结束后保存模型参数为 pct_model.pkl,脚本会自动对测试集进行预测,并生成 result.json 文件,格式为:

{
    "0": 12,
    "1": 5,
    ...
}

其中键为测试样本索引,值为预测类别 ID(0~39)。

开源协议

本项目采用 MIT 许可证。

关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification.

41.0 KB
邀请码