更新 README 内容
基于 Jittor 框架的 Point Cloud Transformer (PCT) 实现,用于 ModelNet40 三维点云分类任务。
本项目使用 Jittor 深度学习框架实现了 Point Cloud Transformer(PCT)点云分类网络。PCT 不包含生成器(Generator)和判别器(Discriminator),其目标是对输入的三维点云进行特征提取与类别预测。
模型以包含 N 个点的点云(每个点具有三维坐标)作为输入,首先使用两个逐点一维卷积层将每个点映射到高维特征空间,随后串联四个自注意力(Self-Attention)模块建模点与点之间的全局关系。多层特征拼接融合后,通过全局最大池化获得全局描述子,最后送入三层全连接分类头输出 40 个类别的预测结果。
(B, 3, N)
本项目在以下环境中测试通过:
1. 创建 Conda 虚拟环境(推荐)
conda create -n jittor_pa3 python=3.7 conda activate jittor_pa3
2. 安装 Jittor
pip install jittor
3. 验证安装
python -c "import jittor as jt; print('Jittor version:', jt.__version__)"
更多安装方式(Docker、手动安装)请参考 Jittor 官方安装指南。
本项目使用 ModelNet40 三维形状分类数据集。
预处理后的数据集(.npy 格式)可从以下地址下载:
.npy
https://cloud.tsinghua.edu.cn/d/7b27c35a7f5544389a80/
下载后将文件放入 data/ 目录:
data/
data/ ├── train_points.npy # 训练集点云 (9843, 2048, 3) ├── train_labels.npy # 训练集标签 (9843,) ├── test_points.npy # 测试集点云 (2468, 2048, 3) └── categories.txt # 类别名称(40 类)
从头训练 PCT 模型:
python pct.py --epochs 200 --batch_size 32 --lr 0.001
主要参数说明:
--data_dir
./data
--n_points
1024
--batch_size
32
--epochs
200
--lr
0.01
--seed
42
--resume
False
--checkpoint
checkpoint.pkl
--checkpoint_freq
10
python pct.py --resume --checkpoint checkpoint.pkl
训练完成后,模型参数自动保存为 pct_model.pkl,测试集预测结果导出为 result.json,格式为 {"样本编号": 预测类别}。
pct_model.pkl
result.json
{"样本编号": 预测类别}
PCT_jittor/ ├── pct.py # 主源码(模型定义、训练、推理) ├── data/ │ └── categories.txt # 类别名称 ├── .gitignore └── README.md
本项目基于 MIT License 开源。
A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification
版权所有:中国计算机学会技术支持:开源发展技术委员会 京ICP备13000930号-9 京公网安备 11010802047560号
PCT_jittor
基于 Jittor 框架的 Point Cloud Transformer (PCT) 实现,用于 ModelNet40 三维点云分类任务。
项目简介
本项目使用 Jittor 深度学习框架实现了 Point Cloud Transformer(PCT)点云分类网络。PCT 不包含生成器(Generator)和判别器(Discriminator),其目标是对输入的三维点云进行特征提取与类别预测。
模型以包含 N 个点的点云(每个点具有三维坐标)作为输入,首先使用两个逐点一维卷积层将每个点映射到高维特征空间,随后串联四个自注意力(Self-Attention)模块建模点与点之间的全局关系。多层特征拼接融合后,通过全局最大池化获得全局描述子,最后送入三层全连接分类头输出 40 个类别的预测结果。
模型架构
(B, 3, N)点云环境要求
本项目在以下环境中测试通过:
环境配置
1. 创建 Conda 虚拟环境(推荐)
2. 安装 Jittor
3. 验证安装
数据集
本项目使用 ModelNet40 三维形状分类数据集。
下载链接
预处理后的数据集(
.npy格式)可从以下地址下载:数据准备
下载后将文件放入
data/目录:使用方法
训练
从头训练 PCT 模型:
主要参数说明:
--data_dir./data--n_points1024--batch_size32--epochs200--lr0.01--seed42--resumeFalse--checkpointcheckpoint.pkl--checkpoint_freq10断点续训
测试
训练完成后,模型参数自动保存为
pct_model.pkl,测试集预测结果导出为result.json,格式为{"样本编号": 预测类别}。实验结果
项目结构
致谢
许可证
本项目基于 MIT License 开源。