目录

PCT_jittor

基于 Jittor 框架的 Point Cloud Transformer (PCT) 实现,用于 ModelNet40 三维点云分类任务。

Framework Task License

项目简介

本项目使用 Jittor 深度学习框架实现了 Point Cloud Transformer(PCT)点云分类网络。PCT 不包含生成器(Generator)和判别器(Discriminator),其目标是对输入的三维点云进行特征提取与类别预测。

模型以包含 N 个点的点云(每个点具有三维坐标)作为输入,首先使用两个逐点一维卷积层将每个点映射到高维特征空间,随后串联四个自注意力(Self-Attention)模块建模点与点之间的全局关系。多层特征拼接融合后,通过全局最大池化获得全局描述子,最后送入三层全连接分类头输出 40 个类别的预测结果。

模型架构

  • 输入: (B, 3, N) 点云
  • 特征嵌入: 两层 Conv1d (3→128→128) + BatchNorm + ReLU
  • 自注意力: 4 个 SA 模块串联,含残差连接
  • 特征融合: 四层输出拼接 → Conv1d (512→1024) + LeakyReLU
  • 全局池化: 点维度最大池化 → 1024 维全局特征
  • 分类器: 全连接层 (1024→512→256→40) + Dropout

环境要求

本项目在以下环境中测试通过:

环境 版本
操作系统 Ubuntu 24.04 (WSL2)
Python 3.7.16
Jittor 1.3.11
NumPy 1.21.6
g++ 12.4.0
CUDA (可选) 12.2

环境配置

1. 创建 Conda 虚拟环境(推荐)

conda create -n jittor_pa3 python=3.7
conda activate jittor_pa3

2. 安装 Jittor

pip install jittor

3. 验证安装

python -c "import jittor as jt; print('Jittor version:', jt.__version__)"

更多安装方式(Docker、手动安装)请参考 Jittor 官方安装指南

数据集

本项目使用 ModelNet40 三维形状分类数据集。

下载链接

预处理后的数据集(.npy 格式)可从以下地址下载:

https://cloud.tsinghua.edu.cn/d/7b27c35a7f5544389a80/

数据准备

下载后将文件放入 data/ 目录:

data/
├── train_points.npy    # 训练集点云 (9843, 2048, 3)
├── train_labels.npy    # 训练集标签 (9843,)
├── test_points.npy     # 测试集点云 (2468, 2048, 3)
└── categories.txt      # 类别名称(40 类)

使用方法

训练

从头训练 PCT 模型:

python pct.py --epochs 200 --batch_size 32 --lr 0.001

主要参数说明:

参数 默认值 说明
--data_dir ./data 数据集路径
--n_points 1024 每个点云采样的点数
--batch_size 32 训练批大小
--epochs 200 训练轮数
--lr 0.01 初始学习率
--seed 42 随机种子
--resume False 从 checkpoint 恢复训练
--checkpoint checkpoint.pkl Checkpoint 文件路径
--checkpoint_freq 10 每 N 个 epoch 保存一次 checkpoint

断点续训

python pct.py --resume --checkpoint checkpoint.pkl

测试

训练完成后,模型参数自动保存为 pct_model.pkl,测试集预测结果导出为 result.json,格式为 {"样本编号": 预测类别}

实验结果

指标 数值
训练准确率 ~80%
模型参数量 1.37M
训练耗时 ~100 min (RTX 4060)

项目结构

PCT_jittor/
├── pct.py              # 主源码(模型定义、训练、推理)
├── data/
│   └── categories.txt  # 类别名称
├── .gitignore
└── README.md

致谢

许可证

本项目基于 MIT License 开源。

关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

37.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号