本实验使用 Jittor 深度学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
二、Conditioanl GAN 网络架构
Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型:输入为一个随机向量 z, 生成器 G 输出一幅图像 G(z), 而判别器 D 需要将真实图像 x 与合成图像 G(z) 区分开来。然而,GAN 对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。在生成器 generator 和判别器 discriminator 中添加相同的额外信息 y,GAN 就可以扩展为一个 conditional 模型。y 可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将 y 作为额外输入层,添加到生成器和判别器来完成条件控制。
![[Pasted image 20240517105802.png]]
GAN 模型的损失函数设计为:$min_{G}\space max_{D} \space V(D,G)=E_{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_{z}(z)}[\log(1- D(G(z)))] \tag{1}$对于判别器 D,我们要训练最大化这个 loss。如果 D 的输入是来自真实样本的数据 x, 则 D 的输出 D(x) 要尽可能地大,log(D(x)) 也会尽可能大。如果 D 的输入是来自 G 生成的假图片 G(z),则 D 的输出 D(G(z)) 应尽可能地小,从而 log(1-D(G(z)) 会尽可能地大。这样可以达到 max D 的目的。
对于生成器 G,我们要训练最小化这个 loss。对于 G 生成的假图片 G(z),我们希望尽可能地骗过 D,让它觉得我们生成的图片就是真的图片,这样就达到了 G“以假乱真”的目的。那么 D 的输出 D(G(z)) 应尽可能地大,从而 log(1-D(G(z)) 会尽可能地小。这样可以达到 min G 的目的。
ConditionalGAN_Jittor
一、实验综述
本实验使用 Jittor 深度学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
二、Conditioanl GAN 网络架构
Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型:输入为一个随机向量 z, 生成器 G 输出一幅图像 G(z), 而判别器 D 需要将真实图像 x 与合成图像 G(z) 区分开来。然而,GAN 对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。在生成器 generator 和判别器 discriminator 中添加相同的额外信息 y,GAN 就可以扩展为一个 conditional 模型。y 可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将 y 作为额外输入层,添加到生成器和判别器来完成条件控制。 ![[Pasted image 20240517105802.png]] GAN 模型的损失函数设计为:$min_{G}\space max_{D} \space V(D,G)=E_{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_{z}(z)}[\log(1- D(G(z)))] \tag{1}$对于判别器 D,我们要训练最大化这个 loss。如果 D 的输入是来自真实样本的数据 x, 则 D 的输出 D(x) 要尽可能地大,log(D(x)) 也会尽可能大。如果 D 的输入是来自 G 生成的假图片 G(z),则 D 的输出 D(G(z)) 应尽可能地小,从而 log(1-D(G(z)) 会尽可能地大。这样可以达到 max D 的目的。
对于生成器 G,我们要训练最小化这个 loss。对于 G 生成的假图片 G(z),我们希望尽可能地骗过 D,让它觉得我们生成的图片就是真的图片,这样就达到了 G“以假乱真”的目的。那么 D 的输出 D(G(z)) 应尽可能地大,从而 log(1-D(G(z)) 会尽可能地小。这样可以达到 min G 的目的。
D 和 G 以这样的方式联合训练,最终达到 G 的生成能力越来越强,D 的判别能力越来越强的目的。 在 CGAN 中,我们增加了限定条件 y,即数字 0-9 的类别标签, 因此生成器和判别器的输入都需要增加类别标签的维度,若真实图片为 x,对应标签为 y1,随机向量为 z,随机标签为 y2,则生成器的输出为 G(z, y2),判别器的输出为 D(G(z, y2), y2) 及 D(x, y1)。在本次实验中,我们采用平方误差函数替代对数函数来计算损失。记合成图片为第 0 类,真实图片为第 1 类,则分类器的损失函数为: $L(D)=\frac{1}{2}(D(G(z,y_{2}),y_{2})^2 + (1-D(x,y_{1})^2)) \tag{2}$ 生成器的目标则是希望合成图片能欺骗判别器,使其被分为第 1 类,因此生成器的损失函数为: $L(G)=(1-D(G(z,y_{2}),y_{2}))^2 \tag{3}$
三、代码说明
模型定义
生成器
Generator
和判别器Discriminator
中的init
函数用于定义模型架构,execute
函数给定网络输入返回网络输出。模型中主要使用的模块有:nn.Embedding(num, dim)
:用于将 num 类整数标签转换为 dim 维向量;nn.Linear(in_features, out_features)
:全连接层,输入向量维度in_features
,输出向量维度out_features
;nn.Drouout(p)
:将比例为p
的特征置为 0;nn.LeakyReLU(scale)
:ReLU 函数的变种,输入为负值时输出乘以scale
; 因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。数据加载与预处理
辅助函数
save_image
函数:用于将生成的图像保存到文件。sample_image
函数:在训练过程中定期保存生成样本的图像。模型训练循环
主循环遍历每个训练周期(epoch),内部循环处理每个数据批次。
每隔一定数量的epoch保存模型权重。
生成指定数字序列的图像
result.png
。四、使用方式
1. 安装 Jittor: Jittor 框架目前支持 Linux 或 Windows(包括 WSL),mac 系统请安装虚拟机解决。需要使用 Python 及 C++ 编译器(g++ 或 clang)。Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考:https://cg.cs.tsinghua.edu.cn/jittor/download/
2. 使用 requirement.txt 文件配置环境: 请使用以下命令配置环境:
3. 运行程序