目录
目录readme.md

Conditional GAN for MNIST (Jittor) 🧠✨

本项目基于 Jittor 实现一个 Conditional GAN(条件生成对抗网络):输入随机噪声 z 与类别标签 y,生成对应类别的手写数字图像;并支持按指定数字序列生成并保存拼接结果图 result.png


1. 任务说明 🎯

  • 数据集:MNIST(Jittor 内置 jittor.dataset.mnist 自动下载)
  • 目标:训练条件生成器 G(z, y) 生成指定类别数字图像,同时训练判别器 D(x, y) 判断“图像-标签”对是否真实
  • 产出:根据比赛给定字符串数字序列(如 学号)生成对应数字图像并拼接保存

2. 方法概述 🧪

  • Generator(生成器) 🏗️
    将噪声向量 z 与标签嵌入(Embedding)后的条件向量拼接,经多层全连接网络生成图像(输出为 C×H×W)。

  • Discriminator(判别器) 🕵️‍♂️
    将输入图像 x 展平为向量,与标签嵌入向量拼接,经过多层全连接网络输出一个标量,表示该“图像-标签”对为真的概率/置信度 。

  • 损失函数 📉
    使用 MSELoss(最小二乘 GAN / LSGAN 风格):

    • 判别器:真实样本逼近 1,生成样本逼近 0
    • 生成器:让生成样本在判别器输出上逼近 1(“骗过判别器”)

3. 环境依赖 🧰

  • 操作系统:Ubuntu / Debian 系(推荐)
  • Python:3.7
  • 框架:Jittor
  • 额外依赖:OpenMP(libomp-dev),用于加速/编译相关

4. 安装与配置 🛠️

4.1 安装系统依赖(Ubuntu/Debian)

sudo apt update
sudo apt install -y libomp-dev build-essential git

说明:如果你的系统源中找不到 python3.7-dev 属于正常现象(新版本 Ubuntu 往往不再提供 3.7 的 dev 包)。推荐直接使用 conda 创建 Python 3.7 环境

4.2 使用 Conda 创建虚拟环境并安装 Jittor

conda create -n jittor37 python=3.7 -y
conda activate jittor37

python -m pip install -U pip setuptools wheel
python -m pip install jittor
# 或安装 github 最新版(二选一)
# python -m pip install git+https://github.com/Jittor/jittor.git

4.3 运行 Jittor 自检

python -m jittor.test.test_example

第一次运行 Jittor 可能会触发编译缓存,耗时较长属于正常现象


5. 训练与生成 🏋️‍♀️

5.1 开始训练

项目主脚本为:

  • CGAN.py

直接运行(使用默认参数训练):

python CGAN.py

5.2 常用参数

脚本支持命令行参数(argparse):

  • --n_epochs:训练轮数(默认 100)🔁
  • --batch_size:批大小(默认 64)📦
  • --lr:学习率(默认 0.0002)📈
  • --b1 / --b2:Adam 的 beta 参数(默认 0.5 / 0.999)⚙️
  • --latent_dim:噪声维度(默认 100)🌫️
  • --n_classes:类别数(默认 10)🔟
  • --img_size:图像尺寸(默认 32,会对 MNIST 做 Resize)📐
  • --channels:通道数(默认 1)🎛️
  • --sample_interval:采样间隔(默认 1000 step 保存一次采样图)⏱️

示例(训练 30 epoch):

python CGAN.py --n_epochs 30 --batch_size 64 --lr 0.0002

5.3 生成指定数字序列(用于提交)

在脚本末尾修改:

number = "2213119"  # 改成比赛页面要求的字符串数字序列

训练结束后脚本会生成并保存:

  • result.png:按 number 生成的数字序列拼接图(提交文件)

6. 输出文件说明 📂

训练过程中/训练结束后会产生常见输出:

  • *.png:按 --sample_interval 采样生成的网格图(用于观察训练过程)
  • generator_last.pkl:生成器参数(每 10 epoch 保存一次)
  • discriminator_last.pkl:判别器参数(每 10 epoch 保存一次)
  • result.png:最终按指定序列生成的结果图(用于提交)

7. 常见问题(FAQ)🧩

7.1 Embedding 报错:索引类型不对?

如果出现类似 “Embedding index must be int” 的问题,确保标签是整型(如 int32),而不是 float32。例如:

labels = labels.int32()

7.2 训练初期生成很糊/噪声很大?

GAN 在早期通常不稳定是正常的。可尝试:

  • 增加训练 epoch 🔁
  • 调整 --lr(后期降低学习率更平稳)
  • 增大 batch size(显存允许的情况下)

8. 目录结构(示例)🗂️

.
├── CGAN.py
├── generator_last.pkl
├── discriminator_last.pkl
├── result.png
└── (训练过程中生成的若干 *.png)

9. 参考 📚

  • Jittor 官方文档与示例
  • MNIST 数据集
  • Conditional GAN / LSGAN 相关论文与资料(概念实现参考)
关于

A Jittor implementation of Conditional GAN (CGAN).

1000.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号