目录
目录README.md

CGAN_jittor

项目简介

本项目实现了一个基于条件生成对抗网络(Conditional GAN, CGAN)的模型,旨在根据随机噪声和类别标签生成MNIST手写数字图片。通过使用Jittor深度学习框架,设计并训练了生成器和判别器两个主要组件,实现了从随机输入生成特定类别数字的功能。最终,模型能够根据给定的用户随机ID生成对应的数字图片,展示了条件GAN在图像生成任务中的有效性。

环境要求

  • 操作系统:Linux
  • 编程语言:Python 3.7
  • 深度学习框架:Jittor
  • 主要依赖库
    • NumPy
    • Pillow (PIL)

安装

1. 克隆项目

首先,将本项目克隆到本地:

git clone https://gitlink.org.cn/sinco/221368_CGAN_jittor.git
cd conditional-gan-mnist

2. 安装依赖

确保已安装Jittor。可参考Jittor官方安装指南进行安装。安装完成后,继续安装其他依赖库:

pip install numpy pillow

使用方法

1. 训练模型

运行以下命令开始训练条件GAN模型:

python CGAN.py

训练脚本支持多种可调参数,以下是默认参数及其说明:

  • --n_epochs:训练轮数,默认值为100
  • --batch_size:批量大小,默认值为64
  • --lr:学习率,默认值为0.0002
  • --b1:Adam优化器的第一个动量参数,默认值为0.5
  • --b2:Adam优化器的第二个动量参数,默认值为0.999
  • --n_cpu:用于数据加载的CPU线程数,默认值为8
  • --latent_dim:潜在空间维度,默认值为100
  • --n_classes:类别数量,默认值为10
  • --img_size:图像尺寸,默认值为32
  • --channels:图像通道数,默认值为1
  • --sample_interval:生成样本图片的间隔步数,默认值为1000

2. 生成图片

训练完成后,模型会根据给定的用户随机ID生成对应的数字图片。生成过程已集成在训练脚本末尾,具体步骤如下:

number ='2213628'  # 指定数字序列
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)

img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")

运行后,生成的图片将保存为result.png

代码结构

  • 主脚本main.py
    • 包含生成器(Generator)和判别器(Discriminator)的定义
    • 数据加载与预处理
    • 训练循环,包括生成器和判别器的交替训练
    • 模型保存与加载
    • 图片生成与保存

生成器(Generator)

生成器接收随机噪声和类别标签,通过嵌入层和多层全连接网络逐步将输入映射到图像空间,最终生成指定类别的手写数字图片。网络结构包括若干线性层、批归一化层(BatchNorm)、LeakyReLU激活函数,最后一层使用Tanh激活函数输出图像。

判别器(Discriminator)

判别器接收图片和类别标签,通过嵌入层和多层全连接网络提取特征,最终输出一个实数值,表示图片的真实性。网络结构包括若干线性层、Dropout层、LeakyReLU激活函数,最后一层输出单一实数。

损失函数与优化器

采用均方误差损失(MSELoss)作为对抗损失,用于衡量生成器和判别器的性能。优化器选择Adam算法,分别为生成器和判别器设置学习率和动量参数。

训练流程

  1. 数据加载:加载MNIST数据集,进行预处理,包括调整图像尺寸、转换为灰度图及标准化。
  2. 模型初始化:定义生成器和判别器模型。
  3. 优化器与损失函数定义:设置Adam优化器和均方误差损失函数。
  4. 训练循环
    • 训练生成器:生成假图片,计算生成器损失,反向传播并更新生成器参数。
    • 训练判别器:使用真实图片和生成图片计算判别器损失,反向传播并更新判别器参数。
    • 定期保存模型参数和生成样本图片。
  5. 模型保存与加载:每隔若干轮保存一次模型,训练结束后加载最后保存的模型用于生成最终结果。
  6. 图片生成:根据指定的用户随机ID生成对应的数字图片,并保存为result.png

结果

在训练过程中,生成器和判别器的损失逐渐趋于稳定,生成的数字图片清晰且具有较高的辨识度。

最终,根据给定的用户随机ID“2213628”,模型成功生成了对应的数字图片,保存为result.png

关于

A Jittor implementation of Conditional GAN (CGAN).

833.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号