目录
目录README.md

CGAN - 条件生成对抗网络

基于Jittor框架实现的条件生成对抗网络(Conditional GAN),用于生成MNIST手写数字图像。

项目简介

本项目实现了一个条件生成对抗网络,可以根据指定的数字标签生成对应的手写数字图像。与传统GAN不同,CGAN可以控制生成内容的类别,实现有条件的图像生成。

技术特点

  • 框架: 使用Jittor深度学习框架
  • 网络结构: 生成器和判别器都采用全连接神经网络
  • 条件控制: 通过标签嵌入层实现条件生成
  • 数据集: MNIST手写数字数据集

环境要求

pip install jittor
pip install pillow
pip install numpy

网络架构

生成器 (Generator)

  • 输入: 随机噪声向量 + 类别标签
  • 标签嵌入: 将类别标签转换为嵌入向量
  • 网络层: 多层全连接网络 (128→256→512→1024→1024)
  • 输出: 32×32的灰度图像

判别器 (Discriminator)

  • 输入: 图像 + 类别标签
  • 网络层: 多层全连接网络,包含Dropout防止过拟合
  • 输出: 单个实数值,表示图像真实性判断

使用方法

1. 训练模型

取消注释训练代码部分,运行:

python CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002

2. 生成指定数字序列

修改代码中的number变量,指定要生成的数字序列:

number = '28806112873230'  # 修改为你想要的数字序列

运行代码后,会在当前目录生成result.png文件,包含指定数字序列的手写数字图像。

参数说明

  • --n_epochs: 训练轮数 (默认: 100)
  • --batch_size: 批次大小 (默认: 64)
  • --lr: 学习率 (默认: 0.0002)
  • --latent_dim: 潜在空间维度 (默认: 100)
  • --n_classes: 类别数量 (默认: 10)
  • --img_size: 图像尺寸 (默认: 32)
  • --sample_interval: 采样间隔 (默认: 1000)

文件说明

  • CGAN.py: 主程序文件,包含网络定义、训练和生成代码
  • generator_last.pkl: 训练好的生成器模型
  • discriminator_last.pkl: 训练好的判别器模型
  • result.png: 生成的数字序列图像

核心功能

条件生成

通过标签控制生成特定数字的手写数字图像

自定义序列生成

可以生成任意数字序列的连续手写数字图像

模型保存与加载

支持模型的保存和加载,便于后续使用

实现细节

  1. 标签嵌入: 使用Embedding层将离散的类别标签转换为连续向量
  2. 条件拼接: 将标签嵌入向量与噪声/图像拼接作为网络输入
  3. 对抗训练: 生成器和判别器交替训练,形成对抗关系
  4. 损失函数: 使用MSE损失函数进行训练

输出示例

运行程序后会生成一张包含指定数字序列的图像,每个数字以32×32像素的灰度图像显示,按水平方向排列。

关于

在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像

3.8 MB
邀请码