目录
目录readme.md

第三届计图人工智能挑战赛热身赛 - Conditional GAN数字生成 这是一个基于Jittor框架实现的Conditional GAN模型,用于生成MNIST手写数字图像。本代码是第三届计图人工智能挑战赛热身赛的解决方案,通过训练GAN模型生成指定数字序列的图像。

项目描述 本赛题使用MNIST数据集训练一个Conditional GAN模型,该模型能够根据给定的数字标签生成对应的手写数字图像。项目的主要任务包括:

实现Conditional GAN的生成器和判别器网络

训练模型学习MNIST数据分布

生成比赛页面指定的数字序列图像

环境依赖 Python 3.7+

Jittor 1.3.6+

NumPy

Pillow

安装命令:

bash pip install jittor numpy Pillow 项目结构 text cgan-mnist-jittor/ ├── cgan.py # 主程序代码 ├── README.md # 说明文件 └── requirements.txt # 依赖列表 使用说明 训练模型 直接运行主程序即可开始训练:

bash python cgan.py 训练过程中会输出损失信息并定期保存生成的样本图像:

text [Epoch 0/100] [Batch 0/938] [D loss: 0.242849] [G loss: 1.135298] [Epoch 0/100] [Batch 50/938] [D loss: 0.047651] [G loss: 4.679907] … 参数配置 可通过命令行参数调整训练配置:

bash python cgan.py
–n_epochs 100 \ # 训练轮数 –batch_size 64 \ # 批大小 –lr 0.0002 \ # 学习率 –latent_dim 100 \ # 噪声向量维度 –n_classes 10 \ # 类别数 –img_size 32 \ # 图像尺寸 –channels 1 \ # 图像通道数 –sample_interval 1000 # 采样间隔 生成结果 训练完成后,程序会自动:

保存生成器和判别器模型

生成最终结果图像result.png

技术细节 模型架构 生成器(Generator):

输入:噪声向量(100维) + 类别标签(10维)

结构:全连接层(1024→512→256→128) + BatchNorm + LeakyReLU

输出:32×32的灰度图像

判别器(Discriminator):

输入:图像(1024维) + 类别标签(10维)

结构:全连接层(1034→512→512→512→1)

输出:图像真实性的概率

损失函数 使用均方误差(MSE)作为对抗损失

生成器目标:使判别器将生成图像分类为真实

判别器目标:正确区分真实图像和生成图像

结果展示 训练过程中生成的样本图像示例: https://generated_sample.png

最终生成的数字序列图像: https://result.png

注意事项 在代码中需要将number变量替换为比赛页面指定的数字序列

训练过程中会定期保存模型检查点

最终结果保存在result.png中

致谢 感谢计图(Jittor)团队提供的高性能深度学习框架

感谢MNIST数据集提供者Yann LeCun等人

关于
31.0 KB
邀请码