目录
目录README.md

CGAN_jittor_MNIST

简述

我们用 Jittor (计图)MNIST 上训练了一个 Conditional GAN(CGAN):输入随机噪声 z 和类别标签 y,生成指定数字的手写体图像。 代码入口只有一个 CGAN.py,因为数据集直接通过 Jittor 内置的 MNIST loader 下载/读取,训练、生成、保存都写在同一个脚本里。 最终提交产物是 result.png(按随机 ID 逐位生成并拼成一张图),以及可选的模型权重 generator_last.pkl / discriminator_last.pkl


任务

  • 目标:训练 CGAN,使得给定随机 ID(本作业使用:2310500 2312313)时,能够生成对应数字序列的图片 result.png
  • 通过标准(题面理解):评测端用 MNIST 分类器判别生成结果,平均正确率 > 0.7 视为通过;我以“数字清晰可辨、类别稳定”为训练停止标准。

特性

  • 一键训练 + 一键生成:脚本末尾自动生成 result.png,不需要额外写推理脚本。
  • 条件生成:G/D 都显式引入标签 y(Embedding + concat),可控生成指定数字而不是随机数字。
  • 图片质量优化:保存时按 tanh -> [0,255] 固定映射,避免 min-max 拉伸导致对比度漂移。

环境与依赖

  • OS:WSL2 Ubuntu(Windows 也能跑)
  • Python:3.9.25
  • Jittor:Jittor(1.3.10.0)
  • 依赖:jittornumpyPillow(以及可选的 matplotlib/tqdm

安装(WSL2 / Ubuntu)

# 1)(推荐)创建 conda 环境:隔离依赖,方便复现
conda create -n cgan-jittor python=3.9 -y
conda activate cgan-jittor

# 2) 安装依赖
pip install -U pip
pip install jittor numpy pillow
# 如果是 GPU 环境(推荐 T4/3060 等),请运行 CUDA 配置:
# python -m jittor_utils.install_cuda

Windows 差异:Windows 本机运行时,首次 import jittor 也会触发编译/缓存;相比 WSL2 更容易遇到编译链/路径问题,所以我把主流程写成 WSL2 版本。


使用方法

1)最短可跑(先确认环境没问题)

python CGAN.py --n_epochs 1 --batch_size 64

2)正式训练

python CGAN.py --n_epochs 50 --batch_size 64 --lr 0.0002

训练过程中会按 sample_interval 定期保存采样图(用于观察训练质量)。

备注:CGAN 训练轮数多的时候会比较吃时间(尤其是 CPU 环境)。如果本地跑得慢,为了提高效率也可以把 CGAN.py 放到云服务器/云 Notebook 上训练,最后把生成的 result.png 下载回来打包提交即可。

3)参数说明

–n_epochs:训练轮数(建议 30~50 以上) –batch_size:批大小 –log_interval:日志间隔(batch) –d_dropout:判别器 Dropout(适当减小可提升生成质量)

4)生成指定随机 ID

CGAN.py 末尾找到:

number = "2310500"

把它改成你自己的随机 ID 字符串即可。脚本运行结束会输出:

  • result.png(最终提交用图)
  • (可选)generator_last.pkl / discriminator_last.pkl(训练权重)

项目结构

  • CGAN.py:模型定义(G/D)+ MNIST 数据加载 + 训练循环 + 最终生成并保存 result.png
  • result.png:最终提交用的结果图(按随机 ID 逐位生成并拼接)
  • generator_last.pkl / discriminator_last.pkl:训练过程中保存的权重(可选)

复现说明

  • 输出位置:默认在当前目录生成 result.png
  • 随机性:脚本默认不固定随机种子(如果需要严格可复现,可补:np.random.seed(...) + jt.set_global_seed(...),待补充)
  • 图片格式:生成器输出经 tanh,保存时按 (x * 0.5 + 0.5) * 255 映射到灰度图(避免 min-max 导致发灰)

常见问题 FAQ

  1. 第一次 import jittor 很慢 / 像卡住
    原因:Jittor 首次会编译并缓存算子(正常现象)。处理:耐心等一次,后面运行会快很多。
  2. Embedding 报错(labels dtype 不对)
    标签 labels 需要是整数类型(例如 int32)。如果你把标签转成了 float,Embedding 可能会直接报错。处理:确保训练与生成阶段的 labels 都是 int32()
  3. 生成图发灰 / 对比度飘
    不用 min-max 归一化去保存每张图(不同 batch 的 min/max 会变化,观感不稳定)。我这里使用固定映射 tanh -> [0,1] -> [0,255]
  4. 为什么只有一个 py 文件也能训练?数据在哪?
    MNIST 通过 Jittor 的 dataset 接口创建 dataloader,数据下载/缓存由框架处理,所以项目目录里不一定有独立的数据文件夹。

合作者 / 友情链接


许可 / 声明

  • 本仓库用于课程作业/比赛提交与学习交流;
  • Jittor Python
关于

A Jittor implementation of Conditional GAN (CGAN)

2.9 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号