update readme
本赛题将会提供数字图片数据集 MNIST,参赛选手需要训练一个将随机噪声和类别标签映射为数字图片的Conditional GAN模型,并生成比赛页面指定数字序列。
本赛题提供示例代码框架,提供数据下载、模型定义、训练步骤等功能。
选手可以基于示例代码填充注释为 TODO 的部分完成该赛题。
在示例代码 CGAN.py 中,共有三处需要选手自行补全。下面对每个位置及其功能做具体说明。
补全内容:在 # TODO 注释处,添加
nn.Linear(512, 1)
这样可以将上一层的 512 维度特征映射为单个实数(形状为 (batch_size, 1)),表示判别器对该图像的“真”或“假”的打分。
out = self.model(d_in) return out
这一步将拼接好的输入 d_in 送入 self.model(多层全连接 + Dropout + LeakyReLU + 最后一层线性),并返回一个 (batch_size, 1) 的分数张量
让判别器对真实图像的输出 validity_real 尽量等于 valid(全 1 向量):
d_real_loss = adversarial_loss(validity_real, valid)
让判别器对假图像的输出 validity_fake 尽量等于 fake(全 0 向量):
d_fake_loss = adversarial_loss(validity_fake, fake)
将两部分损失取平均,得到判别器整体损失 d_loss,并通过 optimizer_D.step(d_loss) 完成参数更新。
# 运行程序 python CGAN.py
程序将进行训练并输出result.png作为最终结果
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor
第三届计图人工智能挑战赛
计图挑战热身赛
本赛题将会提供数字图片数据集 MNIST,参赛选手需要训练一个将随机噪声和类别标签映射为数字图片的Conditional GAN模型,并生成比赛页面指定数字序列。
本赛题提供示例代码框架,提供数据下载、模型定义、训练步骤等功能。
选手可以基于示例代码填充注释为 TODO 的部分完成该赛题。
CGAN.py 中的 TODO 部分说明
在示例代码 CGAN.py 中,共有三处需要选手自行补全。下面对每个位置及其功能做具体说明。
1. Discriminator 模型定义中的最后一层输出
补全内容:在 # TODO 注释处,添加
这样可以将上一层的 512 维度特征映射为单个实数(形状为 (batch_size, 1)),表示判别器对该图像的“真”或“假”的打分。
2. Discriminator 前向计算的输出
这一步将拼接好的输入 d_in 送入 self.model(多层全连接 + Dropout + LeakyReLU + 最后一层线性),并返回一个 (batch_size, 1) 的分数张量
3. 训练循环中判别器损失计算
让判别器对真实图像的输出 validity_real 尽量等于 valid(全 1 向量):
让判别器对假图像的输出 validity_fake 尽量等于 fake(全 0 向量):
将两部分损失取平均,得到判别器整体损失 d_loss,并通过 optimizer_D.step(d_loss) 完成参数更新。
程序将进行训练并输出result.png作为最终结果