目录

jittor-Ayaka-PCT

环境安装

数据准备

本项目使用 ModelNet40 点云数据,下载链接为https://cloud.tsinghua.edu.cn/f/f003de5a2e914d1e9e0e/?dl=1

下载后解压至代码同级目录下的 data/ 文件夹中。

训练

训练时自动进行数据增强、学习率余弦退火,并保存模型与测试集预测结果。

# 使用默认参数训练(SGD, 200 epochs, batch 32)
python pct.py

# 自定义参数示例
python pct.py --epochs 250 --batch_size 64 --lr 0.005 --optimizer adam --seed 1234

主要参数说明:

  • --data_dir:数据集路径,默认为 ./data
  • --n_points: 样本采样点数,默认1024
  • --epochs:训练轮数,默认200
  • --batch_size:训练批次大小,默认32
  • --lr:初始学习率,默认0.001
  • --optimizer:优化器,默认sgd
  • --seed:随机种子,默认1234

训练过程中会输出每个 epoch 的训练损失、训练准确率,以及当前学习率和耗时。最终模型保存为 pct_model.pkl,测试集预测结果保存为 result.json。

评测

训练结束后,脚本会自动对测试集进行推理并生成 result.json。

输出文件说明:

  • pct_model.pkl: 训练好的模型
  • result.json: 测试集预测结果,格式为 {id: label},其中 id 为样本编号,label 为预测类别。
{
  "0": 12,
  "1": 25,
  ...
}

结果说明

评价指标:分类准确率(Accuracy),公式为 正确预测样本数 / 总样本数。

关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

249.9 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号