chore: refactor the structure of files
A Jittor implementation of Conditional GAN (CGAN).
本项目为第五届计图挑战赛热身赛解答,将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
本代码框架依赖于 jittor,你可以通过 https://cg.cs.tsinghua.edu.cn/jittor/download/ 查询 jittor 安装教程。
jittor
本项目可在 CPU ,CUDA 等多个环境上运行、训练,详细步骤可以参考 https://cg.cs.tsinghua.edu.cn/jittor/download/ 。
本代码运行于 WSL2 系统。参考的运行环境如下:
本代码继承了数据导入、预处理、训练与生成,仅需要通过下述命令执行:
python3 CGAN.py
训练过程的结果每 1000 轮记录生成一次图像,直接保存于根目录下,最后根据用户 number 生成 result.png 同样存放于根目录下。
number
result.png
可以通过命令行给定参数指定训练部分参数:
--n_epochs
int
100
--batch_size
64
--lr
float
0.0002
| --b1 | float | 0.5 | Adam的β₁参数 | | --b2 | float | 0.999 | Adam的β₂参数 |
--b1
0.5
--b2
0.999
| --img_size | int | 32 | 图像尺寸(px) | | --channels | int | 1 | 图像通道数 |
--img_size
32
--channels
1
| --latent_dim | int | 100 | 噪声向量维度 | | --n_classes | int | 10 | 分类类别数 |
--latent_dim
--n_classes
10
| --n_cpu | int | 8 | CPU线程数 | | --sample_interval | int | 1000 | 采样间隔步数 |
--n_cpu
8
--sample_interval
1000
python main.py --n_epochs 200 --batch_size 128
此项目基于论文 Conditional Generative Adversarial Nets 实现,部分代码参考了 jittor-gan。
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Jittor 热身赛 Conditional GAN
简介
A Jittor implementation of Conditional GAN (CGAN).
本项目为第五届计图挑战赛热身赛解答,将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
安装
安装依赖
本代码框架依赖于
jittor
,你可以通过 https://cg.cs.tsinghua.edu.cn/jittor/download/ 查询jittor
安装教程。本项目可在 CPU ,CUDA 等多个环境上运行、训练,详细步骤可以参考 https://cg.cs.tsinghua.edu.cn/jittor/download/ 。
运行环境
本代码运行于 WSL2 系统。参考的运行环境如下:
训练与结果生成
本代码继承了数据导入、预处理、训练与生成,仅需要通过下述命令执行:
训练过程的结果每 1000 轮记录生成一次图像,直接保存于根目录下,最后根据用户
number
生成result.png
同样存放于根目录下。训练参数
可以通过命令行给定参数指定训练部分参数:
--n_epochs
int
100
--batch_size
int
64
--lr
float
0.0002
优化器参数
|
--b1
|float
|0.5
| Adam的β₁参数 | |--b2
|float
|0.999
| Adam的β₂参数 |图像参数
|
--img_size
|int
|32
| 图像尺寸(px) | |--channels
|int
|1
| 图像通道数 |模型参数
|
--latent_dim
|int
|100
| 噪声向量维度 | |--n_classes
|int
|10
| 分类类别数 |系统参数
|
--n_cpu
|int
|8
| CPU线程数 | |--sample_interval
|int
|1000
| 采样间隔步数 |使用示例
致谢
此项目基于论文 Conditional Generative Adversarial Nets 实现,部分代码参考了 jittor-gan。