目录
目录README.md

CGAN with Jittor

基于 Jittor 深度学习框架实现的条件生成对抗网络(CGAN),用于生成手写数字图像。

项目简介

本项目使用 Jittor 框架实现 CGAN 模型,在 MNIST 数据集上训练,能够根据指定的数字标签生成相应的手写数字图像。

环境要求

  • Python 3.10+
  • Jittor 1.3+
  • PIL/Pillow
  • NumPy

安装依赖

pip install jittor pillow numpy

项目结构

CG_Jittor/
├── pub/
│   ├── CGAN.py          # 主训练脚本
│   ├── readme.md        # 说明文档
│   └── result.png       # 生成的结果图片
├── .gitignore
└── README.md

使用方法

训练模型

cd pub
python3 CGAN.py

训练参数可以通过命令行修改:

python3 CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002

主要参数说明:

  • --n_epochs: 训练轮数(默认100)
  • --batch_size: 批次大小(默认64)
  • --lr: 学习率(默认0.0002)
  • --latent_dim: 潜在空间维度(默认100)
  • --sample_interval: 采样间隔(默认1000)

生成图片

在训练过程中,模型会:

  1. 每1000个batch保存一次生成的样本图片
  2. 每10个epoch保存一次模型权重
  3. 训练结束后生成指定数字序列的图片(result.png)

修改数字序列:编辑 CGAN.py 中的 number 变量:

number = "2311366"  # 修改为你想生成的数字序列

模型架构

生成器(Generator)

  • 输入:噪声向量 + 数字标签
  • 结构:全连接层 + BatchNorm + LeakyReLU
  • 输出:32×32 的灰度图像

判别器(Discriminator)

  • 输入:图像 + 数字标签
  • 结构:全连接层 + Dropout + LeakyReLU
  • 输出:真假判别分数

注意事项

  1. 首次运行会自动下载 MNIST 数据集
  2. 模型文件(.pkl)较大,已在 .gitignore 中排除
  3. 训练过程中生成的大量样本图片也已排除,仅保留最终结果

License

MIT

关于

A Jittor implementation of Conditional GAN (CGAN).

38.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号