目录

PCT_jittor

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

基于 Jittor 深度学习框架实现的 Point Cloud Transformer(PCT)模型,用于 ModelNet40 三维点云分类任务。本项目包含完整的数据加载、模型训练、断点续训与测试集推理流程。

功能特性

  • 使用 Jittor 实现 PCT 网络结构(自注意力层 + 多层特征融合)
  • 支持 ModelNet40 点云数据集的加载与数据增强(旋转、缩放、平移、抖动等)
  • 支持 Cosine Annealing 学习率调度与 SGD 优化
  • 支持断点续训(pct_checkpoint.json + pct_model_last.pkl
  • 训练完成后自动对测试集推理,输出 result.json 预测结果

项目结构

.
├── pct.py                  # 主程序:数据集、模型、训练与推理
├── test_pct_inference.py   # 推理与评估快速测试脚本
├── test_minimal_train.py   # 最小训练环境验证脚本
├── run_train_200.sh        # 自动断点续训脚本(200 epoch)
├── data/
│   ├── categories.txt      # ModelNet40 40 类类别名称
│   ├── train_points.npy    # 训练集点云(需自行下载,见下文)
│   ├── train_labels.npy    # 训练集标签(需自行下载)
│   └── test_points.npy     # 测试集点云(需自行下载)
├── pct_model.pkl           # 最终模型权重(训练后生成)
├── pct_model_best.pkl      # 最佳模型权重(训练后生成)
├── pct_checkpoint.json     # 训练断点信息(训练后生成)
└── result.json             # 测试集预测结果(训练后生成)

环境要求

依赖 说明
Python >= 3.7
Jittor 深度学习框架
NumPy 数据处理

推荐使用 Conda 创建独立环境:

conda create -n jittor_pct python=3.8
conda activate jittor_pct
pip install jittor numpy

若使用 GPU 训练,请确保已安装对应版本的 CUDA 驱动,Jittor 会在首次运行时自动编译 CUDA 后端。

数据集准备

本项目使用预处理好的 ModelNet40 点云数据(.npy 格式),数据集体积较大,未包含在仓库中。请按课程要求从指定网盘下载后,放置到 data/ 目录:

data/
├── train_points.npy   # 形状 (N, 2048, 3)
├── train_labels.npy   # 形状 (N,)
└── test_points.npy    # 形状 (M, 2048, 3)

快速开始

1. 验证 Jittor 环境

python test_minimal_train.py

若正常输出 loss 数值,说明 Jittor 安装成功。

2. 训练模型

默认配置:1024 采样点数、batch size 32、200 epoch、学习率 0.01。

# GPU 训练(默认)
python pct.py

# CPU 训练
python pct.py --cpu

# 划分 10% 验证集
python pct.py --use_val --val_ratio 0.1

# 从断点续训
python pct.py --resume

常用参数:

参数 默认值 说明
--data_dir ./data 数据目录
--n_points 1024 每帧点云采样点数
--batch_size 32 批大小
--epochs 200 训练轮数
--lr 0.01 初始学习率
--seed 42 随机种子
--num_workers 0 DataLoader 进程数(建议保持 0)
--cpu 强制使用 CPU
--resume 从 checkpoint 断点续训

也可使用脚本自动断点续训:

bash run_train_200.sh

3. 推理测试

训练完成后,可用以下脚本验证模型加载与推理:

python test_pct_inference.py

模型说明

PCT(Point Cloud Transformer)通过自注意力机制捕获点云中的全局与局部结构信息。本实现的主要结构如下:

  1. 输入嵌入:两层 1×1 卷积 + BatchNorm,将 (B, 3, N) 映射到 128 维特征
  2. 自注意力层:4 层 SA_Layer,逐层提取点间关系
  3. 特征融合:拼接 4 层输出后经 1×1 卷积升维至 1024 维
  4. 全局池化 + 分类头:Max Pooling + 两层全连接,输出 40 类 logits

模型参数量约 1.37M

训练结果

在 ModelNet40 训练集上(200 epoch,1024 点采样):

指标 数值
训练准确率 ~99.7%
最佳指标(best metric) 99.72%
测试集预测数 2468

训练日志、模型权重与数据集等大文件建议通过 .gitignore 排除,避免上传到 Gitlink(详见仓库根目录 .gitignore)。

输出文件

文件 说明
pct_model.pkl 最后一轮训练权重
pct_model_best.pkl 训练过程中指标最优的权重
pct_model_last.pkl 用于断点续训的最新权重
pct_checkpoint.json 断点信息(epoch、best_metric)
result.json 测试集预测结果,格式 {"样本编号": 类别索引}

常见问题

Q: 运行时提示 未找到 jittor

请先激活正确的 Conda 环境:conda activate jittor_pct

Q: 训练时出现 segfault?

请将 --num_workers 保持为 0,并设置环境变量:

export OMP_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export MKL_NUM_THREADS=1

Q: 找不到数据文件?

确认 data/train_points.npydata/train_labels.npydata/test_points.npy 已下载并放置正确。

参考资料

许可证

本项目仅供课程学习与交流使用。

关于
520.9 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号