目录

ModelNet40 点云分类(PCT 优化实现)

基于 Jittor 的 Point Cloud Transformer,在 ModelNet40 上完成 40 类三维形状分类。本仓库在标准 PCT 流程上做了若干训练与采样层面的改进,单文件 pct.py 即可完成训练与测试集推理。

快速开始

依赖

pip install numpy jittor

运行

python pct.py --data_dir ./data

训练结束后会生成 pct_model_best.pkl(训练集准确率最高时保存)和 result.json(测试集预测)。

数据

下载

课程提供的预处理数据可从清华云盘获取:

ModelNet40 数据集(直链下载)

下载后解压到与 pct.py 同级data/ 目录,例如:

pct_jittor/
├── pct.py
├── data/              ← 解压到这里
│   ├── categories.txt
│   ├── train_points.npy
│   ├── train_labels.npy
│   └── test_points.npy
├── README.md
├── LICENSE
└── .gitignore
  • categories.txt:40 个类别名称,每行一个,行号与标签编号 0–39 对应
  • train_points.npy:训练点云 (N, 2048, 3)
  • train_labels.npy:训练标签 (N,),类别编号 0–39
  • test_points.npy:测试点云 (M, 2048, 3)(无标签)

每个样本含 2048 个 (x, y, z) 坐标;训练时通过 FPS 下采样到 n_points(默认 1024)。

  • 测试集无标签文件;推理时以样本在 test_points.npy 中的下标作为 result.json 的 key。

命令行参数

参数 默认 含义
--data_dir ./data 数据目录
--n_points 1024 每帧采样点数
--batch_size 32 批大小
--epochs 250 训练轮数
--lr 0.001 初始学习率
--seed 42 随机种子

示例:

python pct.py --data_dir ./data --n_points 1024 --batch_size 32 --epochs 250 --lr 0.001 --seed 42

实现要点

与基础 PCT 相比,本实现主要包含:

模块 说明
采样 最远点采样(FPS)替代纯随机下采样,保留几何结构
增强 绕 Y 轴旋转、随机缩放、高斯抖动,再归一化到单位球
骨干 4 层自注意力(SA_Layer)+ 多尺度特征拼接
池化 Max Pooling 与 Avg Pooling 拼接后接全连接
优化 AdamW(weight_decay=1e-4)+ 10 epoch 线性预热 + 余弦退火
损失 Label Smoothing(smoothing=0.2

训练日志按 epoch 输出 loss、准确率、当前学习率;每 20 个 batch 打印一次中间结果。程序会按训练集准确率保存最优权重,再用该权重对测试集推理。

网络结构(概览)

(B, 3, N)
  → Conv1d ×2 + BN + ReLU
  → SA_Layer ×4(共享 Q/K 权重的多头注意力)
  → concat → Conv1d(512→1024)
  → max pool ∥ mean pool → (B, 2048)
  → FC → BN → ReLU → Dropout ×2
  → FC(40)  # logits

参数量约 1.89M(以实际 pct.py 打印为准)。

输出文件

pct_model_best.pkl
训练过程中训练集准确率最高时写入的模型权重。

result.json
测试集预测,格式为「样本序号 → 类别 id」:

{
  "0": 12,
  "1": 5,
  ...
}

key 为字符串形式的样本索引,value 为 0–39 的预测类别。

参考

  • Guo, M. H. et al. PCT: Point Cloud Transformer. arXiv:2012.09688
  • Wu, Z. et al. 3D ShapeNets: A Deep Representation for Volumetric Shapes. ModelNet40. arXiv:1406.5670
关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号