目录
目录README.md

Jittor 挑战热身赛 Baseline

结果

简介

本项目包含第五届计图挑战赛 - 热身赛的代码实现。本项目基于 Jittor 框架,使用数字图片数据集 MNIST 训练一个条件生成对抗网络(Conditional GAN)模型,并通过输入一个随机向量 z 和额外的辅助信息 y(如类别标签),生成给定的特定数字序列的图像。本项目的特点是:采用标签嵌入机制使得模型能够根据输入的数字类别生成对应的数字图像。最终实现了高质量数字图像的生成。

安装

本项目可在 1 张 NVIDIA GeForce RTX 4090 上运行,训练时间约为 10-15 分钟。

运行环境

  • Ubuntu 22.04 LTS
  • python 3.9
  • jittor >= 1.3.9

安装依赖

执行以下命令创建 conda 环境并安装 python 依赖:

conda create -n jittor python=3.9
conda activate jittor
conda install -c conda-forge gcc_linux-64==11.2.0
conda install -c conda-forge gxx_impl_linux-64==11.2.0
conda install -c conda-forge libstdcxx-ng=12
python3 -m pip install scipy trimesh matplotlib
python3 -m pip install jittor

数据预处理

运行训练程序后,程序将自动调用 Jittor 框架自带接口加载 MNIST 数据集,并设置相应的图像变换,无需进行其他操作。

训练

CGAN.py 中的 number 替换为目标数字序列字符串,运行:

python CGAN.py

如需调整训练参数,可通过命令行传入,命令行参数详见 CGAN.py。例如:

python CGAN.py --n_epochs 150 --batch_size 64 --lr 0.001

在模型训练过程中,每 sample_interval 个 batch 将保存一次生成的图像样本,每 10 个 epoch 将保存一次模型生成器(generator_last.pkl)和判别器(discriminator_last.pkl)。

推理

训练完成后,程序加载保存的模型,并根据设定的数字序列字符串生成对应的手写数字图像,保存至 result.png

可以通过 此云盘链接 下载训练得到的模型文件和生成的图像结果示例。

致谢

本项目基于比赛官方提供的 代码框架 完成。

关于

第五届计图人工智能挑战赛-计图挑战热身赛

31.0 KB
邀请码