目录
目录README.md

ConditionalGAN_Jittor

一、实验综述

本实验使用 Jittor 深度学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。

二、Conditioanl GAN 网络架构

Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型:输入为一个随机向量 z, 生成器 G 输出一幅图像 G(z), 而判别器 D 需要将真实图像 x 与合成图像 G(z) 区分开来。然而,GAN 对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。在生成器 generator 和判别器 discriminator 中添加相同的额外信息 y,GAN 就可以扩展为一个 conditional 模型。y 可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将 y 作为额外输入层,添加到生成器和判别器来完成条件控制。 ![[Pasted image 20240517105802.png]] GAN 模型的损失函数设计为:$min_{G}\space max_{D} \space V(D,G)=E_{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_{z}(z)}[\log(1- D(G(z)))] \tag{1}$对于判别器 D,我们要训练最大化这个 loss。如果 D 的输入是来自真实样本的数据 x, 则 D 的输出 D(x) 要尽可能地大,log(D(x)) 也会尽可能大。如果 D 的输入是来自 G 生成的假图片 G(z),则 D 的输出 D(G(z)) 应尽可能地小,从而 log(1-D(G(z)) 会尽可能地大。这样可以达到 max D 的目的。

对于生成器 G,我们要训练最小化这个 loss。对于 G 生成的假图片 G(z),我们希望尽可能地骗过 D,让它觉得我们生成的图片就是真的图片,这样就达到了 G“以假乱真”的目的。那么 D 的输出 D(G(z)) 应尽可能地大,从而 log(1-D(G(z)) 会尽可能地小。这样可以达到 min G 的目的。

D 和 G 以这样的方式联合训练,最终达到 G 的生成能力越来越强,D 的判别能力越来越强的目的。 在 CGAN 中,我们增加了限定条件 y,即数字 0-9 的类别标签, 因此生成器和判别器的输入都需要增加类别标签的维度,若真实图片为 x,对应标签为 y1,随机向量为 z,随机标签为 y2,则生成器的输出为 G(z, y2),判别器的输出为 D(G(z, y2), y2) 及 D(x, y1)。在本次实验中,我们采用平方误差函数替代对数函数来计算损失。记合成图片为第 0 类,真实图片为第 1 类,则分类器的损失函数为: $L(D)=\frac{1}{2}(D(G(z,y_{2}),y_{2})^2 + (1-D(x,y_{1})^2)) \tag{2}$ 生成器的目标则是希望合成图片能欺骗判别器,使其被分为第 1 类,因此生成器的损失函数为: $L(G)=(1-D(G(z,y_{2}),y_{2}))^2 \tag{3}$

三、代码说明

模型定义

生成器 Generator 和判别器 Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。模型中主要使用的模块有:

  • nn.Embedding(num, dim):用于将 num 类整数标签转换为 dim 维向量;
  • nn.Linear(in_features, out_features):全连接层,输入向量维度 in_features,输出向量维度 out_features
  • nn.Drouout(p):将比例为 p 的特征置为 0;
  • nn.LeakyReLU(scale):ReLU 函数的变种,输入为负值时输出乘以 scale; 因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。

数据加载与预处理

  • 导入 MNIST 数据集,应用了图像缩放、灰度化和标准化等预处理步骤。
  • 创建了数据加载器,用于在训练过程中分批次提供数据。

辅助函数

  • save_image函数:用于将生成的图像保存到文件。

  • sample_image函数:在训练过程中定期保存生成样本的图像。

    模型训练循环

  • 主循环遍历每个训练周期(epoch),内部循环处理每个数据批次。

    • 训练生成器,使其生成更接近真实图像的样本。
    • 训练判别器,使其更好地区分真实和伪造图像。
    • 计算并打印损失,每一定间隔采样生成图像。
  • 每隔一定数量的epoch保存模型权重。

生成指定数字序列的图像

  • 在模型训练结束后,切换到评估模式,加载训练好的模型权重。
  • 根据给定的数字序列生成对应的图像,并保存为文件 result.png

四、使用方式

1. 安装 Jittor: Jittor 框架目前支持 Linux 或 Windows(包括 WSL),mac 系统请安装虚拟机解决。需要使用 Python 及 C++ 编译器(g++ 或 clang)。Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考:https://cg.cs.tsinghua.edu.cn/jittor/download/

2. 使用 requirement.txt 文件配置环境: 请使用以下命令配置环境:

pip install -r requirements.txt

3. 运行程序

python CGAN.py
关于

A Jittor implementation of Conditional GAN (CGAN).

34.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号