目录
目录README.md

Conditional GAN (CGAN) - MNIST 数字生成

第三届计图人工智能挑战赛 - 计图挑战热身赛

项目简介

本项目实现了一个 Conditional Generative Adversarial Network (CGAN) 模型,用于在 MNIST 数据集上训练并生成指定数字的图像。通过输入随机噪声向量和类别标签,模型可以生成对应数字的高质量图像。

环境要求

  • Python 3.x
  • Jittor 深度学习框架
  • 其他依赖包:
    pip install jittor numpy pillow

项目结构

pub/
├── CGAN.py              # 主训练脚本(包含模型定义、训练和生成代码)
├── generator_last.pkl   # 训练好的生成器模型(不提交到git)
├── discriminator_last.pkl # 训练好的判别器模型(不提交到git)
├── result.png          # 最终生成的结果图片
└── README.md           # 项目说明文档

使用方法

1. 训练模型

运行主训练脚本:

python CGAN.py

训练参数可以通过命令行参数调整:

python CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002

主要参数说明:

  • --n_epochs: 训练轮数(默认:100)
  • --batch_size: 批次大小(默认:64)
  • --lr: 学习率(默认:0.0002)
  • --latent_dim: 潜在空间维度(默认:100)
  • --img_size: 图像尺寸(默认:32)
  • --sample_interval: 采样间隔(默认:1000)

2. 生成图片

修改 CGAN.py 文件末尾的 number 变量:

number = "2313743"  # 修改为你想要的数字序列

然后运行:

python CGAN.py

模型架构

Generator (生成器)

  • 输入:随机噪声向量 (100维) + 类别标签嵌入 (10维)
  • 结构:多层全连接网络,包含 BatchNorm 和 LeakyReLU 激活函数
  • 输出:32×32 的灰度图像

Discriminator (判别器)

  • 输入:图像 (32×32) + 类别标签嵌入 (10维)
  • 结构:多层全连接网络,包含 Dropout 和 LeakyReLU 激活函数
  • 输出:单个实数(判断图像真假的概率)

训练过程

  1. 数据准备:自动下载 MNIST 数据集(如果本地不存在)
  2. 模型初始化:创建生成器和判别器
  3. 对抗训练
    • 训练生成器:使生成图像能够欺骗判别器
    • 训练判别器:正确区分真实图像和生成图像
  4. 模型保存:每 10 个 epoch 保存一次模型
  5. 结果生成:训练完成后生成指定数字序列的图像

注意事项

  • 首次运行会自动下载 MNIST 数据集,请确保网络连接正常
  • 训练过程可能需要较长时间,建议使用 GPU 加速
  • 模型文件(.pkl)较大,已添加到 .gitignore,不会提交到仓库
  • 训练过程中生成的中间图片(*.png)也不会提交,只保留最终结果

比赛要求

  • 用户随机ID:2313743
  • 生成结果需要被 MNIST 分类器正确识别
  • 每个数字的平均正确率需大于 0.7

开源说明

本项目遵循比赛开源要求,代码已完整开源。更多 Jittor 框架资料请参考:Jittor 官方文档

许可证

本项目仅用于比赛目的。

关于

A Jittor implementation of Conditional GAN (CGAN)

35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号