CGAN - 条件生成对抗网络
基于Jittor框架实现的条件生成对抗网络(Conditional GAN),用于生成MNIST手写数字图像。
项目简介
本项目实现了一个条件生成对抗网络,可以根据指定的数字标签生成对应的手写数字图像。与传统GAN不同,CGAN可以控制生成内容的类别,实现有条件的图像生成。
技术特点
- 框架: 使用Jittor深度学习框架
- 网络结构: 生成器和判别器都采用全连接神经网络
- 条件控制: 通过标签嵌入层实现条件生成
- 数据集: MNIST手写数字数据集
环境要求
pip install jittor
pip install pillow
pip install numpy
网络架构
生成器 (Generator)
- 输入: 随机噪声向量 + 类别标签
- 标签嵌入: 将类别标签转换为嵌入向量
- 网络层: 多层全连接网络 (128→256→512→1024→1024)
- 输出: 32×32的灰度图像
判别器 (Discriminator)
- 输入: 图像 + 类别标签
- 网络层: 多层全连接网络,包含Dropout防止过拟合
- 输出: 单个实数值,表示图像真实性判断
使用方法
1. 训练模型
取消注释训练代码部分,运行:
python CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002
2. 生成指定数字序列
修改代码中的number
变量,指定要生成的数字序列:
number = '28806112873230' # 修改为你想要的数字序列
运行代码后,会在当前目录生成result.png
文件,包含指定数字序列的手写数字图像。
参数说明
--n_epochs
: 训练轮数 (默认: 100)
--batch_size
: 批次大小 (默认: 64)
--lr
: 学习率 (默认: 0.0002)
--latent_dim
: 潜在空间维度 (默认: 100)
--n_classes
: 类别数量 (默认: 10)
--img_size
: 图像尺寸 (默认: 32)
--sample_interval
: 采样间隔 (默认: 1000)
文件说明
CGAN.py
: 主程序文件,包含网络定义、训练和生成代码
generator_last.pkl
: 训练好的生成器模型
discriminator_last.pkl
: 训练好的判别器模型
result.png
: 生成的数字序列图像
核心功能
条件生成
通过标签控制生成特定数字的手写数字图像
自定义序列生成
可以生成任意数字序列的连续手写数字图像
模型保存与加载
支持模型的保存和加载,便于后续使用
实现细节
- 标签嵌入: 使用Embedding层将离散的类别标签转换为连续向量
- 条件拼接: 将标签嵌入向量与噪声/图像拼接作为网络输入
- 对抗训练: 生成器和判别器交替训练,形成对抗关系
- 损失函数: 使用MSE损失函数进行训练
输出示例
运行程序后会生成一张包含指定数字序列的图像,每个数字以32×32像素的灰度图像显示,按水平方向排列。
CGAN - 条件生成对抗网络
基于Jittor框架实现的条件生成对抗网络(Conditional GAN),用于生成MNIST手写数字图像。
项目简介
本项目实现了一个条件生成对抗网络,可以根据指定的数字标签生成对应的手写数字图像。与传统GAN不同,CGAN可以控制生成内容的类别,实现有条件的图像生成。
技术特点
环境要求
网络架构
生成器 (Generator)
判别器 (Discriminator)
使用方法
1. 训练模型
取消注释训练代码部分,运行:
2. 生成指定数字序列
修改代码中的
number
变量,指定要生成的数字序列:运行代码后,会在当前目录生成
result.png
文件,包含指定数字序列的手写数字图像。参数说明
--n_epochs
: 训练轮数 (默认: 100)--batch_size
: 批次大小 (默认: 64)--lr
: 学习率 (默认: 0.0002)--latent_dim
: 潜在空间维度 (默认: 100)--n_classes
: 类别数量 (默认: 10)--img_size
: 图像尺寸 (默认: 32)--sample_interval
: 采样间隔 (默认: 1000)文件说明
CGAN.py
: 主程序文件,包含网络定义、训练和生成代码generator_last.pkl
: 训练好的生成器模型discriminator_last.pkl
: 训练好的判别器模型result.png
: 生成的数字序列图像核心功能
条件生成
通过标签控制生成特定数字的手写数字图像
自定义序列生成
可以生成任意数字序列的连续手写数字图像
模型保存与加载
支持模型的保存和加载,便于后续使用
实现细节
输出示例
运行程序后会生成一张包含指定数字序列的图像,每个数字以32×32像素的灰度图像显示,按水平方向排列。