目录
目录readme.md

条件生成对抗网络(CGAN)——MNIST 手写数字生成

使用 Jittor 深度学习框架实现的条件生成对抗网络,可按类别生成 MNIST 手写数字图像。


📖 项目简介

本仓库实现了一个 Conditional GAN,能够根据指定数字类别生成对应的手写数字图像。
模型由 生成器 (Generator)判别器 (Discriminator) 组成,二者在对抗训练中不断提升,最终生成器能够“欺骗”判别器,输出以假乱真的数字图片。


✨ 特性

  • 按类别生成:输入随机噪声与数字标签即可生成指定数字。
  • 纯全连接网络:更易阅读和复现的 MLP 结构。
  • MSE 对抗损失:使用平方误差作为对抗目标。
  • 自动保存模型/样本:按设定周期保存网络权重与生成示例。

📦 安装

需 Python ≥ 3.7

  1. 克隆仓库

    git clone https://gitlink.org.cn/Felixccc/jittor_contest_2025_warmup.git
  2. 安装依赖

    # 安装 Jittor(CUDA 用户可参考官网说明)
    pip install jittor
    
    # 其余依赖
    pip install -r requirements.txt

MNIST 数据集会在首次运行时自动下载,无需手动操作。


🚀 快速开始

训练

python CGAN.py

训练过程会:

  • sample_interval 步将 1000 张生成图保存为 *.png
  • 每 10 轮将模型权重保存在 generator_last.pkldiscriminator_last.pkl

根据数字序列生成图片

python CGAN.py --generate_digits "28170102809796"

脚本将读取最近一次保存的权重,按序列 "28170102809796" 生成对应数字图像,并拼接保存为 result.png


🗂️ 目录结构

.
├── CGAN.py               # 主训练 / 推理脚本
├── generator_last.pkl    # 最近一次保存的生成器权重
├── discriminator_last.pkl# 最近一次保存的判别器权重
├── result.png            # 示例生成结果
├── .gitignore             # Git 忽略文件
└── requirements.txt      # Python 依赖列表

🏗️ 模型架构

Generator

输入: [噪声向量 z, 类别嵌入 y] → 128 → 256 → 512 → 1024 → 32×32×1 图像
  • 每层使用 LeakyReLU 激活;
  • 中间层带 BatchNorm(首层除外);
  • 输出层用 Tanh 将像素归一化至 [-1, 1]。

Discriminator

输入: [图像展平, 类别嵌入 y] → 512 → 512 → 512 → 1 (真/假得分)
  • 多层 LeakyReLU + **Dropout(0.4)**;
  • 输出为单个实数,取值越接近 1 表示越“真实”。

📐 损失函数

  • Generator:希望判别器将生成图判为真 → MSE(validity, 1)
  • Discriminator:同时区分真实图与生成图 → 0.5*(MSE(real, 1) + MSE(fake, 0))

🙏 鸣谢

  • Jittor —— 高性能深度学习框架
  • MNIST —— 手写数字数据集

📬 联系方式

如有问题或改进建议,请提交 Issue,或邮件联系 siyuan-c23@mails.tsinghua.edu.cn
欢迎 Star ⭐ / Fork 🍴 / PR 🚀!

关于

A Jittor implementation of Conditional GAN (CGAN).

354.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号