目录

PCT Point Cloud Classification

本项目基于 pct.py 实现了一个用于三维点云分类的 Point Cloud Transformer(PCT)基线模型,目标是对 ModelNet40 风格的点云数据进行类别预测,并生成测试集结果文件 result.json

项目功能

pct.py 主要包含以下部分:

  1. ModelNet40Dataset 数据集读取与增强模块

    • 读取训练集和测试集的点云 .npy 文件
    • 训练时支持随机采样固定数量点
    • 训练时启用数据增强,包括随机旋转、随机缩放、随机平移和高斯抖动
  2. PCT 分类模型

    • 先用一维卷积提取点特征
    • 通过多层自注意力模块聚合全局信息
    • 最终输出 40 类分类 logits
  3. 训练与推理流程

    • 使用 SGD 优化器和余弦退火学习率调度
    • 训练完成后保存模型到 pct_model.pkl
    • 对测试集进行推理并生成 result.json

依赖环境

运行前需要安装以下依赖:

  • Python 3
  • jittor
  • numpy

安装示例:

pip install jittor numpy

数据准备

数据集下载:

ModelNet40 点云数据

默认数据目录为 ./data,脚本会读取以下文件:

  • data/train_points.npy
  • data/train_labels.npy
  • data/test_points.npy

其中:

  • train_points.npy:训练集点云,形状通常为 (N, 2048, 3)
  • train_labels.npy:训练集标签,形状通常为 (N,)
  • test_points.npy:测试集点云,形状通常为 (N, 2048, 3)

使用方法

直接运行 pct.py 即可开始训练并生成预测结果:

python pct.py

也可以通过命令行参数调整训练配置:

python pct.py --data_dir ./data --n_points 1024 --batch_size 32 --epochs 200 --lr 0.01 --seed 42

参数说明

  • --data_dir:数据集目录,默认 ./data
  • --n_points:每个样本采样的点数,默认 1024
  • --batch_size:批量大小,默认 32
  • --epochs:训练轮数,默认 200
  • --lr:初始学习率,默认 0.01
  • --seed:随机种子,默认 42

输出文件

脚本运行结束后会生成以下文件:

  • pct_model.pkl:训练好的模型参数
  • result.json:测试集预测结果,格式为 {样本编号: 预测类别}

运行流程

脚本执行顺序如下:

  1. 读取训练集和测试集点云数据
  2. 构建 PCT 模型
  3. 使用训练集进行多轮训练
  4. 保存模型参数
  5. 对测试集进行预测并导出 result.json
关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

41.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号