目录

MoCoViT 轻量化图像分类:训练 + 安卓端部署全流程

基于 MoCoViT 轻量化 CNN-Transformer 混合架构,实现 CIFAR-10 图像分类,完整适配 Windows AMD ROCm 训练环境,含训练、ONNX 导出简化、安卓真机部署全链路。

项目特性

  • 轻量化混合架构:参数量 5.3M,计算量 147M FLOPs
  • AMD ROCm 深度适配:GroupNorm 替代 BatchNorm,精度损失 < 0.2%
  • 一键 ONNX 转换:算子融合、常量折叠,推理误差 < 3e-6
  • 零 NDK 安卓部署:ONNX Runtime + Kotlin,支持 NNAPI 加速
  • 全链路验证:AMD 训练 + PC 校验 + 小米15 真机测试

前置要求

硬件

系统

  • Windows 10 21H2 / Windows 11
  • Python 3.12、Android Studio

一、AMD ROCm 安装(Windows)

1. 更新 AMD 显卡驱动

官网下载最新肾上腺素驱动,管理员安装后重启。

2. 安装 ROCm

HIP SDK 官方安装教程

二、Python 训练环境

官方 PyTorch ROCm 安装教程

1. 创建虚拟环境

python -m venv venv
venv\Scripts\activate

2. 安装 ROCm SDK

pip install --no-cache-dir ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_core-7.2.1-py3-none-win_amd64.whl ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_devel-7.2.1-py3-none-win_amd64.whl ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_libraries_custom-7.2.1-py3-none-win_amd64.whl ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm-7.2.1.tar.gz

3. 安装 PyTorch ROCm 版

pip install --no-cache-dir ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torch-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchaudio-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl ^
    https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchvision-0.24.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl

4. 验证 GPU 可用性

python -c "import torch;print(torch.cuda.is_available());print(f'device name [0]:', torch.cuda.get_device_name(0))"

5. 安装项目依赖

pip install numpy pillow onnx onnxsim onnxruntime onnxscript

三、代码下载与运行顺序

1. 获取代码

从 GitLink 仓库下载训练脚本:

  • 仓库地址:zjc518/MoCoViT
  • 下载 mocovit_train.pyexport_onnx.py 两个文件

2. 目录结构

将两个脚本放在同一个文件夹内,例如:

D:\code\MoCoViT\
├── mocovit_train.py    # 训练脚本
└── export_onnx.py      # 导出脚本

3. 按顺序运行

打开 CMD,进入代码目录,激活虚拟环境后依次执行:

第一步:训练模型

python mocovit_train.py
  • 自动下载 CIFAR-10 数据集
  • 训练完成后生成 mocovit_best.pth 权重文件

第二步:导出并简化 ONNX 模型

python export_onnx.py
  • 读取上一步生成的 mocovit_best.pth
  • 输出 mocovit.onnxmocovit_simplified.onnx

⚠️ 必须按顺序运行,第二步依赖第一步生成的权重文件。

四、模型训练说明

脚本内置 ROCm 兼容环境变量:

import os
os.environ["MIOPEN_DISABLE_CACHE"] = "1"
os.environ["PYTORCH_HIP_USE_MIOPEN"] = "0"
os.environ["PYTORCH_CUDA_FUSER_DISABLE"] = "1"
  • 数据集:CIFAR-10(自动下载至 ./data
  • BatchSize=32,SGD + 余弦退火
  • 输出:mocovit_best.pth

五、ONNX 导出与简化

python export_onnx.py

输出:

  • mocovit.onnx:原始模型
  • mocovit_simplified.onnx:简化版(部署用)

六、安卓 APP 部署

1. 工程准备

  • Android Studio 打开 android/ 目录
  • app/src/main/assets/ 放入 mocovit_simplified.onnx

2. 依赖配置(Kotlin DSL)

implementation("com.microsoft.onnxruntime:onnxruntime-android:1.16.0")

3. 真机运行

  1. 手机开启 USB 调试 + USB 安装应用
  2. USB 连接电脑,选择「传输文件」
  3. Android Studio 选择真机,点击运行

七、ROCm 适配说明

  1. 环境变量禁用 MIOpen 缓存与调用,规避编译报错
  2. 动态分组 GroupNorm 替代 BatchNorm,解决 BN 算子兼容问题
  3. 小 Batch 训练更稳定,精度仅下降 0.16%

八、实验数据

CIFAR-10 训练结果(12 轮实测)

训练精度 测试精度 测试损失
86.34% 83.94% 0.4711

小米 15 推理性能

推理模式 单帧耗时
CPU 4 线程 ~120ms
NNAPI 加速 ~80ms

模型对比

模型 参数量 FLOPs ImageNet 精度
MobileNetV3 Large 5.4M 219M 75.2%
GhostNet 1.0× 5.2M 141M 73.9%
MoCoViT(本项目) 5.3M 147M 74.5%

九、项目结构

MoCoViT-Android-Demo/
├── train/
│   ├── mocovit_train.py
│   └── export_onnx.py
├── android/
│   └── app/src/main/
│       ├── assets/
│       ├── java/.../MainActivity.kt
│       ├── res/layout/activity_main.xml
│       └── AndroidManifest.xml
├── docs/技术报告.md
├── .gitignore
└── README.md

十、常见问题

Q:MIOpen 相关报错? A:ROCm 与 PyTorch 版本匹配;删除 ~/.cache/miopen 缓存;确认三条环境变量。

Q:缺少 onnxscript? A:pip install onnxscript 或导出时加 use_dynamo=False

Q:识别不到手机? A:开启 USB 调试 + 传输文件模式,换原装数据线。

Q:APP 闪退? A:检查 assets 模型文件名大小写、存储权限、Gradle 同步状态。

十一、模型下载

参考文献

[1] Ma H, et al. MoCoViT: Mobile Convolutional Vision Transformer[J]. arXiv:2205.12635, 2022. [2] Mehta S, Rastegari M. MobileViT[C]//ICLR, 2022. [3] Han K, et al. GhostNet[C]//CVPR, 2020. [4] AMD Official. ROCm HIP SDK Windows Documentation

开源协议

MIT License

关于
86.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9 京公网安备 11010802047560号