Jittor挑战赛热身赛-生成手写数字baseline(CGAN)
[主要结果]
简介
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
安装
硬件需求
CPU和GPU安装版本均可使用。本项目使用一张RTX 3050Ti显卡。
运行环境
使用Anaconda创建虚拟环境,本项目使用
python = 3.9.7、jittor = 1.3.1.18、 numpy==1.26.4
安装依赖
使用pip安装python相关依赖(pip install -r requirements.txt)
以下是Jittor框架的安装步骤:
1、进入Jittor官网,选中Windows->pip->Cuda
2、使用Anaconda创建虚拟环境,python=3.9,激活虚拟环境
3、运行以下指令安装Jittor框架:
python –version
python -m pip install jittor
python -m jittor.test.test_core
python -m jittor.test.test_example
python -m jittor.test.test_cudnn_op
数据预处理
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
模型训练和推理
训练指令:python CGAN.py
在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。
本项目生成器权重下载generator_last.pkl 提取码: tf6d
判别器下载discriminator_last.pkl 提取码: r74d
致谢
本项目是在比赛官方所提供的CGAN网络架构基础上修改完成的。
本项目CGAN网络搭建基于论文Conditional Generative Adversarial Nets实现,部分代码参考了 jittor-gan
本项目均基于Jittor框架,其中MNIST数据集接口使用了Jittor框架中的 jittor.dataset.mnist 模块。
Jittor挑战赛热身赛-生成手写数字baseline(CGAN)
[主要结果]
简介
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
安装
硬件需求
CPU和GPU安装版本均可使用。本项目使用一张RTX 3050Ti显卡。
运行环境
使用Anaconda创建虚拟环境,本项目使用 python = 3.9.7、jittor = 1.3.1.18、 numpy==1.26.4
安装依赖
使用pip安装python相关依赖(pip install -r requirements.txt) 以下是Jittor框架的安装步骤: 1、进入Jittor官网,选中Windows->pip->Cuda 2、使用Anaconda创建虚拟环境,python=3.9,激活虚拟环境 3、运行以下指令安装Jittor框架: python –version python -m pip install jittor python -m jittor.test.test_core python -m jittor.test.test_example python -m jittor.test.test_cudnn_op
数据预处理
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
模型训练和推理
训练指令:python CGAN.py 在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。 本项目生成器权重下载generator_last.pkl 提取码: tf6d 判别器下载discriminator_last.pkl 提取码: r74d
致谢
本项目是在比赛官方所提供的CGAN网络架构基础上修改完成的。 本项目CGAN网络搭建基于论文Conditional Generative Adversarial Nets实现,部分代码参考了 jittor-gan 本项目均基于Jittor框架,其中MNIST数据集接口使用了Jittor框架中的 jittor.dataset.mnist 模块。