Add .gitignore
这是一个使用 Jittor 框架实现的 Conditional Generative Adversarial Network(条件生成对抗网络)项目,基于 MNIST 数据集训练模型生成指定类别的手写数字图像,用于第三届计图人工智能挑战赛热身赛任务。
pip install jittor python -m jittor.test.test_example # 测试是否安装成功
python CGAN.py
运行后程序将自动:
"28192092811985"
result.png
将随机噪声和类别标签拼接后输入,通过全连接层逐步扩展为图像。
输入:100维噪声 + 类别标签(one-hot) 层次结构: Linear -> LeakyReLU Linear -> BatchNorm -> LeakyReLU Linear -> BatchNorm -> LeakyReLU Linear -> BatchNorm -> LeakyReLU Linear -> Tanh 输出:32x32 灰度图像(1通道)
输入图像与类别标签拼接后,判断是否为真实图像。
输入:图像(展平)+ 类别标签(one-hot) 层次结构: Linear -> LeakyReLU Linear -> Dropout -> LeakyReLU Linear -> Dropout -> LeakyReLU Linear -> 实数输出 输出:真/伪分值(非概率)
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
--n_cpu
from jittor.dataset.mnist import MNIST
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
基于 Jittor 框架的 Conditional GAN(CGAN)
这是一个使用 Jittor 框架实现的 Conditional Generative Adversarial Network(条件生成对抗网络)项目,基于 MNIST 数据集训练模型生成指定类别的手写数字图像,用于第三届计图人工智能挑战赛热身赛任务。
配置
环境依赖
安装 Jittor(GPU)
运行训练与生成
运行后程序将自动:
"28192092811985"
)输出最终结果图result.png
。模型结构
Generator(生成器)
将随机噪声和类别标签拼接后输入,通过全连接层逐步扩展为图像。
Discriminator(判别器)
输入图像与类别标签拼接后,判断是否为真实图像。
参数说明
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
--n_cpu
数据说明