目录

jittor PCT

本项目基于 Jittor (计图) 深度学习框架,实现了一个增强版的 Point Cloud Transformer (PCT) 模型,用于在 ModelNet40 数据集上进行 3D 点云分类任务。

主要特性

相比于基础 PCT 模型,本代码进行了以下核心优化:

  • **多头偏移自注意力机制 (Multi-Head Offset Attention)**:计算输入特征与聚合特征之间的偏移量 (Offset),并结合残差连接,提升特征提取能力。
  • 扩大的网络容量:提升了前端卷积的隐藏层维度(最高至 256),特征融合通道数扩大至 1024。
  • 强数据增强策略:在训练阶段自动应用随机 Y 轴旋转、各向异性缩放、平移、随机抖动 (Jitter) 以及点云随机丢弃 (Point Dropout,模拟遮挡)。
  • 高级训练策略
    • 使用带余弦退火和预热 (Cosine Annealing with Warmup) 的 SGD 优化器。
    • 引入标签平滑 (Label Smoothing) 防止模型过拟合。
    • 应用梯度裁剪 (Gradient Clipping) 确保训练过程的稳定性。

环境依赖

运行代码前,需确保环境中安装了以下依赖:

  • Python >= 3.7
  • NumPy
  • Jittor (建议配置 CUDA 环境以获得 GPU 加速)

安装 Jittor:

python -m pip install jittor

(更多 Jittor 安装及 GPU 配置指南请参考 Jittor 官方文档)

数据集准备

项目期望 ModelNet40 数据集以预处理好的 .npy 格式存放。在项目根目录下创建一个 data/ 文件夹,并确保包含以下文件:

├── data/
│   ├── train_points.npy   # 形状: (N_train, 2048, 3)
│   ├── train_labels.npy   # 形状: (N_train,)
│   └── test_points.npy    # 形状: (N_test, 2048, 3)
├── pct.py
└── README.md

(注:测试集标签 test_labels.npy 在预测生成脚本中是不需要的,代码将直接输出预测索引)。 数据集建议前往https://data.educoder.net/api/attachments/att-19d687a579c1e1252?type=application/x-zip-compressed下载。

快速开始

通过命令行可直接运行 pct.py。代码将自动执行训练,并在训练完成后对测试集进行推理预测。

基础运行

python pct.py

自定义参数运行

通过传入不同的命令行参数可调整训练超参:

python pct.py --data_dir ./data --n_points 1024 --batch_size 32 --epochs 200 --lr 0.01 --seed 42

参数说明:

  • --data_dir: 数据集存放路径 (默认: ./data)
  • --n_points: 每个样本采样的点云数量 (默认: 1024)
  • --batch_size: 训练和测试的批次大小 (默认: 32)
  • --epochs: 训练的总轮数 (默认: 200)
  • --lr: 初始学习率 (默认: 0.01)
  • --seed: 随机数种子,用于固定结果 (默认: 42)

输出产物

训练和推理完成后,项目会在当前目录下生成以下两个文件:

  1. pct_model.pkl: 训练好的 Jittor 模型权重文件。
  2. result.json: 测试集的预测结果。格式为键值对 { "样本ID": 预测类别ID },可直接用于下游的评估或提交。

模型结构简述 (EnhancedPCT)

  1. Input: (B, 3, N) 的 3D 坐标点。
  2. Input Embedding: 两层 1D 卷积 (128 -> 256) 提取初始点云局部特征。
  3. Attention Blocks: 4 层串联的 Multi-Head Offset Attention 层。
  4. Feature Fusion: 将 4 层注意力模块的输出拼接为 1024 维特征,通过 1D 卷积和全局最大池化 (Global Max Pooling) 生成全局形状特征向量。
  5. Classification Head: 由全连接层、Dropout (p=0.4) 和 ReLU 组成的 MLP 网络,最终输出 40 个类别的预测对数 (Logits)。
关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

40.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号