目录

pct_jittor

功能

该项目使用 Jittor 深度学习框架,在经典的三维形状数据集 ModelNet40 上 训练点云分类模型(如 PCT, Point Cloud Transformer),完成三维形状分类任务。

环境配置

推荐使用Python 3.7和Jittor。
安装 Jittor:在 Jittor 官网安装指南跟着指令进行安装。 安装依赖:在终端敲入下面两行,补齐 GPU 支持和 Numpy。

python3 -m jittor_utils.install_cuda
python3 -m pip install numpy

数据准备

由于点云数据集体积较大,未包含在当前 Git 仓库中。请在此下载所需的4个数据文件,并如下图放入data文件夹中。

项目文件夹/
├── src/
│   ├── pct.py             # 核心训练代码
│   └── eval.py            # 用来测试的代码
├── data/                  # 数据文件夹
│   ├── categories.txt     # 40个类别的名字
│   ├── train_points.npy   # 训练用的点云数据
│   ├── train_labels.npy   # 训练用的标签
│   └── test_points.npy    # 测试用的点云数据
└── outputs/               # 运行代码后,自动生成该文件夹
    ├── pct_model.pkl      # 训练好的模型
    └── result.json        # 预测的答案

运行

配好环境,并将数据放入 ./data 文件夹后,在终端输入以下命令:

python3 src/pct.py --data_dir ./data --batch_size 32 --epochs 200 --lr 0.01 --seed 42

或者不加参数(使用默认参数值)直接运行:

python3 src/pct.py

运行结束之后,代码会建立 outputs 文件夹,里面有模型文件pct_model.pkl和预测答案result.json。有训练好的模型文件之后可以直接进行测试。

python3 src/eval.py
关于

A Jittor implementation of Point Cloud Transformer (PCT) for ModelNet40 classification

42.0 KB
邀请码