Merge branch ‘dev’ into master
本项目基于 Jittor 框架 实现,使用数字图片数据集 MNIST 训练一个条件生成对抗网络(Conditional GAN)模型,并通过输入一个随机向量 z 和额外的辅助信息 y(如类别标签),生成给定的特定数字序列的图像。
使用 Anaconda 创建虚拟环境,名称为 jittor:
jittor
conda create -n jittor python=3.9 conda activate jittor conda install -c conda-forge gcc_linux-64==11.2.0 conda install -c conda-forge gxx_impl_linux-64==11.2.0 conda install -c conda-forge libstdcxx-ng=12 python3 -m pip install scipy trimesh matplotlib python3 -m pip install jittor
首先将 CGAN.py 中的 number 替换为目标数字序列字符串,运行:
CGAN.py
number
python CGAN.py
如需调整训练参数,可通过命令行传入,如:
python CGAN.py --n_epochs 150 --batch_size 64 --lr 0.001
更多命令行参数详见 CGAN.py。
数据处理:程序将调用 Jittor 框架自带接口加载 MNIST 数据集,并设置相应的图像变换。
模型训练:在模型训练过程中,每 sample_interval 个 batch 将保存一次生成的图像样本,每 10 个 epoch 将保存一次模型生成器(generator_last.pkl)和判别器(discriminator_last.pkl)。
sample_interval
generator_last.pkl
discriminator_last.pkl
模型推理:训练完成后,程序加载保存的模型,并根据设定的数字序列字符串生成对应的手写数字图像。
可以通过 此云盘链接 下载训练得到的模型文件和生成的图像结果。
本项目基于比赛官方提供的 代码框架 完成。
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN-Jittor:计图挑战热身赛
简介
本项目基于 Jittor 框架 实现,使用数字图片数据集 MNIST 训练一个条件生成对抗网络(Conditional GAN)模型,并通过输入一个随机向量 z 和额外的辅助信息 y(如类别标签),生成给定的特定数字序列的图像。
环境配置
使用 Anaconda 创建虚拟环境,名称为
jittor
:运行指令
首先将
CGAN.py
中的number
替换为目标数字序列字符串,运行:如需调整训练参数,可通过命令行传入,如:
更多命令行参数详见
CGAN.py
。运行过程
数据处理:程序将调用 Jittor 框架自带接口加载 MNIST 数据集,并设置相应的图像变换。
模型训练:在模型训练过程中,每
sample_interval
个 batch 将保存一次生成的图像样本,每 10 个 epoch 将保存一次模型生成器(generator_last.pkl
)和判别器(discriminator_last.pkl
)。模型推理:训练完成后,程序加载保存的模型,并根据设定的数字序列字符串生成对应的手写数字图像。
训练结果下载
可以通过 此云盘链接 下载训练得到的模型文件和生成的图像结果。
致谢
本项目基于比赛官方提供的 代码框架 完成。