第三届计图人工智能挑战赛热身赛 - Conditional GAN数字生成
这是一个基于Jittor框架实现的Conditional GAN模型,用于生成MNIST手写数字图像。本代码是第三届计图人工智能挑战赛热身赛的解决方案,通过训练GAN模型生成指定数字序列的图像。
项目描述
本赛题使用MNIST数据集训练一个Conditional GAN模型,该模型能够根据给定的数字标签生成对应的手写数字图像。项目的主要任务包括:
实现Conditional GAN的生成器和判别器网络
训练模型学习MNIST数据分布
生成比赛页面指定的数字序列图像
环境依赖
Python 3.7+
Jittor 1.3.6+
NumPy
Pillow
安装命令:
bash
pip install jittor numpy Pillow
项目结构
text
cgan-mnist-jittor/
├── cgan.py # 主程序代码
├── README.md # 说明文件
└── requirements.txt # 依赖列表
使用说明
训练模型
直接运行主程序即可开始训练:
bash
python cgan.py
训练过程中会输出损失信息并定期保存生成的样本图像:
text
[Epoch 0/100] [Batch 0/938] [D loss: 0.242849] [G loss: 1.135298]
[Epoch 0/100] [Batch 50/938] [D loss: 0.047651] [G loss: 4.679907]
…
参数配置
可通过命令行参数调整训练配置:
bash
python cgan.py
–n_epochs 100 \ # 训练轮数
–batch_size 64 \ # 批大小
–lr 0.0002 \ # 学习率
–latent_dim 100 \ # 噪声向量维度
–n_classes 10 \ # 类别数
–img_size 32 \ # 图像尺寸
–channels 1 \ # 图像通道数
–sample_interval 1000 # 采样间隔
生成结果
训练完成后,程序会自动:
保存生成器和判别器模型
生成最终结果图像result.png
技术细节
模型架构
生成器(Generator):
输入:噪声向量(100维) + 类别标签(10维)
结构:全连接层(1024→512→256→128) + BatchNorm + LeakyReLU
输出:32×32的灰度图像
判别器(Discriminator):
输入:图像(1024维) + 类别标签(10维)
结构:全连接层(1034→512→512→512→1)
输出:图像真实性的概率
损失函数
使用均方误差(MSE)作为对抗损失
生成器目标:使判别器将生成图像分类为真实
判别器目标:正确区分真实图像和生成图像
结果展示
训练过程中生成的样本图像示例:
https://generated_sample.png
最终生成的数字序列图像:
https://result.png
注意事项
在代码中需要将number变量替换为比赛页面指定的数字序列
训练过程中会定期保存模型检查点
最终结果保存在result.png中
致谢
感谢计图(Jittor)团队提供的高性能深度学习框架
感谢MNIST数据集提供者Yann LeCun等人
第三届计图人工智能挑战赛热身赛 - Conditional GAN数字生成 这是一个基于Jittor框架实现的Conditional GAN模型,用于生成MNIST手写数字图像。本代码是第三届计图人工智能挑战赛热身赛的解决方案,通过训练GAN模型生成指定数字序列的图像。
项目描述 本赛题使用MNIST数据集训练一个Conditional GAN模型,该模型能够根据给定的数字标签生成对应的手写数字图像。项目的主要任务包括:
实现Conditional GAN的生成器和判别器网络
训练模型学习MNIST数据分布
生成比赛页面指定的数字序列图像
环境依赖 Python 3.7+
Jittor 1.3.6+
NumPy
Pillow
安装命令:
bash pip install jittor numpy Pillow 项目结构 text cgan-mnist-jittor/ ├── cgan.py # 主程序代码 ├── README.md # 说明文件 └── requirements.txt # 依赖列表 使用说明 训练模型 直接运行主程序即可开始训练:
bash python cgan.py 训练过程中会输出损失信息并定期保存生成的样本图像:
text [Epoch 0/100] [Batch 0/938] [D loss: 0.242849] [G loss: 1.135298] [Epoch 0/100] [Batch 50/938] [D loss: 0.047651] [G loss: 4.679907] … 参数配置 可通过命令行参数调整训练配置:
bash python cgan.py
–n_epochs 100 \ # 训练轮数 –batch_size 64 \ # 批大小 –lr 0.0002 \ # 学习率 –latent_dim 100 \ # 噪声向量维度 –n_classes 10 \ # 类别数 –img_size 32 \ # 图像尺寸 –channels 1 \ # 图像通道数 –sample_interval 1000 # 采样间隔 生成结果 训练完成后,程序会自动:
保存生成器和判别器模型
生成最终结果图像result.png
技术细节 模型架构 生成器(Generator):
输入:噪声向量(100维) + 类别标签(10维)
结构:全连接层(1024→512→256→128) + BatchNorm + LeakyReLU
输出:32×32的灰度图像
判别器(Discriminator):
输入:图像(1024维) + 类别标签(10维)
结构:全连接层(1034→512→512→512→1)
输出:图像真实性的概率
损失函数 使用均方误差(MSE)作为对抗损失
生成器目标:使判别器将生成图像分类为真实
判别器目标:正确区分真实图像和生成图像
结果展示 训练过程中生成的样本图像示例: https://generated_sample.png
最终生成的数字序列图像: https://result.png
注意事项 在代码中需要将number变量替换为比赛页面指定的数字序列
训练过程中会定期保存模型检查点
最终结果保存在result.png中
致谢 感谢计图(Jittor)团队提供的高性能深度学习框架
感谢MNIST数据集提供者Yann LeCun等人