目录
目录README.md

CGAN_jittor

第三届计图人工智能挑战赛

计图挑战热身赛

本赛题将会提供数字图片数据集 MNIST,参赛选手需要训练一个将随机噪声和类别标签映射为数字图片的Conditional GAN模型,并生成比赛页面指定数字序列。

本赛题提供示例代码框架,提供数据下载、模型定义、训练步骤等功能。

选手可以基于示例代码填充注释为 TODO 的部分完成该赛题。

CGAN.py 中的 TODO 部分说明

在示例代码 CGAN.py 中,共有三处需要选手自行补全。下面对每个位置及其功能做具体说明。

1. Discriminator 模型定义中的最后一层输出

补全内容:在 # TODO 注释处,添加

nn.Linear(512, 1)

这样可以将上一层的 512 维度特征映射为单个实数(形状为 (batch_size, 1)),表示判别器对该图像的“真”或“假”的打分。

2. Discriminator 前向计算的输出

out = self.model(d_in)
return out

这一步将拼接好的输入 d_in 送入 self.model(多层全连接 + Dropout + LeakyReLU + 最后一层线性),并返回一个 (batch_size, 1) 的分数张量

3. 训练循环中判别器损失计算

让判别器对真实图像的输出 validity_real 尽量等于 valid(全 1 向量):

d_real_loss = adversarial_loss(validity_real, valid)

让判别器对假图像的输出 validity_fake 尽量等于 fake(全 0 向量):

d_fake_loss = adversarial_loss(validity_fake, fake)

将两部分损失取平均,得到判别器整体损失 d_loss,并通过 optimizer_D.step(d_loss) 完成参数更新。

# 运行程序
python CGAN.py

程序将进行训练并输出result.png作为最终结果

关于

A Jittor implementation of Conditional GAN (CGAN).

39.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号