项目简介
本项目实现了一个条件生成对抗网络(Conditional Generative Adversarial Network, CGAN),用于生成MNIST手写数字图像。该项目使用Jittor深度学习框架实现,可以根据指定的数字标签生成对应的手写数字图像。
环境要求
- Python 3.7+
- Jittor
- NumPy
- PIL (Python Imaging Library)
项目结构
.
├── CGAN.py # 主程序文件
├── result.png # 生成结果示例
├── generator_last.pkl # 生成器模型权重
└── discriminator_last.pkl # 判别器模型权重
模型架构
生成器 (Generator)
- 输入:随机噪声向量(100维)和条件标签(10维)
- 网络结构:多层全连接网络
- 输出:32x32的灰度图像
判别器 (Discriminator)
- 输入:图像(32x32)和条件标签(10维)
- 网络结构:多层全连接网络
- 输出:真实概率(0-1之间的标量)
使用方法
训练模型:
python CGAN.py
模型参数:
- –n_epochs: 训练轮数 (默认: 100)
- –batch_size: 批次大小 (默认: 64)
- –lr: 学习率 (默认: 0.0002)
- –latent_dim: 潜在空间维度 (默认: 100)
- –img_size: 图像大小 (默认: 32)
示例:
python CGAN.py –n_epochs 200 –batch_size 128
训练过程
- 每个epoch会显示判别器和生成器的损失
- 每隔1000个batch会保存一次生成的样本图片
- 每10个epoch会保存一次模型权重
生成结果
训练完成后,模型可以根据指定的数字序列生成对应的手写数字图像。生成的结果将保存为result.png。
注意事项
- 确保有足够的GPU内存进行训练
- 训练过程中会自动保存模型权重
- 可以通过调整超参数来优化生成效果
参考
项目简介
本项目实现了一个条件生成对抗网络(Conditional Generative Adversarial Network, CGAN),用于生成MNIST手写数字图像。该项目使用Jittor深度学习框架实现,可以根据指定的数字标签生成对应的手写数字图像。
环境要求
项目结构
. ├── CGAN.py # 主程序文件 ├── result.png # 生成结果示例 ├── generator_last.pkl # 生成器模型权重 └── discriminator_last.pkl # 判别器模型权重
模型架构
生成器 (Generator)
判别器 (Discriminator)
使用方法
训练模型: python CGAN.py
模型参数:
示例: python CGAN.py –n_epochs 200 –batch_size 128
训练过程
生成结果
训练完成后,模型可以根据指定的数字序列生成对应的手写数字图像。生成的结果将保存为result.png。
注意事项
参考