基于Jittor的条件生成对抗网络手写数字生成项目
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN, cGAN),旨在基于随机噪声和标签生成手写数字图像,并通过判别器进行训练,以使生成的图像尽可能接近真实的手写数字图像。该实现基于MNIST数据集,可通过标签控制生成特定数字的图像。
项目结构
.
├── generator_last.pkl # 训练后的生成器模型
├── discriminator_last.pkl # 训练后的判别器模型
├── result.png # 生成的图像
├── CGAN.py # 主训练脚本
└── README.md # 项目的说明文件
依赖
该项目依赖以下Python包:
- Jittor:用于高效的深度学习计算。
- Numpy:用于数组操作和数值计算。
- Pillow(PIL):用于图像保存和处理。
- argparse:用于命令行参数解析。
可以通过以下命令安装依赖:
pip install jittor numpy pillow
数据集
本项目使用MNIST数据集,该数据集包含手写数字的28x28像素图像。Jittor提供了MNIST数据集的加载器,支持数据的自动下载和预处理。
参数说明
训练参数
以下是训练脚本支持的命令行参数:
--n_epochs
(默认值:100):训练的总轮数。
--batch_size
(默认值:64):每个批次的图像数量。
--lr
(默认值:0.0002):Adam优化器的学习率。
--b1
(默认值:0.5):Adam优化器的一阶矩动量衰减。
--b2
(默认值:0.999):Adam优化器的二阶矩动量衰减。
--n_cpu
(默认值:8):用于批量生成的CPU线程数。
--latent_dim
(默认值:100):隐变量的维度。
--n_classes
(默认值:10):数据集的类别数量(MNIST数据集中为10个数字类别)。
--img_size
(默认值:32):图像的尺寸(宽度/高度,本项目将MNIST图像调整为此尺寸)。
--channels
(默认值:1):图像的通道数(MNIST为灰度图像,通道数为1)。
--sample_interval
(默认值:1000):每多少步生成并保存一次图像。
生成图像相关参数
在生成图像时,可在CGAN.py
文件中修改number
变量来指定生成图像的标签序列(例如,可替换为电话号码、自定义数字序列等)。
生成的图像
每隔sample_interval
训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为.png
格式(如result.png
)。
训练过程
在训练过程中,生成器和判别器交替训练。生成器尝试生成逼真的手写数字图像,以欺骗判别器将其判断为真实图像;判别器则努力区分生成图像与真实图像的差异,通过不断调整两者的参数,使生成器生成的图像质量逐渐提高。
如何运行
- 下载并安装依赖项:
pip install jittor numpy pillow
- 下载MNIST数据集:
训练脚本会自动下载MNIST数据集并进行预处理,无需手动下载。
- 运行训练脚本:
python3 CGAN.py
训练过程中,模型会自动保存生成器和判别器的状态。每1000步,会生成一个新图像并保存。
- 使用生成的模型生成图像:
训练完成后,若要使用保存的生成器模型生成图像,需修改
CGAN.py
文件中的以下代码:generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')
number = "1234567890" # 替换为你自己的数字序列(例如电话号码等)
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)
img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1))
min_ = img_array.min()
max_ = img_array.max()
img_array = (img_array - min_) / (max_ - min_) * 255
Image.fromarray(np.uint8(img_array)).save("result.png")
将number
替换为你自己的数字序列(如手机号或其他自定义数字序列),然后运行代码即可生成与标签对应的图像。
生成图像示例
result.png
:生成的图像文件,基于输入的标签和噪声生成。例如,当number
为”12345”时,生成的图像将尝试呈现出与数字1、2、3、4、5相关的手写数字特征。
可能遇到的问题
- CUDA错误(如果使用GPU训练):确保Jittor正确安装并且CUDA配置正常。可通过检查Jittor是否能正确检测到CUDA(运行
jt.has_cuda
)以及CUDA相关驱动和环境变量是否设置正确来排查。
- 内存问题:训练过程中可能会占用较多内存,建议根据系统内存情况使用合适的
batch_size
。如果内存不足,可以尝试减小batch_size
或者关闭其他占用内存的程序。
仓库信息
项目仓库链接如下:https://gitlink.org.cn/yifan_personal/yifan_jitu_hw.git
基于Jittor的条件生成对抗网络手写数字生成项目
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN, cGAN),旨在基于随机噪声和标签生成手写数字图像,并通过判别器进行训练,以使生成的图像尽可能接近真实的手写数字图像。该实现基于MNIST数据集,可通过标签控制生成特定数字的图像。
项目结构
. ├── generator_last.pkl # 训练后的生成器模型 ├── discriminator_last.pkl # 训练后的判别器模型 ├── result.png # 生成的图像 ├── CGAN.py # 主训练脚本 └── README.md # 项目的说明文件
依赖
该项目依赖以下Python包:
可以通过以下命令安装依赖:
数据集
本项目使用MNIST数据集,该数据集包含手写数字的28x28像素图像。Jittor提供了MNIST数据集的加载器,支持数据的自动下载和预处理。
参数说明
训练参数
以下是训练脚本支持的命令行参数:
--n_epochs
(默认值:100):训练的总轮数。--batch_size
(默认值:64):每个批次的图像数量。--lr
(默认值:0.0002):Adam优化器的学习率。--b1
(默认值:0.5):Adam优化器的一阶矩动量衰减。--b2
(默认值:0.999):Adam优化器的二阶矩动量衰减。--n_cpu
(默认值:8):用于批量生成的CPU线程数。--latent_dim
(默认值:100):隐变量的维度。--n_classes
(默认值:10):数据集的类别数量(MNIST数据集中为10个数字类别)。--img_size
(默认值:32):图像的尺寸(宽度/高度,本项目将MNIST图像调整为此尺寸)。--channels
(默认值:1):图像的通道数(MNIST为灰度图像,通道数为1)。--sample_interval
(默认值:1000):每多少步生成并保存一次图像。生成图像相关参数
在生成图像时,可在
CGAN.py
文件中修改number
变量来指定生成图像的标签序列(例如,可替换为电话号码、自定义数字序列等)。生成的图像
每隔
sample_interval
训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为.png
格式(如result.png
)。训练过程
在训练过程中,生成器和判别器交替训练。生成器尝试生成逼真的手写数字图像,以欺骗判别器将其判断为真实图像;判别器则努力区分生成图像与真实图像的差异,通过不断调整两者的参数,使生成器生成的图像质量逐渐提高。
如何运行
CGAN.py
文件中的以下代码: 将number
替换为你自己的数字序列(如手机号或其他自定义数字序列),然后运行代码即可生成与标签对应的图像。生成图像示例
result.png
:生成的图像文件,基于输入的标签和噪声生成。例如,当number
为”12345”时,生成的图像将尝试呈现出与数字1、2、3、4、5相关的手写数字特征。可能遇到的问题
jt.has_cuda
)以及CUDA相关驱动和环境变量是否设置正确来排查。batch_size
。如果内存不足,可以尝试减小batch_size
或者关闭其他占用内存的程序。仓库信息
项目仓库链接如下:https://gitlink.org.cn/yifan_personal/yifan_jitu_hw.git