目录
目录README.md

CGAN_jittor 实验项目

  • 本项目实现了一个 条件生成对抗网络(CGAN),能够根据给定的类别标签和随机噪声生成对应的数字图像。
  • 实验数据集使用 MNIST,基础代码来自头歌平台提供的示例;
  • 我在基础代码上补全了 TODO 部分,并且针对部分内容进行了一些修改优化。

1、 模型训练环境

  • 操作系统版本:Ubuntu 24.04 LTS;
  • 运行CPU环境: Intel Xeon Silver 4314,16 核 16 线程,主频 2.4GHz;
  • Python 版本: Python 3.7.17;
  • Jittor 版本:Jittor 1.3.10.0;

    2 、 训练参数设置

  • n_epochs:训练轮数设置为 100 轮,在保证训练效果的同时还要考虑避免过拟合;
  • batch_size:数据集 Batch 大小设置为 64,梯度估计更稳定,使训练收敛更平滑;
  • lr:Adam 学习率一开始设置为 0.00015,但发现结果不太理想,故仍采用 0.0002;
  • b1=0.5, b2=0.999:经典 GAN 中常用的 Adam 参数,与 DCGAN 论文配置一致;
  • latent_dim=128:噪声维度设置为128,略大于源码给出的100,希望能覆盖到丰富的潜在空间;
  • sample_interval从 1000 减小到 500:更频繁地可视化生成结果,便于更早发现问题和模型变化趋势。

    3、 代码实现阐述

  1. Generator 生成器实现:
  • 核心目标:在给定类别标签的条件下,将随机噪声映射为对应数字的图像。
  • 实现阐述:生成器首先使用 nn.Embedding 将离散的类别标签(0–9)映射为与类别数相同维度的向量,实现标签的可学习表示;随后定义了一个内部的 block 函数,用于构建多层全连接模块,每个模块包含 Linear等激活函数;生成器的输入由“随机噪声向量 z 与类别嵌入向量拼接”构成,经过多层特征扩展后,最终通过一层全连接映射到图像像素总数,并使用 Tanh 激活函数将输出限制在 [-1, 1] 范围内,以匹配后续数据归一化方式。
    class Generator(nn.Module):
      def __init__(self):
          super(Generator, self).__init__()
          # 将标签索引(0~9)映射为 one-hot 风格的嵌入向量,维度为 n_classes
          self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
          # 定义一个全连接层 block:Linear -> (BatchNorm) -> LeakyReLU
          # normalize=True 时加入 BatchNorm,首层通常不加 BatchNorm,训练更稳定
          def block(in_feat, out_feat, normalize=True):
              layers = [nn.Linear(in_feat, out_feat)]
              if normalize:
                  layers.append(nn.BatchNorm1d(out_feat, 0.8))
              layers.append(nn.LeakyReLU(0.2))
              return layers
          # 噪声维度 + 类别嵌入维度 作为整体输入
          self.model = nn.Sequential(
              *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
              *block(128, 256),
              *block(256, 512),
              *block(512, 1024),
              # 映射到图像所有像素的数量(C*H*W)
              nn.Linear(1024, int(np.prod(img_shape))),
              # 输出经过 Tanh,范围约束在 [-1, 1],与后面的归一化方式匹配
              nn.Tanh(),
          )
      def execute(self, noise, labels):
          # noise: [N, latent_dim] 的随机噪声
          # labels: [N] 的类别标签索引,与 noise 一一对应
          # 将标签索引通过 Embedding 映射为向量后,与噪声在特征维度上拼接
          # 拼接后张量形状为 [N, latent_dim + n_classes]
          gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
          # 通过多层全连接网络生成平铺的一维图像向量
          img = self.model(gen_input)
          # 将一维向量 reshape 为真正的图像张量 [N, C, H, W]
          img = img.view((img.shape[0], *img_shape))
          return img
  1. Discriminator 判别器实现:
  • 核心目标:给定图像和对应类别标签的条件,判断该图像是否为真实样本。
  • 实现阐述:判别器首先使用 nn.Embedding 将类别标签映射为向量,并将其与输入图像展平后的向量进行拼接,从而使判别器“同时看到图像内容和条件标签”。网络主体由多层全连接层、LeakyReLU 激活函数和 Dropout 组成,用于增强判别能力并防止过拟合。最后一层线性层输出一个标量,表示图像为真实的置信度值。由于使用的是 LSGAN,因此直接与均方误差损失配合使用。
    class Discriminator(nn.Module):
      def __init__(self):
          super(Discriminator, self).__init__()
          # 与生成器类似,将标签索引映射到 n_classes 维空间
          self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
          # 判别器同样采用多层全连接网络,将图像向量与标签嵌入向量拼接后进行判别
          self.model = nn.Sequential(
              nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
              nn.LeakyReLU(0.2),
              nn.Linear(512, 512),
              nn.Dropout(0.4),
              nn.LeakyReLU(0.2),
              nn.Linear(512, 512),
              nn.Dropout(0.4),
              nn.LeakyReLU(0.2),
              # 最后一层线性层输出一个标量,表示真实度,无需 Sigmoid,直接配合 MSELoss 即可
              nn.Linear(512, 1),
          )
      def execute(self, img, labels):
          # img: [N, C, H, W] 的图像张量
          # labels: [N] 的类别标签索引(int 类型)
          # 将图像展平成一维向量后,与标签嵌入向量在特征维度上拼接
          d_in = jt.contrib.concat((img.view((img.shape[0], -1)), self.label_embedding(labels)), dim=1)
          # 判别器输出一个实数,表示图像在给定标签条件下的真实性
          validity = self.model(d_in)
          return validity
  1. 对抗损失函数定义、生成器判别器实例化、MNIST 数据集加载、优化器定义:
  • 通过 nn.MSELoss() 定义对抗损失函数,用均方误差来度量判别器输出与目标标签之间的差距;
  • 实例化生成器 generator 和判别器 discriminator,为后续推理阶段准备好可训练的模型对象。
  • 使用 MNIST 数据集接口加载训练数据,然后将图像 resize 到指定大小、转换为单通道灰度图,以及使用均值 0.5、方差 0.5 的归一化方式将像素值映射到 [-1, 1] 区间。最终通过 set_attrs 设置 batch 大小和数据随机打乱,使得训练过程中每个 batch 的数据分布更加均匀,有利于模型收敛。
  • 生成器和判别器分别定义 Adam 优化器,并使用相同的学习率和动量参数。
    # 对抗损失:使用 LSGAN 风格的均方误差损失;调用方式:adversarial_loss(网络输出, 目标标签)
    adversarial_loss = nn.MSELoss()
    # 实例化生成器与判别器
    generator = Generator()
    discriminator = Discriminator()
    # 数据准备,导入 MNIST 数据集
    from jittor.dataset.mnist import MNIST
    import jittor.transform as transform
    transform = transform.Compose(
      [
          transform.Resize(opt.img_size),                  # 1. Resize 到指定分辨率
          transform.Gray(),                                # 2. 转为单通道灰度图
          transform.ImageNormalize(mean=[0.5], std=[0.5]), # 3. 归一化到 [-1, 1] 区间,与生成器输出 Tanh 对应
      ]
    )
    # dataloader 每次返回一个 batch 的 (imgs, labels)
    dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
    # 优化器定义,生成器和判别器分别使用 Adam 优化器
    optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
  1. GAN 训练主循环部分: 这是整个程序的核心训练逻辑。外层循环遍历 epoch,内层循环遍历每个 batch。在每个 batch 中,首先构造真实标签(使用 label smoothing,将 1 改为 0.9)和虚假标签(0),然后先训练生成器,使其生成的假图像在判别器中尽可能接近真实标签;随后训练判别器,分别计算其在真实图像和假图像上的损失,并取平均作为判别器总损失;每次更新前调用 sync() 触发反向传播计算。训练过程中定期打印损失信息并按设定间隔保存生成样本;每隔 10 个 epoch 将当前生成器和判别器参数保存为 .pkl 文件。

    for epoch in range(opt.n_epochs):
     for i, (imgs, labels) in enumerate(dataloader):
         # 当前 batch 的实际大小
         batch_size = imgs.shape[0]
         # 真实标签(valid 接近 1)与虚假标签(fake=0),用于对抗损失
         # 一个 GAN 稳定技巧 label smoothing ,将真实标签从 1 改为 0.9,减少判别器过拟合,让生成器在训练中更稳定
         real_label = 0.9
         valid = (jt.ones([batch_size, 1]) * real_label).float32().stop_grad()
         fake = jt.zeros([batch_size, 1]).float32().stop_grad()
         # 真实图片及其类别标签
         real_imgs = jt.array(imgs)
         # MNIST 数据集的 labels 本身就是整数类别索引,直接转为 jittor 张量
         labels = jt.array(labels).int32()
    
         # 训练生成器
         # 采样随机噪声和随机类别标签作为生成器输入
         z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
         gen_labels_np = np.random.randint(0, opt.n_classes, batch_size)
         gen_labels = jt.array(gen_labels_np).int32()
         # 生成一组假图片
         gen_imgs = generator(z, gen_labels)
         # 让判别器评估这些假图片在对应标签条件下的真实性
         validity = discriminator(gen_imgs, gen_labels)
         # 生成器的目标是“骗过”判别器,因此希望 validity 接近 1(valid)
         g_loss = adversarial_loss(validity, valid)
         # 同步计算图,触发反向传播所需的数值计算
         g_loss.sync()
         # 根据生成器损失更新生成器参数
         optimizer_G.step(g_loss)
    
         # 训练判别器
         # 判别器在真实图片上的输出,希望接近 1
         validity_real = discriminator(real_imgs, labels)
         d_real_loss = adversarial_loss(validity_real, valid)
         # 判别器在假图片上的输出,希望接近 0
         # 这里对 gen_imgs 使用 stop_grad,阻断梯度,防止更新到生成器
         validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
         d_fake_loss = adversarial_loss(validity_fake, fake)
         # 判别器总损失是对真实与虚假两部分损失的平均
         d_loss = (d_real_loss + d_fake_loss) / 2
         d_loss.sync()
         optimizer_D.step(d_loss)
    
         # 每隔若干 batch 打印一次当前的损失情况,方便监控训练状态
         if i % 50 == 0:
             print(
                 "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                 % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
             )
         # 根据训练进度定期采样保存生成图片
         batches_done = epoch * len(dataloader) + i
         if batches_done % opt.sample_interval == 0:
             sample_image(n_row=10, batches_done=batches_done)
    
     # 每隔 10 个 epoch 保存一次模型(包含第 0 个 epoch)
     if epoch % 10 == 0:
         generator.save("generator_last.pkl")
         discriminator.save("discriminator_last.pkl")
  2. 模型加载并指定数字序列生成结果:

  • 训练完成后,代码将生成器和判别器切换到 eval() 模式,并加载最近一次保存的模型参数;
  • 之后给定数字序列(这里是我的学号”2312326”),为每个数字生成一个对应的图像。
  • 程序为每个数字采样独立噪声向量,并将标签转换为整型张量输入生成器,最终将生成的多张数字图像在宽度方向拼接成一张长条图片,并保存为 result.png,展示生成效果。
    # 切换到评估模式
    generator.eval()
    discriminator.eval()
    # 加载最近一次保存的模型参数
    generator.load("generator_last.pkl")
    discriminator.load("discriminator_last.pkl")
    # 指定的数字序列
    number = "2312326"
    # 序列长度决定了需要生成的图片张数
    n_row = len(number)
    # 为每一个数字采样一个噪声向量
    z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
    # 将字符型的数字序列转为整型类别索引
    label_list = [int(ch) for ch in number]
    labels = jt.array(np.array(label_list)).int32().stop_grad()
    # 生成对应数字的图片
    gen_imgs = generator(z, labels)
    # 将多张生成图片在宽度方向拼接成一张长条图,依次对应 number 中的数字
    img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1))
    min_ = img_array.min()
    max_ = img_array.max()
    img_array = (img_array - min_) / (max_ - min_) * 255
    Image.fromarray(np.uint8(img_array)).save("result.png")

    4、 模型训练结果

1) 过程采样结果

在训练过程中,每隔 sample_interval 个 batch(我们设置sample_interval=500),程序就会让生成器生成一批图像并保存,以便我们直观观察生成器学习效果和训练进度。我们在整个过程中共获得178张采样结果,我们挑几个有代表性的阶段进行展示:

  • Batch = 500:第一张生成的采样结果,未具雏形;

图片

  • Batch = 18000:此时1~9数字图像已具雏形,但细节仍模糊不清;

图片

  • Batch = 93500:最后一张采样图片,此时1~9数字图像已基本清晰完整,细节上略有抽象;

图片

2) 最终生成结果

目标生成数字串是我的学号 2312326,最终生成结果如下:

图片

可以看到,生成结果非常清晰准确,说明我的模型训练效果还不错。

关于

A Jittor implementation of Conditional GAN (CGAN).

37.5 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号