目录

jittor_by_czy - Point Cloud Transformer (PCT) for ModelNet40 Classification

基于 Jittor 框架实现 PCT(Point Cloud Transformer)模型,用于 ModelNet40 三维形状分类任务。

项目简介

本项目是计算机图形学课程 PA3 作业,目标是使用 PCT 网络对 ModelNet40 数据集中的三维点云进行分类。PCT 利用 Self-Attention 机制捕捉点云中各点之间的全局依赖关系,实现对三维形状的高效识别。

模型架构

PCT 模型的主要流程:

  1. 输入嵌入: 两个 1D 卷积层将输入点云 (B, 3, N) 映射到 128 维特征空间
  2. Self-Attention: 4 个堆叠的 SA_Layer 提取全局特征(每个 SA 层包含 Q/K/V 卷积 + 注意力机制 + 残差连接)
  3. 特征融合: 拼接 4 个 SA 层的输出 (B, 512, N),经过 1D 卷积得到 (B, 1024, N)
  4. 全局池化: Max Pooling 得到全局特征 (B, 1024)
  5. 分类头: 三层全连接网络输出 40 类的 logits

主要特点

  • 4 层 Self-Attention 捕捉点云全局关系
  • 数据增强:随机 Y 轴旋转、缩放、抖动、平移
  • Cosine Annealing 学习率调度
  • AdamW 优化器

目录

安装

环境要求

  • Python 3.7+
  • CUDA 11.0+(GPU 训练)
  • g++ 12(注意:g++ 14 与 CUDA 12.2 存在兼容性问题)

安装依赖

pip install jittor

g++ 版本问题

如果遇到 _Float32 is undefined 编译错误,需要安装 g++-12:

sudo apt install g++-12
rm -rf ~/.cache/jittor/
export cc_path=g++-12

注意:rm -rf ~/.cache/jittor/ 仅清除 Jittor 的编译缓存,不会删除已安装的包。

数据准备

将数据文件放置在 ./data/ 目录下,需要以下文件:

data/
├── train_points.npy   # 训练集点云 (N, 2048, 3)
├── train_labels.npy   # 训练集标签 (N,)
└── test_points.npy    # 测试集点云 (N, 2048, 3)

使用方法

训练并预测

python pct.py

默认参数:1024 个采样点、batch size 32、200 个 epoch、学习率 0.01。

自定义参数

python pct.py --n_points 1024 --batch_size 64 --epochs 300 --lr 0.01 --seed 42

参数说明

参数 默认值 说明
--data_dir ./data 数据目录路径
--n_points 1024 每个点云的采样点数
--batch_size 32 批大小
--epochs 200 训练轮数
--lr 0.01 初始学习率
--seed 42 随机种子

输出

  • pct_model.pkl: 训练好的模型权重
  • result.json: 测试集预测结果,格式为 {"样本编号": 预测类别}

项目结构

PA3/
├── pct.py              # 主程序(模型、训练、预测)
├── data/               # 数据目录
│   ├── train_points.npy
│   ├── train_labels.npy
│   └── test_points.npy
├── pct_model.pkl       # 模型权重(训练后生成)
├── result.json         # 预测结果(训练后生成)
└── README.md

TODO 说明

本作业需要完成以下两处 TODO:

  1. 数据增强策略 (pct.py 第 68 行): 在 ModelNet40Dataset.__getitem__ 中实现点云数据增强,包括随机 Y 轴旋转、缩放、抖动、平移等
  2. 优化器与学习率调度 (pct.py 第 286 行): 在 main() 中配置优化器(SGD / AdamW 等)和学习率调度策略(Cosine Annealing 等)

参考资料

许可证

本项目为课程作业,仅供学习参考。

关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

28.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802032778号