ADD file via upload
[主要结果]
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
CPU和GPU安装版本均可使用。本项目使用一张A10显卡。
本项目使用
使用pip安装python相关依赖
pip install -r requirements.txt
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
训练指令:python CGAN.py
python CGAN.py
在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Jittor挑战赛热身赛-生成手写数字baseline(CGAN)
[主要结果]
简介
本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。
安装
硬件需求
CPU和GPU安装版本均可使用。本项目使用一张A10显卡。
运行环境
本项目使用
安装依赖
使用pip安装python相关依赖
数据预处理
本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。
模型训练和推理
训练指令:
python CGAN.py在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。
致谢