目录

PCT Using Jittor

项目概览

本项目基于 Jittor 实现 PCT(Point Cloud Transformer) 网络,用于 ModelNet40 三维形状分类任务。训练完成后会保存模型权重,并对测试集生成预测结果文件。

动机与目标

  • 复现实验性基线:实现 PCT 的训练与推理流程。
  • 明确数据处理与增强流程,便于后续优化与对比。
  • 输出可提交的测试集预测结果。

特性

  • 基于 Jittor 的 PCT 分类模型与自注意力层实现。
  • 训练时包含旋转、缩放、平移、抖动等数据增强。
  • 余弦退火学习率调度。
  • 自动保存模型与预测结果。

环境与依赖

Jittor 环境配置

在 WSL 里执行:

conda create -n jittor python=3.9 -y
conda activate jittor
conda install -c conda-forge gcc=10 gxx=10 -y
conda install -c conda-forge libgomp -y
python -m pip install jittor

验证:

python3.9 -m  jittor.test.test_example
python3.9 -m  jittor.test.test_cudnn_op

运行依赖

  • Python 3.9
  • Jittor
  • Numpy
  • CUDA 环境

安装与运行

  1. 进入仓库目录:
cd mzrpctjittor
  1. 启动训练并生成预测:
python pct.py
  1. 常用参数:
python pct.py --data_dir ./data --n_points 1024 --batch_size 32 --epochs 200 --lr 0.01 --seed 42

数据说明

数据放在 data/ 目录下,包含以下文件:

  • train_points.npy:训练点云,形状为 (N, 2048, 3)
  • train_labels.npy:训练标签,形状为 (N,)
  • test_points.npy:测试点云,形状为 (M, 2048, 3)
  • categories.txt:类别名称(对应 40 类)

输出结果

  • pct_model.pkl:训练后的模型权重。
  • result.json:测试集预测结果,格式为 {“样本编号”: 预测类别ID}。

目录结构

mzrpctjittor/
    pct.py
    README.md
    result.json
    data/
        categories.txt
        test_points.npy
        train_labels.npy
        train_points.npy

友情链接

贡献与协作

如需协作或提交改进,请与项目维护者联系。

关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

39.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802032778号