目录
目录README.md

Conditional GAN for MNIST Digit Generation

项目简介

本项目实现了一个基于条件生成对抗网络(Conditional GAN)的手写数字生成模型。该模型能够根据输入的类别标签(0-9的数字)和随机噪声,生成对应的MNIST风格手写数字图像。

核心功能

  • 条件生成:根据指定数字标签生成对应数字图像
  • 随机生成:随机生成各种手写数字样本
  • 训练与评估:完整的训练、验证、测试流程
  • 量化评估:提供多种统计指标评估生成质量

开始项目

环境配置

# 创建并激活虚拟环境
conda create --name jittor_env python=3.9
conda activate jittor_env

# 安装Jittor
pip install jittor

数据准备

  1. 自动下载:首次运行时会自动下载MNIST数据集
  2. 手动下载:如果网络问题,可从以下链接手动下载并放置到 ~/.cache/jittor/dataset/mnist_data/

运行训练

# 基本训练(100个epoch,默认参数)
python CGAN.py

# 自定义参数训练
python CGAN.py --n_epochs 50 --batch_size 32 --lr 0.0001

# 减少训练轮数,以快速测试
python CGAN.py --n_epochs 20 --batch_size 64

参数说明

参数 默认值 说明
--n_epochs 100 训练轮数
--batch_size 64 批次大小
--lr 0.0002 学习率
--latent_dim 100 噪声向量维度
--img_size 32 生成图像尺寸
--sample_interval 1000 生成样本保存间隔
--eval_interval 5 评估间隔(epoch)

项目结构

CGAN-MNIST/
├── CGAN.py                    # 主程序文件
├── README.md                  # 项目说明文档
├── generator_last.pkl         # 训练好的生成器模型
├── discriminator_last.pkl     # 训练好的判别器模型
├── training_evaluation.txt    # 训练评估报告
├── result.png                 # 生成的数字序列图像
├── images/                    # 训练过程中生成的样本
│   ├── 0.png                  # 第0批次的生成样本
│   ├── 1000.png               # 第1000批次的生成样本
│   └── ...
└── data/                      # MNIST数据集

模型架构

生成器(Generator)

输入: [噪声(100维) + 标签嵌入(10维)] → 110维
结构: 全连接层(110→128→256→512→1024→1024)
激活: LeakyReLU(0.2) + BatchNorm + Tanh(输出层)
输出: 32×32×1 灰度图像

判别器(Discriminator)

输入: [图像展平(1024维) + 标签嵌入(10维)] → 1034维
结构: 全连接层(1034→512→512→512→1)
激活: LeakyReLU(0.2) + Dropout(0.4) + Sigmoid(输出层)
输出: 图像为真的概率 [0, 1]

评估指标

训练过程中会自动评估以下指标:

  1. 判别器准确率:区分真实与生成图像的能力

    • 理想值:约50%(生成图像足够逼真)
    • 实际值:记录在 training_evaluation.txt
  2. 生成图像统计

    • 均值:反映图像整体亮度
    • 标准差:反映图像对比度
  3. 类别一致性:不同数字类别生成特征的稳定性

自定义使用

1. 生成特定数字序列

修改代码中的 number 变量:

# 修改生成序列
number = "1234567"  # 生成数字"1234567"对应的图像

2. 使用训练好的模型

# 加载预训练模型
generator.load('generator_last.pkl')

# 生成单个数字
z = jt.array(np.random.normal(0, 1, (1, 100))).float32()
label = jt.array([5]).float32()  # 生成数字5
generated_img = generator(z, label)

3. 批量生成数字网格

# 生成10×10的数字网格(0-9每个数字10个样本)
z = jt.array(np.random.normal(0, 1, (100, 100))).float32()
labels = jt.array([[i] * 10 for i in range(10)]).flatten().float32()
grid_imgs = generator(z, labels)

实验结果

训练100个epoch后,模型表现如下:

  • 判别器准确率:88.91%
  • 生成图像均值:-0.7318
  • 生成图像标准差:0.5773
  • 类别一致性:0.0688

分析

  • 判别器准确率偏高,表明生成质量有提升空间
  • 生成图像整体偏暗(均值为负)
  • 各类别生成一致性较好

训练监控

实时查看训练进度

# 训练过程中会显示:
[Epoch 1/100] [Batch 50/937] [D loss: 0.254] [G loss: 0.812]
[Epoch 5/100] 测试评估
判别器测试准确率: 82.66%
生成图像均值: -0.7312
生成图像标准差: 0.5812

生成样本查看

  • 每1000个批次会自动保存生成样本到 0.png, 1000.png, 2000.png
  • 最终结果保存为 result.png
关于

A Jittor implementation of Conditional GAN(CGAN)

35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号