目录

PCT: Point Cloud Transformer for ModelNet40 Classification

基于 Jittor 框架的 PCT 模型,用于 ModelNet40 三维点云分类任务。

项目结构

.
├── README.md               # 本文件
├── LICENSE                 # MIT 许可证
├── requirements.txt        # Python 依赖
├── .gitignore              # Git 忽略规则
├── configs/
│   └── default.yaml        # 默认训练/评测配置
├── src/
│   ├── __init__.py
│   ├── dataset.py          # 数据集加载与增强
│   ├── model.py            # PCT 模型定义
│   ├── scheduler.py        # 学习率调度器
│   └── utils.py            # 工具函数(种子设置等)
├── scripts/
│   ├── train.py            # 训练入口脚本
│   └── eval.py             # 推理/评测入口脚本
├── data/                   # 数据目录(不提交,见下方说明)
└── outputs/                # 输出目录(不提交)

环境安装

系统要求

  • 操作系统: Linux (Ubuntu 18.04+)
  • Python: 3.7+
  • CUDA: 11.1+(GPU 训练需要)
  • GPU: 推荐 NVIDIA GPU,显存 ≥ 4GB

安装依赖

pip install -r requirements.txt

Jittor 会自动检测 CUDA 环境,无需额外配置。

数据准备

数据下载

从比赛平台下载 ModelNet40 点云数据,解压后放到 data/ 目录下。

数据目录结构

data/
├── train_points.npy    # 训练集点云 (9843, 2048, 3)
├── train_labels.npy    # 训练集标签 (9843,)
├── test_points.npy     # 测试集点云 (2468, 2048, 3)
└── categories.txt      # 40 类别名称

可通过 --data_dir 参数指定数据目录路径(默认 ./data)。

训练

# 使用默认配置训练(200 epochs)
python scripts/train.py --epochs 200

# 使用自定义配置训练
python scripts/train.py --config configs/default.yaml --epochs 200 --lr 0.01 --batch_size 32

# 指定数据路径和随机种子
python scripts/train.py --data_dir ./data --seed 42 --epochs 200

训练完成后,模型权重保存到 outputs/pct_model.pkl,预测结果保存到 outputs/result.json

评测/推理

# 使用已训练的模型进行推理
python scripts/eval.py --ckpt outputs/pct_model.pkl --data_dir ./data

# 指定输出路径
python scripts/eval.py --ckpt outputs/pct_model.pkl --output outputs/result.json

结果说明

指标 说明
Accuracy 分类准确率(正确预测数 / 总测试样本数)
通过标准 测试集准确率 ≥ 0.80

提交格式

result.json 打包为 result.zip 提交:

result.zip
  └── result.json

result.json 格式:

{
    "0": 4,
    "1": 35,
    "2": 10,
    ...
}

key 为测试集样本编号(字符串,从 "0" 开始),value 为预测类别编号(整数,0-39)。

结果复现

使用以下命令可复现结果:

python scripts/train.py --config configs/default.yaml --seed 42

由于数据增强包含随机操作,不同运行间可能存在微小差异(通常 < 1%),属正常现象。

参考文献

License

本项目采用 MIT License

关于
38.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号