条件生成对抗网络 (Conditional GAN) 实现
简介
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN),基于MNIST数据集生成指定类别的手写数字图像。通过训练一个生成器和一个判别器,生成器可以从随机噪声和类别标签生成逼真的手写数字图像,而判别器则用于区分真实图像和生成的图像。
环境要求
- Python 3.6 及以上版本
- Jittor 1.3.0 及以上版本 (需自行安装,安装指南请参考Jittor官方文档)
- PIL (Python Imaging Library) (用于图像保存)
项目结构
CGAN/
├── CGAN.py # 主代码文件,包含模型定义、训练逻辑和图像生成逻辑
└── result.png # 训练完成后生成的图像结果
使用说明
克隆项目:
git clone <repository_url>
cd CGAN
安装依赖:
pip install jittor pillow numpy
运行训练:
python CGAN.py
- 项目默认训练参数包括:训练轮数
n_epochs
,批量大小batch_size
,学习率lr
,类别数n_classes
等。
- 训练过程中,每隔一定数量的批次会保存生成的图像样本,并在每10轮训练后保存模型参数。
生成指定数字:
- 在训练完成后,代码会自动生成并保存一个由指定数字组成的图像。
- 需要更改生成的数字时,可以修改
number
变量中的字符串。
条件生成对抗网络 (Conditional GAN) 实现
简介
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN),基于MNIST数据集生成指定类别的手写数字图像。通过训练一个生成器和一个判别器,生成器可以从随机噪声和类别标签生成逼真的手写数字图像,而判别器则用于区分真实图像和生成的图像。
环境要求
项目结构
使用说明
克隆项目:
安装依赖:
运行训练:
n_epochs
,批量大小batch_size
,学习率lr
,类别数n_classes
等。生成指定数字:
number
变量中的字符串。