目录
目录README.md

Jittor挑战赛热身赛-生成手写数字baseline(CGAN)

[主要结果]

简介

本项目基于Jittor框架构建条件生成对抗网络(cGAN)模型,并使用数字图片数据集 MNIST对该模型进行训练,使得该模型能够通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了类别标签映射与条件生成对抗的方法对图像类别进行了控制,并取得了可控且质量较高的手写数字图像生成效果。

安装

硬件需求

CPU和GPU安装版本均可使用。本项目使用一张A10显卡。

运行环境

本项目使用

  • Ubuntu 20.04.6 LTS
  • python = 3.8.16
  • cuda = 11.7
  • jittor==1.3.9.14
  • numpy==1.22.0

安装依赖

使用pip安装python相关依赖

pip install -r requirements.txt

数据预处理

本项目使用Jittor自带的MNIST数据集接口,可直接自动调用无需手动下载,首次运行该接口时会自动下载该数据集并进行预处理操作。预处理过后的图像会被灰度化并归一化到[-1,1]区间。

模型训练和推理

训练指令:python CGAN.py

在模型训练过程中,每隔一段时间会自动保存一次模型生成的图像,同时保存最新模型的参数至生成器(generator_last.pkl)和判别器(discriminator_last.pkl)的权重中,在该模型训练结束后,会自动加载最后一次保存的生成器和判别器的权重,并根据所输入的用户随机ID序列生成对应的手写数字图像。

致谢

  • 本项目在比赛官方所提供的CGAN网络架构基础上修改完成的。
  • 本项目CGAN网络搭建基于论文Conditional Generative Adversarial Nets实现,部分代码参考 jittor-gan
  • 本项目均基于Jittor框架
关于
37.0 KB
邀请码