gitignore
本项目实现了一个基于条件生成对抗网络(cGAN)的模型,用于生成MNIST数据集中的手写数字图像。通过训练生成器和判别器,该模型能够根据给定的类别标签生成逼真的手写数字。其中框架代码来自第四届计图人工智能挑战赛热身赛。
本项目采用jittor框架。安装教程可参考https://cg.cs.tsinghua.edu.cn/jittor/download/
确保你已经安装了所有必要的库。你可以使用以下命令来安装这些库:
pip install jittor numpy matplotlib pillow
cGAN.py
output.txt
generator_last.pkl
discriminator_last.pkl
loss_curve.png
result.png
python cGAN.py
可以通过命令行参数来调整训练过程中的参数。例如:
python cGAN.py --n_epochs 50 --batch_size 64 --lr 0.0002
可用的参数包括:
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
生成器:
判别器:
opt.sample_interval
如果你有任何问题或建议,请联系我。
```
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
cGAN (Conditional Generative Adversarial Network) for MNIST
项目概述
本项目实现了一个基于条件生成对抗网络(cGAN)的模型,用于生成MNIST数据集中的手写数字图像。通过训练生成器和判别器,该模型能够根据给定的类别标签生成逼真的手写数字。其中框架代码来自第四届计图人工智能挑战赛热身赛。
环境与框架
本项目采用jittor框架。安装教程可参考https://cg.cs.tsinghua.edu.cn/jittor/download/
依赖库
安装依赖
确保你已经安装了所有必要的库。你可以使用以下命令来安装这些库:
项目结构
cGAN.py
: 主程序文件,包含生成器、判别器的定义及训练过程。output.txt
: 训练过程中记录的输出日志。generator_last.pkl
和discriminator_last.pkl
: 保存的生成器和判别器模型。loss_curve.png
: 训练过程中生成的损失曲线图。result.png
: 使用最后保存的模型生成的手写数字图像。如何运行
output.txt
中。loss_curve.png
。result.png
。参数设置
可以通过命令行参数来调整训练过程中的参数。例如:
可用的参数包括:
--n_epochs
: 训练的总轮数。--batch_size
: 每个批次的大小。--lr
: 学习率。--b1
,--b2
: Adam优化器的动量衰减系数。--latent_dim
: 隐藏空间的维度。--n_classes
: 数据集的类别数量。--img_size
: 图像的尺寸。--channels
: 图像的通道数。--sample_interval
: 采样间隔。模型架构
生成器:
判别器:
结果
opt.sample_interval
个批次会生成一组图像并保存。loss_curve.png
。result.png
。联系方式
如果你有任何问题或建议,请联系我。
```