目录

PCT 点云分类作业

本项目使用 Jittor 实现 Point Cloud Transformer(PCT),用于 ModelNet40 三维点云分类任务。程序会读取 data/ 下的预处理数据,训练模型,并生成测试集预测结果 result.json

环境要求

  • Python 3.8+
  • Jittor
  • NumPy

建议先安装依赖:

pip install jittor numpy

数据说明

数据文件位于 data/ 目录:

  • train_points.npy:训练集点云数据
  • train_labels.npy:训练集标签
  • test_points.npy:测试集点云数据
  • categories.txt:类别说明

目录结构

HW/
├── pct.py
├── result.json(未上传)
├── data/
│   ├── categories.txt
│   ├── test_points.npy
│   ├── train_labels.npy
│   └── train_points.npy
└── README.md

运行方式

直接执行训练与预测脚本:

python pct.py

常用参数:

  • --data_dir:数据目录,默认 ./data
  • --n_points:每个点云采样点数,默认 1024
  • --batch_size:批大小,默认 32
  • --epochs:训练轮数,默认 200
  • --lr:初始学习率,默认 0.01
  • --seed:随机种子,默认 42

示例:

python pct.py --epochs 100 --batch_size 16 --lr 0.005

输出文件

运行结束后会生成:

  • pct_model.pkl:训练得到的模型权重
  • result.json:测试集预测结果,格式为 {"样本编号": 类别编号}

注意事项

  • 脚本默认启用 GPU:jt.flags.use_cuda = 1。如果本机没有可用 CUDA 环境,需要按需修改脚本。
  • 训练前请确认 data/ 下文件名与 README 中一致。
  • 生成的模型文件和中间缓存不建议提交到仓库。
关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

89.4 MB
邀请码