Update README.md
本项目使用条件生成对抗网络(CGAN)生成手写数字图像。CGAN在MNIST数据集上进行训练,并根据输入的标签生成对应类别的手写数字图像。项目使用Jittor框架进行实现。
参考官方文档进行安装。
运行以下命令进行模型训练:
python CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002 --b1 0.5 --b2 0.999 --n_cpu 8 --latent_dim 100 --n_classes 10 --img_size 32 --channels 1 --sample_interval 1000
训练完成后,模型将保存为generator_last.pkl和discriminator_last.pkl。可以通过以下命令生成特定标签的图像:
generator_last.pkl
discriminator_last.pkl
python CGAN.py
生成的图像将保存为result.png。
result.png
--n_epochs
--batch_size
--lr
--b1
--b2
--n_cpu
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
CGAN.py
本项目使用MIT许可证。
A Jittor implementation of Conditional GAN (CGAN)
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor
项目描述
本项目使用条件生成对抗网络(CGAN)生成手写数字图像。CGAN在MNIST数据集上进行训练,并根据输入的标签生成对应类别的手写数字图像。项目使用Jittor框架进行实现。
安装依赖
参考官方文档进行安装。
运行步骤
1. 训练模型
运行以下命令进行模型训练:
2. 生成图像
训练完成后,模型将保存为
generator_last.pkl
和discriminator_last.pkl
。可以通过以下命令生成特定标签的图像:生成的图像将保存为
result.png
。参数说明
--n_epochs
: 训练的总epoch数--batch_size
: 每个batch的大小--lr
: Adam优化器的学习率--b1
: Adam优化器的一阶动量衰减--b2
: Adam优化器的二阶动量衰减--n_cpu
: 用于数据加载的CPU线程数--latent_dim
: 潜在空间的维度--n_classes
: 数据集的类别数--img_size
: 图像的尺寸--channels
: 图像的通道数--sample_interval
: 图像采样间隔项目结构
CGAN.py
: 主代码文件,包括模型定义、训练和生成图像的代码。许可证
本项目使用MIT许可证。