目录
目录readme.md

CGAN (Jittor) for MNIST

基于 Jittor 的 Conditional GAN,在 MNIST 上按指定数字序列生成图片。默认生成顺序为 2311469,可通过参数覆盖。

1. 环境要求

  • Python 3.7+(示例使用 3.7)
  • Jittor(CPU/GPU 均可,GPU 需已装好 CUDA/CUDNN/驱动)
  • 其余依赖:numpy, pillow(随 Jittor 一同安装即可)

2. 安装依赖

pip install --upgrade pip
pip install jittor==1.2.2.59 pillow numpy

如需指定镜像:-i https://pypi.tuna.tsinghua.edu.cn/simple

3. 主要文件

4. 运行示例

主机直接跑(默认序列 2311469):

python pub/CGAN.py --n_epochs 1 --sample_interval 200

覆盖输出序列,例如 0123456789:

python pub/CGAN.py --n_epochs 1 --sample_interval 200 --sequence "0123456789"

使用 jittor 官方镜像(一次性运行):

docker run --rm -v /绝对路径/pub:/workspace jittor/jittor:cuda11.8 \
    bash -c "python3 /workspace/CGAN.py --n_epochs 1 --sample_interval 200 --sequence '2311469'"

已运行的容器内(容器名示例 9141634d8e40):

docker exec 9141634d8e40 bash -lc "cd /workspace && python3.7 CGAN.py --n_epochs 1 --sample_interval 200 --sequence '2311469'"

5. 训练输出

  • 采样图:每隔 sample_interval 步保存一张 *.png,用于查看生成效果。
  • 模型权重:每 10 个 epoch 保存 generator_last.pkldiscriminator_last.pkl
  • 最终拼接图:按 --sequence 生成的横向拼接 result.png

6. 常见问题

  • 判别器过强(D loss 低、G loss 高):降低判别器学习率,或每步多训 G 少训 D,可加标签平滑(真实标签设为 0.9)。
  • tqdm 报 RuntimeError(Set changed size during iteration):设置环境变量 TQDM_DISABLE=1 运行即可。
关于

A Jittor implementation of Conditional GAN(CGAN)

33.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号