目录
目录README.md

JT_5th_Warmup_cjd

基于条件生成对抗网络(CGAN)的MNIST数字生成模型,使用Jittor框架实现。

项目特点

  1. 条件生成对抗网络(CGAN):结合类别标签信息,生成指定数字(0-9)的图像。
  2. MNIST数据集:使用经典手写数字数据集,输入图像尺寸为32x32像素。
  3. Jittor框架:支持动态图与静态图混合编程,高效利用GPU加速(需CUDA环境)。

环境要求

依赖项

  • Jittor:深度学习框架,支持CPU/GPU运行。
    pip install jittor -f https://cg.cs.tsinghua.edu.cn/jittor/download/index.html  
  • 其他库numpy, argparse, Pillow, matplotlib(可选,用于可视化)。
    pip install numpy pillow  

硬件建议

  • GPU(推荐):支持CUDA的NVIDIA显卡,加速训练过程。
  • CPU:可运行但训练速度较慢。

使用方法

1. 训练模型

python your_script_name.py  # 替换为实际文件名(如 cgan_mnist.py)  

可选参数

参数名 说明 默认值
--n_epochs 训练轮数 100
--batch_size 批量大小 64
--lr Adam优化器学习率 0.0002
--sample_interval 图像采样间隔(每生成n批次保存一次) 1000

训练过程

  • 控制台实时输出判别器(D)和生成器(G)的损失值。
  • sample_interval批次生成并保存一次示例图像(如1000.png)。
  • 每10个epoch保存模型参数(generator_last.pkldiscriminator_last.pkl)。

2. 生成指定数字图像

  1. 打开代码,找到以下部分:

    number = "TODO: 写入比赛页面中指定的数字序列(字符串类型)"  

    number替换为目标数字字符串(如"1234567890")。

  2. 运行脚本,生成结果保存在result.png,展示指定数字的生成图像。

目录结构

.  
├── your_script_name.py  # 主代码文件  
├── generator_last.pkl   # 最后一次保存的生成器模型  
├── discriminator_last.pkl  # 最后一次保存的判别器模型  
├── *.png                # 训练过程中生成的示例图像(如 1000.png, 2000.png 等)  
└── result.png           # 最终生成的指定数字图像  

代码说明

核心模块

  1. 生成器(Generator)

    • 输入:随机噪声(100维)+ 类别标签(通过Embedding层编码)。
    • 结构:多层全连接层,使用LeakyReLU激活函数和BatchNorm,最终输出32x32单通道图像(通过Tanh归一化到[-1, 1])。
  2. 判别器(Discriminator)

    • 输入:图像(展平为1024维向量)+ 类别标签(Embedding编码)。
    • 结构:多层全连接层,使用LeakyReLU和Dropout防止过拟合,最终输出单个实数表示真实性概率。
  3. 损失函数

    • 均方误差(MSELoss):判别器区分真实/生成图像(标签1/0),生成器欺骗判别器(目标标签1)。

数据处理

  • 使用Jittor内置的MNIST数据集,预处理包括Resize、灰度化、归一化(均值0.5,标准差0.5)。

结果示例

  • 训练过程中生成的图像示例(如1000.png)展示不同epoch下的生成效果,随着训练轮数增加,数字清晰度逐渐提升。
  • result.png为指定数字序列的生成结果,每个数字对应输入的类别标签。

贡献与反馈

欢迎通过GitHub提交Issue或Pull Request,优化代码、修复问题或添加新功能。
提交Issue时请注明环境配置、复现步骤及具体问题描述。

许可证

本项目采用 MIT License,详见LICENSE文件。允许自由使用、修改和分发,但需保留原作者声明。

关于
37.0 KB
邀请码