目录
目录README.md

CGAN

项目简介

本项目在数字图片数据集 MNIST 上训练 Conditional GAN(条件生成对抗网络)模型。通过输入随机噪声向量 z 和条件信息 y(类别标签),生成特定数字的手写体图像。

任务目标:

  • 训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型
  • 根据给定的数字序列生成对应的手写数字图像
  • 使用 Jittor 深度学习框架实现模型训练和推理

技术特点:

  • 基于条件生成对抗网络(CGAN)架构
  • 支持指定数字类别的图像生成
  • 使用 MNIST 数据集进行训练
  • 支持 GPU 加速训练

运行环境

python 3.8
argon2-cffi         20.1.0
astunparse          1.6.3
async-generator     1.10
attrs               20.3.0
backcall            0.2.0
bleach              3.3.0
cffi                1.14.5
cycler              0.10.0
decorator           5.0.6
defusedxml          0.7.1
entrypoints         0.3
ipykernel           5.5.3
ipython             7.22.0
ipython-genutils    0.2.0
jedi                0.18.0
Jinja2              2.11.3
jittor              1.2.2.59
jsonschema          3.2.0
jupyter-client      6.1.13
jupyter-core        4.7.1
jupyterlab-pygments 0.1.2
kiwisolver          1.3.1
MarkupSafe          1.1.1
matplotlib          3.4.1
mistune             0.8.4
nbclient            0.5.3
nbconvert           6.0.7
nbformat            5.1.3
nest-asyncio        1.5.1
notebook            6.3.0
numpy               1.20.2
packaging           20.9
pandocfilters       1.4.3
parso               0.8.2
pexpect             4.8.0
pickleshare         0.7.5
Pillow              8.2.0
pip                 20.0.2
prometheus-client   0.10.1
prompt-toolkit      3.0.18
ptyprocess          0.7.0
pybind11            2.6.2
pycparser           2.20
Pygments            2.8.1
pyparsing           2.4.7
pyrsistent          0.17.3
python-dateutil     2.8.1
pyzmq               22.0.3
Send2Trash          1.5.0
setuptools          45.2.0
six                 1.15.0
terminado           0.9.4
testpath            0.4.4
tornado             6.1
tqdm                4.60.0
traitlets           5.0.5
wcwidth             0.2.5
webencodings        0.5.1
wheel               0.34.2

使用方法

cd /path/to/project/
docker run -it --gpus all -v /path/to/project:/workspace jittor-cuda-12-2 bash
cd /workspace
python3.8 -m CGAN.py

部分可配置参数如下:

--n_epochs: 训练轮数,默认100
--batch_size: 批次大小,默认64
--lr: Adam优化器学习率,默认0.0002
--b1: Adam优化器一阶动量衰减率,默认0.5
--b2: Adam优化器二阶动量衰减率,默认0.999
--latent_dim: 潜在空间维度,默认100
--n_classes: 数据集类别数,默认10
--img_size: 图像尺寸,默认32x32
--channels: 图像通道数,默认1(灰度图)
--sample_interval: 图像采样间隔,默认1000

模型简介

生成器(Generator)

输入:

  • noise: 随机噪声向量,维度为 [batch_size, latent_dim](默认100维)
  • labels: 类别标签,维度为 [batch_size],取值范围0-9

网络结构:

  • 标签嵌入层:将类别标签映射为10维向量
  • 全连接网络:输入维度110(100维噪声+10维标签),通过多层全连接层逐步扩展
    • 110 → 128 → 256 → 512 → 1024 → 1024(图像像素数)
  • 激活函数:LeakyReLU(除输出层)+ Tanh(输出层)
  • 批归一化:应用于隐藏层

输出:

  • 生成的图像,维度为 [batch_size, 1, 32, 32],像素值范围[-1, 1]

判别器(Discriminator)

输入:

  • img: 图像数据,维度为 [batch_size, 1, 32, 32]
  • labels: 对应的类别标签,维度为 [batch_size]

网络结构:

  • 标签嵌入层:将类别标签映射为10维向量
  • 图像展平:将32x32图像展平为1024维向量
  • 全连接网络:输入维度1034(1024维图像+10维标签)
    • 1034 → 512 → 512 → 512 → 1
  • 激活函数:LeakyReLU
  • 正则化:Dropout(0.4)应用于隐藏层

输出:

  • 真假判别结果,维度为 [batch_size, 1],表示输入图像为真实图像的概率

损失函数

  • 对抗损失: 均方误差损失(MSE Loss)
  • 生成器损失: 希望判别器将生成图像判定为真实图像
  • 判别器损失: 正确区分真实图像和生成图像的平均损失

项目结构

.
├── CGAN.py # 主代码文件
├── Dockerfile # 自定义 Dockerfile
├── README.md # README 文件
├── result.png # 生成的数字序列图像(最终结果)
├── generator_last.pkl # 训练好的生成器模型
└── discriminator_last.pkl # 训练好的判别器模型

参考资料

Jittor文档

感谢王潇策同学提供的镜像包

关于
14.8 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号