Improve PCT training with data augmentation, optimizer, and engineering enhancements
基于 Jittor 框架的 PCT 模型,用于 ModelNet40 三维点云分类(40 类)。
基础代码和训练数据源于第六届计图人工智能挑战赛——热身赛二,包含原始 PCT 模型结构和训练流程。
# 安装依赖 pip install jittor # 训练模型 python pct.py # 指定超参数 python pct.py --epochs 500 --lr 0.001 --batch_size 32 # 从 checkpoint 继续训练 python pct.py --resume pct_model.pkl
输出:
pct_model.pkl
checkpoints/pct_epoch_*.pkl
result.json
我们在 baseline 基础上完成了两个 TODO,并增加了训练工程化改进:
Baseline: 仅随机绕 Y 轴旋转 0~360°。
我们的改动:
三轴旋转的设计选择:初始尝试了全角度 0~360° 旋转,导致训练集过于困难,准确率停滞在 70% 左右。限制 ±20° 后在保持增强效果的同时让任务可学。
Baseline: SGD (lr=0.01) + Cosine Annealing。
选择 ReduceLROnPlateau 而非 Cosine Annealing 的设计理由:余弦退火在配合 early stopping 时存在脱节问题——若训练在 epoch 80 提前停止,lr 几乎未衰减。Plateau 策略让 lr 衰减与模型实际表现挂钩。
--resume
A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification.
PCT — Point Cloud Transformer for ModelNet40 Classification
基于 Jittor 框架的 PCT 模型,用于 ModelNet40 三维点云分类(40 类)。
基础代码和训练数据源于第六届计图人工智能挑战赛——热身赛二,包含原始 PCT 模型结构和训练流程。
Quickstart
输出:
pct_model.pkl— 验证准确率最高的模型checkpoints/pct_epoch_*.pkl— 每 50 epoch 周期保存的 checkpointresult.json— 测试集预测结果改动概览
我们在 baseline 基础上完成了两个 TODO,并增加了训练工程化改进:
TODO 1: 数据增强
Baseline: 仅随机绕 Y 轴旋转 0~360°。
我们的改动:
三轴旋转的设计选择:初始尝试了全角度 0~360° 旋转,导致训练集过于困难,准确率停滞在 70% 左右。限制 ±20° 后在保持增强效果的同时让任务可学。
TODO 2: 优化器与学习率调度
Baseline: SGD (lr=0.01) + Cosine Annealing。
我们的改动:
选择 ReduceLROnPlateau 而非 Cosine Annealing 的设计理由:余弦退火在配合 early stopping 时存在脱节问题——若训练在 epoch 80 提前停止,lr 几乎未衰减。Plateau 策略让 lr 衰减与模型实际表现挂钩。
新增: 验证集与 Early Stopping
pct_model.pkl)checkpoints/pct_epoch_*.pkl)--resume从任意 checkpoint 继续训练结果