Update README.md
基于 Jittor 框架与 ResNet‑50 的 ModelNet40 点云三视图分类方案(计图挑战赛赛道二热身赛)
. ├── train.py # 训练脚本(含数据预处理、数据集定义、模型、训练流程) ├── predict.py # 预测脚本(加载最佳模型,对测试集输出 JSON 结果) ├── ksh.py # 可视化脚本(展示点云及其三视图渲染效果) ├── data/ # 数据集目录(需自行放置,见下文) │ ├── train_points.npy # 训练集点云 (9843, 2048, 3) │ ├── train_labels.npy # 训练集标签 (9843,) │ ├── test_points.npy # 测试集点云 (2468, 2048, 3) (预测时使用) │ └── categories.txt # 40 个类别名称(一行一个) └── best_model.pkl # 训练后保存的最佳模型参数(训练时自动生成)
安装依赖:
pip install jittor numpy scikit-learn wandb pillow
首次运行训练时,Jittor会自动下载预训练ResNet-50权重(约98MB),请确保网络畅通。
将官方提供的 train_points.npy, train_labels.npy, test_points.npy, categories.txt 放入 data/ 目录。
train_points.npy
train_labels.npy
test_points.npy
categories.txt
data/
python train.py
训练过程中将:
val_split=0.05
best_model.pkl
wandb login
超参数默认值:
python predict.py
该脚本会加载 best_model.pkl,对 data/test_points.npy 中的测试样本进行预测,并输出 submission.json,格式为:
data/test_points.npy
submission.json
{"0": 12, "1": 5, ...}
其中 key 为样本编号(字符串),value 为预测类别索引(int,0~39)。
python ksh.py
可查看随机样本的点云三维散点图及其对应的三视图(RGB 伪彩色渲染),便于理解数据预处理效果。
数据表示:将每个点云(2048 个三维点)投影到三个正交平面(X‑Y, X‑Z, Y‑Z),并根据深度分三个区间赋予 RGB 通道亮度,生成 224×224 的彩色伪图像。三个视图分别对应三个观察方向。
模型架构:使用在 ImageNet 上预训练的 ResNet‑50 作为特征提取骨干(移除最后的分类层),每个视图独立通过共享权重的 ResNet‑50 得到 2048 维特征,三个视图的特征拼接为 6144 维向量,最后由一个全连接层映射到 40 个类别。
训练策略:采用 Adam 优化器,学习率 8e-5,交叉熵损失,并加入随机水平翻转作为数据增强(三个视图同步翻转)。训练 50 个 epoch,每 10 轮学习率减半,保存验证集最佳模型。
jt.flags.use_cuda = 1
wandb.init()
wandb.log()
如有问题,欢迎提 Issue 或联系作者。
本项目针对计图(Jittor)挑战赛赛道二热身赛,旨在完成 ModelNet40 三维点云数据集的形状分类任务。我们将每个点云样本(2048 个三维点)通过投影映射转换为二维伪图像,并采用在 ImageNet 上预训练的 ResNet‑50 作为骨干网络进行特征提取与分类。
版权所有:中国计算机学会技术支持:开源发展技术委员会 京ICP备13000930号-9 京公网安备 11010802047560号
Jittor-Point2Image-ResNet
基于 Jittor 框架与 ResNet‑50 的 ModelNet40 点云三视图分类方案(计图挑战赛赛道二热身赛)
📁 项目结构
🚀 快速开始
1. 环境依赖
安装依赖:
首次运行训练时,Jittor会自动下载预训练ResNet-50权重(约98MB),请确保网络畅通。
2. 数据准备
将官方提供的
train_points.npy,train_labels.npy,test_points.npy,categories.txt放入data/目录。3. 训练模型
训练过程中将:
val_split=0.05)best_model.pklwandb login,若无需可注释掉相关代码)超参数默认值:
4. 预测并生成提交文件
该脚本会加载
best_model.pkl,对data/test_points.npy中的测试样本进行预测,并输出submission.json,格式为:其中 key 为样本编号(字符串),value 为预测类别索引(int,0~39)。
5. 可视化点云与三视图
可查看随机样本的点云三维散点图及其对应的三视图(RGB 伪彩色渲染),便于理解数据预处理效果。
🧠 方法简述
数据表示:将每个点云(2048 个三维点)投影到三个正交平面(X‑Y, X‑Z, Y‑Z),并根据深度分三个区间赋予 RGB 通道亮度,生成 224×224 的彩色伪图像。三个视图分别对应三个观察方向。
模型架构:使用在 ImageNet 上预训练的 ResNet‑50 作为特征提取骨干(移除最后的分类层),每个视图独立通过共享权重的 ResNet‑50 得到 2048 维特征,三个视图的特征拼接为 6144 维向量,最后由一个全连接层映射到 40 个类别。
训练策略:采用 Adam 优化器,学习率 8e-5,交叉熵损失,并加入随机水平翻转作为数据增强(三个视图同步翻转)。训练 50 个 epoch,每 10 轮学习率减半,保存验证集最佳模型。
📊 结果
📝 注意事项
jt.flags.use_cuda = 1),若无 GPU 可将其改为 0。wandb.init()及wandb.log()相关行。📄 参考
如有问题,欢迎提 Issue 或联系作者。