基于 Jittor 框架的 Conditional GAN(CGAN)
这是一个使用 Jittor 框架实现的 Conditional Generative Adversarial Network(条件生成对抗网络)项目,基于 MNIST 数据集训练模型生成指定类别的手写数字图像,用于第三届计图人工智能挑战赛热身赛任务。
配置
环境依赖
- Python 3.x
- Jittor(推荐使用 GPU 加速)
安装 Jittor(GPU)
pip install jittor
python -m jittor.test.test_example # 测试是否安装成功
运行训练与生成
python CGAN.py
运行后程序将自动:
- 下载 MNIST 数据集;
- 训练 CGAN 模型;
- 持续保存生成图像;
- 根据指定数字字符串(如
"28192092811985"
)输出最终结果图 result.png
。
模型结构
Generator(生成器)
将随机噪声和类别标签拼接后输入,通过全连接层逐步扩展为图像。
输入:100维噪声 + 类别标签(one-hot)
层次结构:
Linear -> LeakyReLU
Linear -> BatchNorm -> LeakyReLU
Linear -> BatchNorm -> LeakyReLU
Linear -> BatchNorm -> LeakyReLU
Linear -> Tanh
输出:32x32 灰度图像(1通道)
Discriminator(判别器)
输入图像与类别标签拼接后,判断是否为真实图像。
输入:图像(展平)+ 类别标签(one-hot)
层次结构:
Linear -> LeakyReLU
Linear -> Dropout -> LeakyReLU
Linear -> Dropout -> LeakyReLU
Linear -> 实数输出
输出:真/伪分值(非概率)
参数说明
参数名 |
描述 |
--n_epochs |
训练轮数 |
--batch_size |
每批样本大小 |
--lr |
学习率(Adam优化器) |
--b1 |
Adam优化器参数 beta1 |
--b2 |
Adam优化器参数 beta2 |
--latent_dim |
噪声向量维度 |
--n_classes |
类别数量(MNIST为10) |
--img_size |
图像尺寸(32x32) |
--channels |
图像通道数(灰度图为1) |
--sample_interval |
每隔多少步保存一次生成图像样本 |
--n_cpu |
加载数据使用的CPU线程数 |
数据说明
基于 Jittor 框架的 Conditional GAN(CGAN)
这是一个使用 Jittor 框架实现的 Conditional Generative Adversarial Network(条件生成对抗网络)项目,基于 MNIST 数据集训练模型生成指定类别的手写数字图像,用于第三届计图人工智能挑战赛热身赛任务。
配置
环境依赖
安装 Jittor(GPU)
运行训练与生成
运行后程序将自动:
"28192092811985"
)输出最终结果图result.png
。模型结构
Generator(生成器)
将随机噪声和类别标签拼接后输入,通过全连接层逐步扩展为图像。
Discriminator(判别器)
输入图像与类别标签拼接后,判断是否为真实图像。
参数说明
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
--n_cpu
数据说明