Jittor挑战赛热身赛-生成手写数字baseline(CGAN)
[主要结果]
简介
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
安装
硬件需求
CPU和GPU安装版本均可使用。本项目使用一张A10显卡。
运行环境
本项目使用
- Ubuntu 20.04.6 LTS
- python = 3.8.16
- cuda = 11.7
- jittor==1.3.9.14
- numpy==1.22.0
安装依赖
使用pip安装python相关依赖
pip install -r requirements.txt
数据预处理
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
模型训练和推理
训练指令:python CGAN.py
在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。
致谢
- 本项目在比赛官方所提供的CGAN网络架构基础上修改完成的。
- 本项目CGAN网络搭建基于论文Conditional Generative Adversarial Nets实现,部分代码参考 jittor-gan
- 本项目均基于Jittor框架
Jittor挑战赛热身赛-生成手写数字baseline(CGAN)
[主要结果]
简介
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
安装
硬件需求
CPU和GPU安装版本均可使用。本项目使用一张A10显卡。
运行环境
本项目使用
安装依赖
使用pip安装python相关依赖
数据预处理
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
模型训练和推理
训练指令:
python CGAN.py
在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。
致谢