CGAN model,code, README, and result image
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # 将标签索引(0~9)映射为 one-hot 风格的嵌入向量,维度为 n_classes self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes) # 定义一个全连接层 block:Linear -> (BatchNorm) -> LeakyReLU # normalize=True 时加入 BatchNorm,首层通常不加 BatchNorm,训练更稳定 def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2)) return layers # 噪声维度 + 类别嵌入维度 作为整体输入 self.model = nn.Sequential( *block(opt.latent_dim + opt.n_classes, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), # 映射到图像所有像素的数量(C*H*W) nn.Linear(1024, int(np.prod(img_shape))), # 输出经过 Tanh,范围约束在 [-1, 1],与后面的归一化方式匹配 nn.Tanh(), ) def execute(self, noise, labels): # noise: [N, latent_dim] 的随机噪声 # labels: [N] 的类别标签索引,与 noise 一一对应 # 将标签索引通过 Embedding 映射为向量后,与噪声在特征维度上拼接 # 拼接后张量形状为 [N, latent_dim + n_classes] gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1) # 通过多层全连接网络生成平铺的一维图像向量 img = self.model(gen_input) # 将一维向量 reshape 为真正的图像张量 [N, C, H, W] img = img.view((img.shape[0], *img_shape)) return img
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # 与生成器类似,将标签索引映射到 n_classes 维空间 self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes) # 判别器同样采用多层全连接网络,将图像向量与标签嵌入向量拼接后进行判别 self.model = nn.Sequential( nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), # 最后一层线性层输出一个标量,表示真实度,无需 Sigmoid,直接配合 MSELoss 即可 nn.Linear(512, 1), ) def execute(self, img, labels): # img: [N, C, H, W] 的图像张量 # labels: [N] 的类别标签索引(int 类型) # 将图像展平成一维向量后,与标签嵌入向量在特征维度上拼接 d_in = jt.contrib.concat((img.view((img.shape[0], -1)), self.label_embedding(labels)), dim=1) # 判别器输出一个实数,表示图像在给定标签条件下的真实性 validity = self.model(d_in) return validity
# 对抗损失:使用 LSGAN 风格的均方误差损失;调用方式:adversarial_loss(网络输出, 目标标签) adversarial_loss = nn.MSELoss() # 实例化生成器与判别器 generator = Generator() discriminator = Discriminator() # 数据准备,导入 MNIST 数据集 from jittor.dataset.mnist import MNIST import jittor.transform as transform transform = transform.Compose( [ transform.Resize(opt.img_size), # 1. Resize 到指定分辨率 transform.Gray(), # 2. 转为单通道灰度图 transform.ImageNormalize(mean=[0.5], std=[0.5]), # 3. 归一化到 [-1, 1] 区间,与生成器输出 Tanh 对应 ] ) # dataloader 每次返回一个 batch 的 (imgs, labels) dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True) # 优化器定义,生成器和判别器分别使用 Adam 优化器 optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
GAN 训练主循环部分: 这是整个程序的核心训练逻辑。外层循环遍历 epoch,内层循环遍历每个 batch。在每个 batch 中,首先构造真实标签(使用 label smoothing,将 1 改为 0.9)和虚假标签(0),然后先训练生成器,使其生成的假图像在判别器中尽可能接近真实标签;随后训练判别器,分别计算其在真实图像和假图像上的损失,并取平均作为判别器总损失;每次更新前调用 sync() 触发反向传播计算。训练过程中定期打印损失信息并按设定间隔保存生成样本;每隔 10 个 epoch 将当前生成器和判别器参数保存为 .pkl 文件。
for epoch in range(opt.n_epochs): for i, (imgs, labels) in enumerate(dataloader): # 当前 batch 的实际大小 batch_size = imgs.shape[0] # 真实标签(valid 接近 1)与虚假标签(fake=0),用于对抗损失 # 一个 GAN 稳定技巧 label smoothing ,将真实标签从 1 改为 0.9,减少判别器过拟合,让生成器在训练中更稳定 real_label = 0.9 valid = (jt.ones([batch_size, 1]) * real_label).float32().stop_grad() fake = jt.zeros([batch_size, 1]).float32().stop_grad() # 真实图片及其类别标签 real_imgs = jt.array(imgs) # MNIST 数据集的 labels 本身就是整数类别索引,直接转为 jittor 张量 labels = jt.array(labels).int32() # 训练生成器 # 采样随机噪声和随机类别标签作为生成器输入 z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32() gen_labels_np = np.random.randint(0, opt.n_classes, batch_size) gen_labels = jt.array(gen_labels_np).int32() # 生成一组假图片 gen_imgs = generator(z, gen_labels) # 让判别器评估这些假图片在对应标签条件下的真实性 validity = discriminator(gen_imgs, gen_labels) # 生成器的目标是“骗过”判别器,因此希望 validity 接近 1(valid) g_loss = adversarial_loss(validity, valid) # 同步计算图,触发反向传播所需的数值计算 g_loss.sync() # 根据生成器损失更新生成器参数 optimizer_G.step(g_loss) # 训练判别器 # 判别器在真实图片上的输出,希望接近 1 validity_real = discriminator(real_imgs, labels) d_real_loss = adversarial_loss(validity_real, valid) # 判别器在假图片上的输出,希望接近 0 # 这里对 gen_imgs 使用 stop_grad,阻断梯度,防止更新到生成器 validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels) d_fake_loss = adversarial_loss(validity_fake, fake) # 判别器总损失是对真实与虚假两部分损失的平均 d_loss = (d_real_loss + d_fake_loss) / 2 d_loss.sync() optimizer_D.step(d_loss) # 每隔若干 batch 打印一次当前的损失情况,方便监控训练状态 if i % 50 == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data) ) # 根据训练进度定期采样保存生成图片 batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: sample_image(n_row=10, batches_done=batches_done) # 每隔 10 个 epoch 保存一次模型(包含第 0 个 epoch) if epoch % 10 == 0: generator.save("generator_last.pkl") discriminator.save("discriminator_last.pkl")
模型加载并指定数字序列生成结果:
# 切换到评估模式 generator.eval() discriminator.eval() # 加载最近一次保存的模型参数 generator.load("generator_last.pkl") discriminator.load("discriminator_last.pkl") # 指定的数字序列 number = "2312326" # 序列长度决定了需要生成的图片张数 n_row = len(number) # 为每一个数字采样一个噪声向量 z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad() # 将字符型的数字序列转为整型类别索引 label_list = [int(ch) for ch in number] labels = jt.array(np.array(label_list)).int32().stop_grad() # 生成对应数字的图片 gen_imgs = generator(z, labels) # 将多张生成图片在宽度方向拼接成一张长条图,依次对应 number 中的数字 img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1)) min_ = img_array.min() max_ = img_array.max() img_array = (img_array - min_) / (max_ - min_) * 255 Image.fromarray(np.uint8(img_array)).save("result.png")
在训练过程中,每隔 sample_interval 个 batch(我们设置sample_interval=500),程序就会让生成器生成一批图像并保存,以便我们直观观察生成器学习效果和训练进度。我们在整个过程中共获得178张采样结果,我们挑几个有代表性的阶段进行展示:
目标生成数字串是我的学号 2312326,最终生成结果如下:
可以看到,生成结果非常清晰准确,说明我的模型训练效果还不错。
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor 实验项目
1、 模型训练环境
2 、 训练参数设置
3、 代码实现阐述
GAN 训练主循环部分: 这是整个程序的核心训练逻辑。外层循环遍历 epoch,内层循环遍历每个 batch。在每个 batch 中,首先构造真实标签(使用 label smoothing,将 1 改为 0.9)和虚假标签(0),然后先训练生成器,使其生成的假图像在判别器中尽可能接近真实标签;随后训练判别器,分别计算其在真实图像和假图像上的损失,并取平均作为判别器总损失;每次更新前调用 sync() 触发反向传播计算。训练过程中定期打印损失信息并按设定间隔保存生成样本;每隔 10 个 epoch 将当前生成器和判别器参数保存为 .pkl 文件。
模型加载并指定数字序列生成结果:
4、 模型训练结果
1) 过程采样结果
在训练过程中,每隔 sample_interval 个 batch(我们设置sample_interval=500),程序就会让生成器生成一批图像并保存,以便我们直观观察生成器学习效果和训练进度。我们在整个过程中共获得178张采样结果,我们挑几个有代表性的阶段进行展示:
2) 最终生成结果
目标生成数字串是我的学号 2312326,最终生成结果如下:
可以看到,生成结果非常清晰准确,说明我的模型训练效果还不错。