目录

Jittor-Point2Image-ResNet

基于 Jittor 框架与 ResNet‑50 的 ModelNet40 点云三视图分类方案(计图挑战赛赛道二热身赛)


📁 项目结构

.
├── train.py          # 训练脚本(含数据预处理、数据集定义、模型、训练流程)
├── predict.py        # 预测脚本(加载最佳模型,对测试集输出 JSON 结果)
├── ksh.py            # 可视化脚本(展示点云及其三视图渲染效果)
├── data/             # 数据集目录(需自行放置,见下文)
│   ├── train_points.npy    # 训练集点云 (9843, 2048, 3)
│   ├── train_labels.npy    # 训练集标签 (9843,)
│   ├── test_points.npy     # 测试集点云 (2468, 2048, 3)  (预测时使用)
│   └── categories.txt      # 40 个类别名称(一行一个)
└── best_model.pkl    # 训练后保存的最佳模型参数(训练时自动生成)

🚀 快速开始

1. 环境依赖

  • Python 3.8+
  • Jittor (建议最新版)
  • NumPy, scikit-learn, wandb (可选,用于实验记录)
  • 其他:Pillow(用于可视化)

安装依赖:

pip install jittor numpy scikit-learn wandb pillow

首次运行训练时,Jittor会自动下载预训练ResNet-50权重(约98MB),请确保网络畅通。

2. 数据准备

将官方提供的 train_points.npy, train_labels.npy, test_points.npy, categories.txt 放入 data/ 目录。

3. 训练模型

python train.py

训练过程中将:

  • 使用分层抽样将训练集拆分为 95% 训练 + 5% 验证(val_split=0.05
  • 将每个点云渲染为三个正交视图(正视、侧视、顶视),尺寸 224×224
  • 使用预训练的 ResNet‑50 提取各视图特征,拼接后经全连接层分类
  • 自动保存验证集准确率最高的模型为 best_model.pkl
  • 支持 wandb 日志记录(需先 wandb login,若无需可注释掉相关代码)

超参数默认值:

  • batch_size = 16
  • learning_rate = 8e-5
  • epochs = 50
  • 学习率调度:每 10 个 epoch 衰减 0.5

4. 预测并生成提交文件

python predict.py

该脚本会加载 best_model.pkl,对 data/test_points.npy 中的测试样本进行预测,并输出 submission.json,格式为:

{"0": 12, "1": 5, ...}

其中 key 为样本编号(字符串),value 为预测类别索引(int,0~39)。

5. 可视化点云与三视图

python ksh.py

可查看随机样本的点云三维散点图及其对应的三视图(RGB 伪彩色渲染),便于理解数据预处理效果。


🧠 方法简述

  1. 数据表示:将每个点云(2048 个三维点)投影到三个正交平面(X‑Y, X‑Z, Y‑Z),并根据深度分三个区间赋予 RGB 通道亮度,生成 224×224 的彩色伪图像。三个视图分别对应三个观察方向。

  2. 模型架构:使用在 ImageNet 上预训练的 ResNet‑50 作为特征提取骨干(移除最后的分类层),每个视图独立通过共享权重的 ResNet‑50 得到 2048 维特征,三个视图的特征拼接为 6144 维向量,最后由一个全连接层映射到 40 个类别。

  3. 训练策略:采用 Adam 优化器,学习率 8e-5,交叉熵损失,并加入随机水平翻转作为数据增强(三个视图同步翻转)。训练 50 个 epoch,每 10 轮学习率减半,保存验证集最佳模型。


📊 结果

  • 实际训练10轮
  • 最终测试集提交准确率:90.44%

📝 注意事项

  • 代码默认启用 CUDA(jt.flags.use_cuda = 1),若无 GPU 可将其改为 0。
  • 若需关闭 wandb,请注释掉 wandb.init()wandb.log() 相关行。
  • 数据集已归一化至单位球,无需额外缩放;预处理中的坐标缩放仅用于图像渲染,不影响原始数据。

📄 参考


如有问题,欢迎提 Issue 或联系作者。

关于

本项目针对计图(Jittor)挑战赛赛道二热身赛,旨在完成 ModelNet40 三维点云数据集的形状分类任务。我们将每个点云样本(2048 个三维点)通过投影映射转换为二维伪图像,并采用在 ImageNet 上预训练的 ResNet‑50 作为骨干网络进行特征提取与分类。

43.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号