目录

PCT — Point Cloud Transformer for ModelNet40 Classification

基于 Jittor 框架实现的 PCT(Point Cloud Transformer),用于 ModelNet40 三维形状分类任务。

原论文:PCT: Point Cloud Transformer (Guo et al., 2021)


环境安装

  • Python 3.7
  • CUDA 11.x 或以上(推荐)
pip install -r requirements.txt

数据准备

将数据文件放置于 data/ 目录下,结构如下:

data/
  train_points.npy   # (N, 2048, 3) 训练集点云
  train_labels.npy   # (N,) 训练集标签
  test_points.npy    # (M, 2048, 3) 测试集点云

通过 --data_dir 参数指定数据根目录,默认为 ./data


训练

python pct.py \
  --data_dir ./data \
  --n_points 1024 \
  --batch_size 32 \
  --epochs 200 \
  --lr 0.001 \
  --seed 42

训练过程每 50 epoch 自动保存 checkpoint 到 checkpoint_epochN.pkl,训练结束后保存完整模型到 pct_model.pkl


评测 / 推理

从已训练的模型生成测试集预测结果 result.json

# 使用最终模型
python predict.py

# 使用某个 checkpoint
python predict.py --ckpt checkpoint_epoch200.pkl

# 自定义输出文件名
python predict.py --ckpt pct_model.pkl --output result_v2.json

多卡训练(mpirun)下 Jittor 会自动切分测试集,因此训练脚本末尾的 in-training predict 会自动跳过,请用上述 predict.py 单独生成完整结果。

从 checkpoint 续训:

python pct.py --resume checkpoint_epoch100.pkl --epochs 200

结果说明

评测指标为测试集分类准确率(Overall Accuracy, OA),由头歌平台计算。

配置 OA
旋转 + jitter + 缩放 + Adam 84.28%

模型在 ModelNet40 40 类分类任务上评测,测试集共 2468 个样本,结果以 result.json 格式提交。


第三方声明

本代码参考自 PCT 原论文官方实现,基于 Jittor 框架重写。

关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

14.5 MB
邀请码