Update .gitignore
这个项目实现了一个基于 Jittor 框架的条件生成对抗网络 (Conditional Generative Adversarial Network, CGAN)CGAN 是一种生成模型,它不仅学习生成逼真的数据(在这个例子中是 MNIST 手写数字图像),还学习根据给定的条件(数字类别标签)来生成特定类别的数据。本项目旨在训练一个 CGAN 模型,使其能够根据输入的数字标签生成对应的手写数字图像。
argparse
n_epochs
batch_size
lr
img_size
channels
img_shape
Generator
Discriminator
adversarial_loss
MNIST
save_image
sample_image
epoch
dataloader
number = "28164482809235"
result.png
pip install jittor
pip install numpy pillow
.py
mnist_cgan.py
python mnist_cgan.py
--n_epochs 200
--batch_size 128
sample_interval
<步数>.png
jt.has_cuda
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
基于 Jittor 实现的条件生成对抗网络 (CGAN)
项目概述
这个项目实现了一个基于 Jittor 框架的条件生成对抗网络 (Conditional Generative Adversarial Network, CGAN)CGAN 是一种生成模型,它不仅学习生成逼真的数据(在这个例子中是 MNIST 手写数字图像),还学习根据给定的条件(数字类别标签)来生成特定类别的数据。本项目旨在训练一个 CGAN 模型,使其能够根据输入的数字标签生成对应的手写数字图像。
主要组件
代码结构
argparse
定义并解析训练参数,如训练轮数 (n_epochs
)、批次大小 (batch_size
)、学习率 (lr
) 等。img_size
) 和通道数 (channels
),构建img_shape
。定义生成器使用的标签嵌入层。Generator
类:定义生成器的网络结构,包含嵌入层、多个全连接层块(带批归一化和 LeakyReLU 激活)以及最终的输出层(Tanh 激活以输出 [-1, 1] 范围的像素值)。Discriminator
类:定义判别器的网络结构,包含嵌入层、多个全连接层(带 Dropout 和 LeakyReLU 激活)以及最终的输出层(输出一个实数,代表“真实性”得分)。adversarial_loss
)。MNIST
数据集类,并应用预处理变换(调整大小、转为灰度、归一化)。save_image
: 将一批生成的图像拼接并保存为一张图片。sample_image
: 在训练过程中定期生成并保存一批图像样本,用于监控训练进度。epoch
)。dataloader
),获取真实图像和标签。sample_image
保存生成的图像样本。number = "28164482809235"
)。result.png
文件。如何运行
pip install jittor
。pip install numpy pillow
。.py
文件(例如mnist_cgan.py
)。python mnist_cgan.py
argparse
中定义的参数来调整训练过程,例如增加训练轮数 (--n_epochs 200
) 或改变批次大小 (--batch_size 128
)。sample_interval
步会生成并保存一批图像,文件名格式为<步数>.png
。result.png
的图像,其中包含根据数字序列 “28164482809235” 生成的对应数字图像。注意事项
jt.has_cuda
),如果可用会自动使用 GPU。n_epochs
) 越多,通常生成质量越好,但也会增加训练时间。