目录

PCT 点云分类 - ModelNet40

基于 Jittor 框架的 Point Cloud Transformer (PCT) 实现,用于 ModelNet40 数据集的点云分类任务。

环境要求

  • Python 3.7+
  • Jittor

安装依赖

pip install jittor numpy

数据集

data 文件夹放在项目根目录下,包含以下文件:

data/
├── train_points.npy    # 训练集点云数据
├── train_labels.npy    # 训练集标签
└── test_points.npy     # 测试集点云数据

运行训练

python pct.py

参数说明

参数 默认值 说明
--data_dir ./data 数据目录路径
--n_points 1024 每个点云采样的点数
--batch_size 32 批大小
--epochs 200 训练轮数
--lr 0.01 学习率
--seed 42 随机种子

使用示例

# 使用 64 批大小,训练 50 轮
python pct.py --batch_size 64 --epochs 50

# 使用 512 个点,学习率 0.001
python pct.py --n_points 512 --lr 0.001

输出文件

  • pct_model.pkl - 训练好的模型权重
  • result.json - 测试集预测结果(格式:{"样本ID": "预测类别"}

注意事项

  • 首次运行会进行 JIT 编译,需要稍等片刻
  • 训练时间取决于 GPU 性能,建议使用 GPU 加速
  • 代码中 jt.flags.use_cuda = 0 表示使用 CPU,如需 GPU 请改为 1
关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

40.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号