目录

PCT_jittor_zengchen: 基于 Jittor 的 ModelNet40 点云分类

本项目是计算机图形学基础 PA3 的最小开源版本,实现了一个基于 Jittor 的 Point Cloud Transformer(PCT)点云分类模型,用于 ModelNet40 三维形状分类任务。

按本次开源要求,仓库包含以下文件:

PCT_jittor_zengchen/
├── LISENCE
├── README.md
├── .gitignore
└── pct.py

项目特点

  • 使用 Jittor 深度学习框架实现 PCT 点云分类模型。
  • 支持 ModelNet40 风格的点云输入,每个样本由若干三维点坐标组成。
  • 训练阶段包含随机采样、旋转、姿态扰动、缩放平移、jitter 和轻量点 dropout。
  • 模型主体包含逐点 Conv1d 特征嵌入、4 个自注意力模块、全局最大池化和全连接分类头。
  • 支持 --seed--data_dir--n_points--batch_size--epochs--lr 等关键参数。

环境安装

推荐环境:

  • Python >= 3.8
  • Jittor
  • NumPy
  • Linux/WSL + CUDA GPU 推荐

安装依赖:

python -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
python -m pip install jittor numpy

验证 Jittor 安装:

python -m jittor.test.test_example

若使用 CUDA,可进一步验证:

python -m jittor.test.test_cuda

注意:pct.py 当前默认使用 GPU:

jt.flags.use_cuda = 1

如果只在 CPU 环境运行,需要将其改为:

jt.flags.use_cuda = 0

数据准备

请在本地运行时自行准备数据,并放到仓库根目录下的 data/ 目录:

PCT_jittor_zengchen/
├── pct.py
└── data/
    ├── train_points.npy   # 训练点云,形状通常为 (N_train, 2048, 3)
    ├── train_labels.npy   # 训练标签,形状通常为 (N_train,)
    ├── test_points.npy    # 测试点云,形状通常为 (N_test, 2048, 3)
    └── categories.txt     # 40 个类别名称,每行一个类别,可选

训练与推理

在仓库根目录运行:

python pct.py \
  --data_dir ./data \
  --n_points 1024 \
  --batch_size 32 \
  --epochs 200 \
  --lr 0.001 \
  --seed 42

脚本会完成训练,并在训练结束后自动对测试集进行预测。

运行后会生成:

  • pct_model.pkl:训练得到的模型权重。
  • result.json:测试集预测结果。

结果格式

result.json 的格式为“测试样本编号 -> 预测类别编号”:

{
  "0": 0,
  "1": 5,
  "2": 12
}

可复现性说明

脚本通过 --seed 设置 NumPy 和 Jittor 的随机种子:

np.random.seed(args.seed)
jt.set_global_seed(args.seed)

由于 GPU 并行计算、底层算子实现和数据增强随机性等因素,不同硬件/驱动环境下的结果可能 存在小幅波动。训练日志中的 Train Acc 只表示训练集准确率,不能等同于线上测试准确率。

主要参数

参数 默认值 含义
--data_dir ./data 数据目录
--n_points 1024 每个点云样本采样的点数
--batch_size 32 批大小
--epochs 200 训练轮数
--lr 0.01 输入学习率;当取默认 0.01 时,脚本内部实际使用 0.001
--seed 42 随机种子

参考

在安装 Jittor 时若出现问题,可查阅 Jittor 官方安装文档:https://cg.cs.tsinghua.edu.cn/jittor/download/

关于

PA3: A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

40.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号