目录
目录readme.md

使用 Jittor 训练 GAN 模型

题目来自2024计图挑战热身赛,传送门:https://www.educoder.net/competitions/index/Jittor-5

如何安装Jittor

对于windows平台官方给出了pip安装和docker安装两种方法,根据个人实际情况安装即可

PIP安装

Windows 请准备好Python>=3.8,安装方法如下(conda安装需要额外命令):

# check your python version(>=3.8)
python --version
python -m pip install jittor
# if conda is used
conda install pywin32

Windows 下,jittor会自动检测显卡并安装对应的 CUDA, 请确保您的NVIDIA驱动支持CUDA 10.2 以上,您还可以使用如下命令手动为Jittor安装CUDA:

python -m jittor_utils.install_cuda

但是实际使用过程中难免出现配置cuda的各种各样的问题,貌似并不支持最新的cuda,也可能是我自己电脑的问题。

因此更推荐下面的Docker安装方法,省时省力。

Docker安装

安装Docker:点击此处下载Docker的Windows安装包, 双击运行Docker for Windows Installer。

安装过程中可能需要重启,请遵循Docker安装程序的指示完成安装。

运行如下命令(Windows请使用PowerShell

docker run jittor/jittor python3.7 -m jittor.test.test_example

输出结果:

……
step 996, loss = 0.0016367514617741108 {'hold_vars': 14, 'lived_vars': 64, 'lived_ops': 57}
step 997, loss = 0.0011712713167071342 {'hold_vars': 14, 'lived_vars': 64, 'lived_ops': 57}
step 998, loss = 0.0010918093612417579 {'hold_vars': 14, 'lived_vars': 64, 'lived_ops': 57}
step 999, loss = 0.0009948197985067964 {'hold_vars': 14, 'lived_vars': 64, 'lived_ops': 57}

如果您的输出结果如上图所示,那么恭喜您,计图镜像已经安装成功了!

接下来,我们启动一个notebook server,从命令行里面运行如下命令:

docker run -it -p 8888:8888 jittor/jittor

输出结果如下

 To access the notebook, open this file in a browser:
        file:///root/.local/share/jupyter/runtime/nbserver-133-open.html
    Or copy and paste one of these URLs:
        http://a730ebb9a5ec:8888/?token=b6c6215bd0faa750833a6c81bbebd2d021248d43338fec94
     or http://127.0.0.1:8888/?token=b6c6215bd0faa750833a6c81bbebd2d021248d43338fec94

复制到浏览器打开。 这样可以看到一些jittor的基本教程。

Mac/Linux

jittor官方安装教程

Conditional GAN(条件生成对抗网络)

Conditional GAN (cGAN) 是生成对抗网络(GAN)的一种扩展,它通过引入额外的条件信息来生成特定的样本。与标准 GAN 不同,cGAN 在生成数据时不仅依赖于随机噪声,还包括一个额外的输入条件(例如标签、类别信息、图像等),使得生成的样本能够根据给定条件有所变化。

主要概念

  • 生成器 (Generator): 生成器的目标是生成逼真的样本,它不仅接收随机噪声(通常是高维的向量),还接收一个额外的条件信息(如类别标签)。生成器根据条件信息生成与之对应的样本。

  • 判别器 (Discriminator): 判别器的任务是判断一个样本是否为真实样本。在 cGAN 中,判别器也会接收条件信息,并判断给定条件下的样本是否为真实数据。

实验数据集

jittor官方提供的 MNIST

题目要求

  • 在Discriminator中添加最后一个线性层,最终输出为一个实数;
  self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
                                    nn.LeakyReLU(0.2),
                                      nn.Linear(512, 512), 
                                      nn.Dropout(0.4), 
                                      nn.LeakyReLU(0.2), 
                                      nn.Linear(512, 512), 
                                      nn.Dropout(0.4), 
                                      nn.LeakyReLU(0.2), 
                                      nn.Linear(512, 1))
  • 将d_in输入到模型中并返回计算结果。
 validity = self.model(d_in)
  • 分别计算真实类别和虚假类别的损失函数
  # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

Reference

https://github.com/Jittor/JGAN/blob/master/models/cgan/cgan.py https://github.com/Jittor/jittor

关于

A Jittor implementation of Conditional GAN (CGAN).

34.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号