目录

PCT (Point Cloud Transformer) - Jittor 实现

本项目基于清华大学自主研发的 Jittor (计图) 深度学习框架,实现了 Point Cloud Transformer (PCT) 网络,用于在经典三维数据集 ModelNet40 上进行 40 类三维物体的形状分类任务。本仓库专为图形学实验 PA3 构建。

1. 项目特点

  • 基于 Transformer 架构:利用自注意力机制(Self-Attention)捕捉全局点点关系,不依赖显式邻域拓扑,提取更稳健的三维结构特征。
  • 多尺度融合:串联 4 层自注意力特征,捕获不同层级的几何语义。
  • 数据增强:内置随机 Y 轴旋转与高斯噪声抖动(Jittering),提高模型抗噪与防过拟合能力。
  • 余弦退火调度:使用 Cosine Annealing 学习率策略配合带动量的 SGD 优化器,保证模型收敛平稳。

2. 目录结构

.
├── pct.py              # 核心代码:含数据集加载、模型定义、训练与推理函数
├── result.json         # 测试集预测结果(由推理程序生成)
├── .gitignore          # 忽略大文件提交规则(如 *.npy 和 *.pkl)
└── README.md           # 项目说明文档

3. 环境依赖

请确保本地已安装以下环境:

  • Python >= 3.7
  • Jittor >= 1.3
  • NumPy

4. 数据集准备

请将 ModelNet40 数据集放置在项目根目录下:

  • 训练集:train_points.npytrain_labels.npy
  • 测试集:test_points.npy

5. 运行指南

5.1 完整运行(训练 + 预测)

默认情况下,运行主程序将依次执行模型训练测试集预测

python pct.py
  • 训练阶段:模型将训练 200 个 Epoch,并在训练准确率上升时将最佳权重保存为 pct_model.pkl
  • 推理阶段:训练结束后,程序自动载入 pct_model.pkltest_points.npy 进行预测,输出 result.json

5.2 仅运行推理(快速生成预测结果)

如果您本地已经拥有训练好的权重文件 pct_model.pkl,无需再次耗时训练,可以修改 pct.py 最下方的入口函数:

if __name__ == '__main__':
    # train()      # 注释掉训练函数
    predict()    # 仅保留预测函数

修改后在终端运行:

python pct.py

程序将在数秒内载入模型、读取测试点云,并更新生成的 result.json

关于
44.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号