目录

基于 Jittor 的 Point Cloud Transformer 点云分类

1. 项目简介

本项目为《计算机图形学基础》课程作业,基于Jittor 深度学习框架实现了一个用于三维点云分类的Point Cloud Transformer(PCT)模型,并在数据集ModelNet40上完成训练与预测任务。本项目采用 PCT 网络结构,通过一维卷积提取点级特征,并利用自注意力机制建模点与点之间的全局关系,最终通过全局池化和全连接分类器输出 ModelNet40 中 40 个类别的预测结果。

2. 项目特点

课程给出的样例代码主要由如下特点:

  • 实现了基于自注意力机制的 Point Cloud Transformer(PCT) 网络;
  • 训练过程中使用绕 Y 轴随机旋转的点云数据增强;
  • 使用交叉熵损失函数进行多分类训练;
  • 使用 SGD 优化器和余弦退火学习率调度策略。

本项目针对于课程给出的样例代码做出了如下修改:

  • 数据增强方面在绕Y轴旋转操作外添加了缩放和微小扰动;
  • 将SGD修改为Adam方法,同时降低学习率保证收敛。

3. 数据集说明

本项目使用的数据集为 ModelNet40 点云分类数据集。该数据集包含 40 类常见三维物体,每个样本由若干个三维点组成。本课程作业使用的数据文件主要包括:

train_points.npy    # 训练集点云数据
train_labels.npy    # 训练集类别标签
test_points.npy     # 测试集点云数据

数据集下载链接:https://cloud.tsinghua.edu.cn/d/00e5101bd08347a5a43e/

4. 输出说明

在项目根目录下运行:

python pct.py

程序会自动读取 data/ 文件夹下的训练数据,并开始训练 PCT 分类模型。训练完成后,模型参数会保存为:

pct_model.pkl

程序还会使用训练好的模型对测试集点云进行预测,并生成预测文件:

result.json

其中result.json 的基本格式为:

{
  "0": 预测类别编号,
  "1": 预测类别编号,
  "2": 预测类别编号
}
关于

A Jittor implementation of Point Cloud Transformer(PCT) for ModelNet40 classification

39.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号