重新提交
本项目实现了一个基于条件生成对抗网络(Conditional GAN)的手写数字生成模型。该模型能够根据输入的类别标签(0-9的数字)和随机噪声,生成对应的MNIST风格手写数字图像。
# 创建并激活虚拟环境 conda create --name jittor_env python=3.9 conda activate jittor_env # 安装Jittor pip install jittor
~/.cache/jittor/dataset/mnist_data/
# 基本训练(100个epoch,默认参数) python CGAN.py # 自定义参数训练 python CGAN.py --n_epochs 50 --batch_size 32 --lr 0.0001 # 减少训练轮数,以快速测试 python CGAN.py --n_epochs 20 --batch_size 64
--n_epochs
--batch_size
--lr
--latent_dim
--img_size
--sample_interval
--eval_interval
CGAN-MNIST/ ├── CGAN.py # 主程序文件 ├── README.md # 项目说明文档 ├── generator_last.pkl # 训练好的生成器模型 ├── discriminator_last.pkl # 训练好的判别器模型 ├── training_evaluation.txt # 训练评估报告 ├── result.png # 生成的数字序列图像 ├── images/ # 训练过程中生成的样本 │ ├── 0.png # 第0批次的生成样本 │ ├── 1000.png # 第1000批次的生成样本 │ └── ... └── data/ # MNIST数据集
输入: [噪声(100维) + 标签嵌入(10维)] → 110维 结构: 全连接层(110→128→256→512→1024→1024) 激活: LeakyReLU(0.2) + BatchNorm + Tanh(输出层) 输出: 32×32×1 灰度图像
输入: [图像展平(1024维) + 标签嵌入(10维)] → 1034维 结构: 全连接层(1034→512→512→512→1) 激活: LeakyReLU(0.2) + Dropout(0.4) + Sigmoid(输出层) 输出: 图像为真的概率 [0, 1]
训练过程中会自动评估以下指标:
判别器准确率:区分真实与生成图像的能力
training_evaluation.txt
生成图像统计:
类别一致性:不同数字类别生成特征的稳定性
修改代码中的 number 变量:
number
# 修改生成序列 number = "1234567" # 生成数字"1234567"对应的图像
# 加载预训练模型 generator.load('generator_last.pkl') # 生成单个数字 z = jt.array(np.random.normal(0, 1, (1, 100))).float32() label = jt.array([5]).float32() # 生成数字5 generated_img = generator(z, label)
# 生成10×10的数字网格(0-9每个数字10个样本) z = jt.array(np.random.normal(0, 1, (100, 100))).float32() labels = jt.array([[i] * 10 for i in range(10)]).flatten().float32() grid_imgs = generator(z, labels)
训练100个epoch后,模型表现如下:
分析:
# 训练过程中会显示: [Epoch 1/100] [Batch 50/937] [D loss: 0.254] [G loss: 0.812] [Epoch 5/100] 测试评估 判别器测试准确率: 82.66% 生成图像均值: -0.7312 生成图像标准差: 0.5812
0.png
1000.png
2000.png
result.png
A Jittor implementation of Conditional GAN(CGAN)
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Conditional GAN for MNIST Digit Generation
项目简介
本项目实现了一个基于条件生成对抗网络(Conditional GAN)的手写数字生成模型。该模型能够根据输入的类别标签(0-9的数字)和随机噪声,生成对应的MNIST风格手写数字图像。
核心功能
开始项目
环境配置
数据准备
~/.cache/jittor/dataset/mnist_data/:运行训练
参数说明
--n_epochs--batch_size--lr--latent_dim--img_size--sample_interval--eval_interval项目结构
模型架构
生成器(Generator)
判别器(Discriminator)
评估指标
训练过程中会自动评估以下指标:
判别器准确率:区分真实与生成图像的能力
training_evaluation.txt生成图像统计:
类别一致性:不同数字类别生成特征的稳定性
自定义使用
1. 生成特定数字序列
修改代码中的
number变量:2. 使用训练好的模型
3. 批量生成数字网格
实验结果
训练100个epoch后,模型表现如下:
分析:
训练监控
实时查看训练进度
生成样本查看
0.png,1000.png,2000.png…result.png