目录
目录readme.md

Conditional GAN (Jittor) for MNIST

一个基于 Jittor 的简单条件 GAN(cGAN),将随机噪声和类别标签映射为 32x32 灰度 MNIST 数字,并在训练过程中与推理阶段保存生成结果到 sample_img/

特性

  • 条件生成器和判别器,使用标签 embedding 融合类别信息。
  • MSE 对抗损失,G/D 各自使用 Adam 优化。
  • 训练中定期保存 10x10 网格采样图;推理阶段按自定义数字序列生成长条图。

环境依赖

  • Python 3.7
  • Jittor
  • NumPy、Pillow

安装(示例,CUDA 请参考官方文档):

pip install jittor

训练

直接运行:

python CGAN.py

常用参数(脚本内有默认值):

  • --n_epochs 训练轮数(默认 100)
  • --batch_size batch 大小(默认 64)
  • --lr Adam 学习率(默认 0.0002)
  • --latent_dim 噪声维度(默认 100)
  • --n_classes 类别数(默认 10)
  • --img_size 输出尺寸(默认 32)
  • --sample_interval 采样保存间隔(按 batch,默认 1000)

输出

  • 训练采样:sample_img/<batches_done>.png,10x10 网格展示 0–9。
  • 推理结果:./result.png

自定义推理

CGAN.py 末尾修改 number 字符串为目标数字序列,然后运行:

python CGAN.py

脚本会加载 generator_last.pkldiscriminator_last.pkl,并根据序列生成 result.png

额外说明

  • 输入图片归一化到 [-1, 1],生成器输出使用 Tanh。
  • 检测到 CUDA 时自动启用 jt.flags.use_cuda = 1
  • 如出现 G loss 高而 D loss 低,可适当调节 G/D 学习率或应用 label smoothing。
关于

A Jittor implementation of Conditional GAN (CGAN).

20.1 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号