目录

PCT_jittor

本项目基于 Jittor 深度学习框架实现 Point Cloud Transformer (PCT) 网络,用于 ModelNet40 三维形状分类任务。本版本在课程示例代码基础上加入了多种数据增强、warmup + 余弦退火学习率调度、测试时增强 (TTA) 等改进,最终成功实现了87.5%的准确率。

其中数据集data/和运行pct.py训练得到的模型参数文件pct_model.pkl体积较大,未上传至仓库,如需要可以在清华云盘链接进行下载

项目简介

ModelNet40 是一个常用的三维形状分类基准数据集,包含 40 类常见物体,每个样本是一个三维点云。本项目使用 PCT (Point Cloud Transformer) 网络对其进行分类。PCT 的核心思想是把 Transformer 的自注意力机制迁移到点云数据上,利用全局注意力建模点与点之间的几何关系,从而获得强大的形状表征能力。

主要改进

相对课程提供的 baseline 代码,本版本主要做了以下改进:

改进项 Baseline 本版本
数据增强 仅 Y 轴随机旋转 各向异性缩放 + 随机平移 + 高斯抖动 + Y 轴旋转 + 点序打乱
优化器 SGD, wd=1e-4 SGD, momentum=0.9, wd=5e-4
学习率调度 余弦退火 线性 warmup (10 epoch) + 余弦退火
测试时增强 TTA,10 次随机旋转投票
测试点数 1024 2048(训练仍用 1024)
工程改进 一次性训练 每 20 epoch 存 checkpoint,支持 --predict_only 仅推理

环境依赖

  • Python ≥ 3.7
  • Linux / WSL
  • C++ 编译器 (g++)

安装 Python 依赖:

pip install jittor numpy

如需 GPU 加速,请参考 Jittor 官方安装文档 配置 CUDA 环境。

数据准备

数据集请放置在 ./data/ 目录下,文件结构如下:

data/
├── train_points.npy   # (N_train, 2048, 3)
├── train_labels.npy   # (N_train,)
└── test_points.npy    # (N_test,  2048, 3)

数据集来源于课程评测平台 头歌,本仓库不包含数据本身。

使用方法

  1. 训练 + 预测(默认)

    python pct.py

    完整流程:训练 200 epoch;保存模型 pct_model.pkl;在测试集上做 TTA 推理;保存 result.json

  2. 仅推理(已有训练好的模型文件pct_model.pkl)

    python pct.py --predict_only

    依据pct_model.pkl在测试集上做 TTA 推理;保存 result.json

参数

参数 默认值 说明
--data_dir ./data 数据集目录
--n_points 1024 训练时每个点云的点数
--test_n_points 2048 测试时每个点云的点数
--batch_size 32 batch 大小
--epochs 200 训练轮数
--lr 0.01 初始学习率
--warmup_epochs 10 warmup 轮数
--n_votes 10 TTA 投票次数
--seed 42 随机种子
--predict_only 跳过训练,仅推理

文件结构

PCT_jittor/
├── pct.py            # 主程序
├── README.md
├── .gitignore        # git忽略规则
├── data/             # 数据集(需自行下载)
├── pct_model.pkl     # 训练好的模型权重(运行后生成)
└── result.json       # 测试集预测结果(运行后生成)

其中数据集data/和运行pct.py训练得到的模型参数文件pct_model.pkl体积较大,未上传至仓库,如需要可以在清华云盘链接进行下载

关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

51.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号