Update pct.py
本项目基于 Jittor 深度学习框架实现 Point Cloud Transformer (PCT) 网络,用于 ModelNet40 三维形状分类任务。本版本在课程示例代码基础上加入了多种数据增强、warmup + 余弦退火学习率调度、测试时增强 (TTA) 等改进,最终成功实现了87.5%的准确率。
其中数据集data/和运行pct.py训练得到的模型参数文件pct_model.pkl体积较大,未上传至仓库,如需要可以在清华云盘链接进行下载
data/
pct.py
pct_model.pkl
ModelNet40 是一个常用的三维形状分类基准数据集,包含 40 类常见物体,每个样本是一个三维点云。本项目使用 PCT (Point Cloud Transformer) 网络对其进行分类。PCT 的核心思想是把 Transformer 的自注意力机制迁移到点云数据上,利用全局注意力建模点与点之间的几何关系,从而获得强大的形状表征能力。
相对课程提供的 baseline 代码,本版本主要做了以下改进:
--predict_only
安装 Python 依赖:
pip install jittor numpy
如需 GPU 加速,请参考 Jittor 官方安装文档 配置 CUDA 环境。
数据集请放置在 ./data/ 目录下,文件结构如下:
./data/
data/ ├── train_points.npy # (N_train, 2048, 3) ├── train_labels.npy # (N_train,) └── test_points.npy # (N_test, 2048, 3)
数据集来源于课程评测平台 头歌,本仓库不包含数据本身。
训练 + 预测(默认)
python pct.py
完整流程:训练 200 epoch;保存模型 pct_model.pkl;在测试集上做 TTA 推理;保存 result.json。
result.json
仅推理(已有训练好的模型文件pct_model.pkl)
python pct.py --predict_only
依据pct_model.pkl在测试集上做 TTA 推理;保存 result.json
--data_dir
./data
--n_points
--test_n_points
--batch_size
--epochs
--lr
--warmup_epochs
--n_votes
--seed
PCT_jittor/ ├── pct.py # 主程序 ├── README.md ├── .gitignore # git忽略规则 ├── data/ # 数据集(需自行下载) ├── pct_model.pkl # 训练好的模型权重(运行后生成) └── result.json # 测试集预测结果(运行后生成)
A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification
版权所有:中国计算机学会技术支持:开源发展技术委员会 京ICP备13000930号-9 京公网安备 11010802047560号
PCT_jittor
本项目基于 Jittor 深度学习框架实现 Point Cloud Transformer (PCT) 网络,用于 ModelNet40 三维形状分类任务。本版本在课程示例代码基础上加入了多种数据增强、warmup + 余弦退火学习率调度、测试时增强 (TTA) 等改进,最终成功实现了87.5%的准确率。
其中数据集
data/和运行pct.py训练得到的模型参数文件pct_model.pkl体积较大,未上传至仓库,如需要可以在清华云盘链接进行下载项目简介
ModelNet40 是一个常用的三维形状分类基准数据集,包含 40 类常见物体,每个样本是一个三维点云。本项目使用 PCT (Point Cloud Transformer) 网络对其进行分类。PCT 的核心思想是把 Transformer 的自注意力机制迁移到点云数据上,利用全局注意力建模点与点之间的几何关系,从而获得强大的形状表征能力。
主要改进
相对课程提供的 baseline 代码,本版本主要做了以下改进:
--predict_only仅推理环境依赖
安装 Python 依赖:
如需 GPU 加速,请参考 Jittor 官方安装文档 配置 CUDA 环境。
数据准备
数据集请放置在
./data/目录下,文件结构如下:使用方法
训练 + 预测(默认)
完整流程:训练 200 epoch;保存模型
pct_model.pkl;在测试集上做 TTA 推理;保存result.json。仅推理(已有训练好的模型文件pct_model.pkl)
依据
pct_model.pkl在测试集上做 TTA 推理;保存result.json参数
--data_dir./data--n_points--test_n_points--batch_size--epochs--lr--warmup_epochs--n_votes--seed--predict_only文件结构
其中数据集
data/和运行pct.py训练得到的模型参数文件pct_model.pkl体积较大,未上传至仓库,如需要可以在清华云盘链接进行下载