cgan_jittor
本项目使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
环境说明
Jittor 框架目前支持 Linux 或 Windows,需要使用 Python 及 C++编译器(g++或 clang)。 Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考: https://cg.cs.tsinghua.edu.cn/jittor/download/。
代码框架
本次代码仅包含一个文件 CGAN.py。
生成器Generator和判别器Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。
模型中主要使用 的模块有
- nn.Embedding(num, dim):用于将 num 类整数标签转换为 dim 维向量。
- nn.Linear(in_features, out_features):全连接层,输入向量维度 in_features,输出向量 维度 out_features。
- nn.Drouout(p):将比例为 p 的特征置为 0。
- nn.LeakyReLU(scale):ReLU 函数的变种,输入为负值时输出乘以 scale。
因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。
代码会自动下载 MNIST 数据集。每轮迭代 中,我们枚举数据集中的图片(imgs)和类别标签(labels)对,并随机生成一组输入向量,计算生成器和判别器损失函数,回传梯度并更新网络参数。每迭代若干轮会随机采样生成一批数字图片,如下:

运行说明
在根目录下运行python CGAN.py
,可以添加的参数如下:
参数 |
默认值 |
含义 |
–n_epochs |
50 |
number of epochs of training |
–batch_size |
64 |
size of the batches |
–lr |
0.0002 |
adam: learning rate |
–b1 |
0.5 |
adam: decay of first order momentum of gradient |
–b2 |
0.999 |
adam: decay of first order momentum of gradient |
–n_cpu |
8 |
number of cpu threads to use during batch generation |
–latent_dim |
100 |
dimensionality of the latent space |
–n_classes |
10 |
number of classes for dataset |
–img_size |
32 |
size of each image dimension |
–channels |
1 |
number of image channels |
–sample_interval |
1000 |
interval between image sampling |
如python CGAN.py --n_epochs 100
。
示例结果
更改CGAN.py中的number变量值,可以改变输出的数字序列。下面为一个样例输出:

cgan_jittor
本项目使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
环境说明
Jittor 框架目前支持 Linux 或 Windows,需要使用 Python 及 C++编译器(g++或 clang)。 Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考: https://cg.cs.tsinghua.edu.cn/jittor/download/。
代码框架
本次代码仅包含一个文件 CGAN.py。
生成器Generator和判别器Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。
模型中主要使用 的模块有
因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。
代码会自动下载 MNIST 数据集。每轮迭代 中,我们枚举数据集中的图片(imgs)和类别标签(labels)对,并随机生成一组输入向量,计算生成器和判别器损失函数,回传梯度并更新网络参数。每迭代若干轮会随机采样生成一批数字图片,如下:
运行说明
在根目录下运行
python CGAN.py
,可以添加的参数如下:如
python CGAN.py --n_epochs 100
。示例结果
更改CGAN.py中的number变量值,可以改变输出的数字序列。下面为一个样例输出: