目录
目录README

cGAN (Conditional Generative Adversarial Network) for MNIST

项目概述

本项目实现了一个基于条件生成对抗网络(cGAN)的模型,用于生成MNIST数据集中的手写数字图像。通过训练生成器和判别器,该模型能够根据给定的类别标签生成逼真的手写数字。其中框架代码来自第四届计图人工智能挑战赛热身赛。

环境与框架

本项目采用jittor框架。安装教程可参考https://cg.cs.tsinghua.edu.cn/jittor/download/

依赖库

  • Jittor: 一个基于Jittor框架的深度学习库。
  • NumPy: 用于数值计算。
  • Matplotlib: 用于绘制损失曲线。
  • PIL: 用于保存生成的图像。
  • argparse: 用于命令行参数解析。

安装依赖

确保你已经安装了所有必要的库。你可以使用以下命令来安装这些库:

pip install jittor numpy matplotlib pillow

项目结构

  • cGAN.py: 主程序文件,包含生成器、判别器的定义及训练过程。
  • output.txt: 训练过程中记录的输出日志。
  • generator_last.pkldiscriminator_last.pkl: 保存的生成器和判别器模型。
  • loss_curve.png: 训练过程中生成的损失曲线图。
  • result.png: 使用最后保存的模型生成的手写数字图像。

如何运行

  1. 确保你已经安装了所有依赖库。
  2. 运行主程序文件:
    python cGAN.py
  3. 训练过程中的日志会输出到output.txt中。
  4. 训练完成后,生成的损失曲线会保存为loss_curve.png
  5. 使用最后保存的模型生成的手写数字图像会保存为result.png

参数设置

可以通过命令行参数来调整训练过程中的参数。例如:

python cGAN.py --n_epochs 50 --batch_size 64 --lr 0.0002

可用的参数包括:

  • --n_epochs: 训练的总轮数。
  • --batch_size: 每个批次的大小。
  • --lr: 学习率。
  • --b1, --b2: Adam优化器的动量衰减系数。
  • --latent_dim: 隐藏空间的维度。
  • --n_classes: 数据集的类别数量。
  • --img_size: 图像的尺寸。
  • --channels: 图像的通道数。
  • --sample_interval: 采样间隔。

模型架构

  • 生成器:

    • 输入:随机噪声向量和类别标签。
    • 输出:生成的图像。
    • 架构:全连接层 + 批归一化 + LeakyReLU激活函数 + Tanh激活函数。
  • 判别器:

    • 输入:图像和类别标签。
    • 输出:判别结果(实数)。
    • 架构:全连接层 + LeakyReLU激活函数 + Dropout层。

结果

  • 训练过程中,每opt.sample_interval个批次会生成一组图像并保存。
  • 训练完成后,生成的损失曲线会保存为loss_curve.png
  • 使用最后保存的模型生成的手写数字图像会保存为result.png

联系方式

如果你有任何问题或建议,请联系我。

```

关于

A Jittor implementation of Conditional GAN (CGAN).

35.9 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号