目录

PCT (Point Cloud Transformer) for ModelNet40 Classification

基于 Jittor 框架的 PCT 模型,用于 ModelNet40 三维形状分类任务。

环境安装

  • Python >= 3.7
  • Jittor
pip install jittor

数据准备

数据文件应放在 data/ 目录下:

data/
├── train_points.npy   # (N, 2048, 3) 训练点云
├── train_labels.npy   # (N,) 训练标签
├── test_points.npy    # (M, 2048, 3) 测试点云
└── categories.txt     # 类别名称

可通过 --data-dir 参数指定数据目录。

训练

python scripts/train.py \
    --data_dir ./data \
    --n_points 1024 \
    --batch_size 32 \
    --epochs 200 \
    --lr 0.01 \
    --seed 42

训练完成后,模型保存至 outputs/pct_model.pkl

评测/推理

python scripts/eval.py \
    --data_dir ./data \
    --checkpoint outputs/pct_model.pkl \
    --output result.json

生成的 result.json 格式为 {"sample_id": predicted_class}

结果说明

  • 指标:Top-1 Accuracy(分类准确率)
  • 模型:PCT (Point Cloud Transformer)
  • 训练轮数默认 200 epochs
  • 使用 SGD 优化器 + Cosine Annealing 学习率调度

项目结构

.
├── configs/           # 配置文件
├── data/              # 数据目录
├── scripts/           # 训练与评测脚本
├── src/               # 核心代码
│   ├── data/          # 数据集
│   ├── models/        # 模型定义
│   ├── scheduler/    # 学习率调度
│   └── utils/         # 工具函数
├── outputs/           # 输出目录(训练产物)
├── LICENSE
└── README.md
关于
34.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802032778号