cgan
本项目使用 Jittor 框架实现了一个条件生成对抗网络(Conditional GAN),用于生成带有类别标签的图像。示例数据集为 MNIST 手写数字图像,可通过输入数字标签生成相应图像。
pip install jittor
更多安装详情见:Jittor 安装文档
CGAN.py
result.png
运行以下命令开始训练(默认参数见下方):
python CGAN.py
支持的命令行参数:
--n_epochs
--batch_size
--lr
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
你可以通过指定参数来自定义训练,例如:
python CGAN.py --n_epochs 50 --batch_size 128
训练期间,每隔 sample_interval 步数将会生成一次图像,保存为 x.png,其中 x 是当前步数。
sample_interval
x.png
x
在训练完成后,程序会自动保存模型为 generator_last.pkl 和 discriminator_last.pkl。你可以使用如下代码加载模型并生成指定数字图像序列:
generator_last.pkl
discriminator_last.pkl
number = "28262472819004" # 指定数字序列
运行后将自动生成并保存为 result.png,其中图像按顺序拼接显示每个数字生成的图片。
channels × img_size × img_size
Tanh
jt.flags.use_cuda = 1
channels
A Jittor implementation of Conditional GAN (CGAN)
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor
本项目使用 Jittor 框架实现了一个条件生成对抗网络(Conditional GAN),用于生成带有类别标签的图像。示例数据集为 MNIST 手写数字图像,可通过输入数字标签生成相应图像。
环境依赖
安装 Jittor
更多安装详情见:Jittor 安装文档
文件结构说明
CGAN.py
:主程序文件,包含模型构建、训练、采样和生成流程。result.png
:根据给定数字序列生成的图像拼接结果。使用说明
训练模型
运行以下命令开始训练(默认参数见下方):
支持的命令行参数:
--n_epochs
--batch_size
--lr
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
你可以通过指定参数来自定义训练,例如:
查看生成图像
训练期间,每隔
sample_interval
步数将会生成一次图像,保存为x.png
,其中x
是当前步数。使用训练好的模型生成图像
在训练完成后,程序会自动保存模型为
generator_last.pkl
和discriminator_last.pkl
。你可以使用如下代码加载模型并生成指定数字图像序列:运行后将自动生成并保存为
result.png
,其中图像按顺序拼接显示每个数字生成的图片。模型结构
Generator
channels × img_size × img_size
的图像张量Tanh
归一化输出至 [-1, 1]Discriminator
注意事项
jt.flags.use_cuda = 1
。channels
参数。