目录

PCT (Point Cloud Transformer) for ModelNet40 Classification

本项目基于 Jittor 深度学习框架实现了 PCT (Point Cloud Transformer) 模型,用于 ModelNet40 数据集上的三维形状分类任务。

1. 环境安装

  • Python 版本: >= 3.7
  • 依赖框架: Jittor

安装依赖:

pip install -r requirements.txt

2. 数据准备

本项目使用预处理好的 ModelNet40 点云数据(npy 格式)。请确保数据文件放置在 data/ 目录下。

目录结构应如下所示:

data/
├── train_points.npy   # 训练集点云 (9843, 2048, 3)
├── train_labels.npy   # 训练集标签 (9843,)
├── test_points.npy    # 测试集点云 (2468, 2048, 3)
└── categories.txt     # 40 个类别名称

数据根目录可以通过命令行参数 --data_dir 进行配置,默认为 ./data

3. 训练

使用 scripts/train.py 脚本进行模型训练。可以通过命令行参数配置训练超参数。

一条可直接运行的训练命令:

python scripts/train.py --data_dir ./data --n_points 2048 --batch_size 32 --epochs 200 --lr 0.1 --seed 42 --output_dir ./outputs

训练过程中,模型权重将保存在 outputs/ 目录下(如 pct_model_latest.pklpct_model_best.pkl)。同时,实际使用的配置会保存到 outputs/config.json,运行命令保存到 outputs/command.txt,日志保存到 outputs/train.log

4. 评测/推理

使用 scripts/eval.py 脚本对测试集进行推理,并生成预测结果文件。

一条可直接运行的评测命令(请将 --ckpt_path 替换为实际的权重路径):

python scripts/eval.py --data_dir ./data --n_points 2048 --batch_size 32 --ckpt_path ./outputs/pct_model_best.pkl --output_file result.json --output_dir ./outputs

运行完成后,将在当前目录下生成 result.json 文件,包含测试集样本的预测分类。评测日志会保存到 outputs/eval.log

5. 结果说明

  • 评测指标: Accuracy (准确率),即预测正确的样本数占总测试样本数的比例。
  • 优化说明: 本项目在 Baseline 基础上增加了数据增强(随机旋转、缩放、平移、抖动),将采样点数提升至 2048,并使用了 Cosine Annealing 学习率调度策略,在 Self-Attention 层中加入了缩放因子以稳定训练。
  • 预期结果: 在测试集上可达到约 81.93% 的准确率。

6. 目录结构

.
├── configs/          # 配置文件目录 (预留)
├── data/             # 数据集目录
├── scripts/          # 运行脚本
│   ├── train.py      # 训练脚本
│   └── eval.py       # 评测/推理脚本
├── src/              # 核心代码
│   ├── dataset.py    # 数据集加载
│   ├── model.py      # PCT 模型定义
│   └── scheduler.py  # 学习率调度器
├── outputs/          # 训练输出目录 (权重、日志等,不提交)
├── README.md         # 项目说明文档
├── requirements.txt  # 依赖列表
├── LICENSE           # 开源许可证
└── .gitignore        # Git 忽略配置
关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

73.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号