目录
目录README.md

Jittor 热身赛 Conditional GAN

主要结果

简介

A Jittor implementation of Conditional GAN (CGAN).

本项目为第五届计图挑战赛热身赛解答,将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

安装

安装依赖

本代码框架依赖于 jittor,你可以通过 https://cg.cs.tsinghua.edu.cn/jittor/download/ 查询 jittor 安装教程。

本项目可在 CPU ,CUDA 等多个环境上运行、训练,详细步骤可以参考 https://cg.cs.tsinghua.edu.cn/jittor/download/

运行环境

本代码运行于 WSL2 系统。参考的运行环境如下:

  • ubuntu 22.04 LTS
  • python >= 3.7
  • jittor >= 1.3.0

训练与结果生成

本代码继承了数据导入、预处理、训练与生成,仅需要通过下述命令执行:

python3 CGAN.py

训练过程的结果每 1000 轮记录生成一次图像,直接保存于根目录下,最后根据用户 number 生成 result.png 同样存放于根目录下。

训练参数

可以通过命令行给定参数指定训练部分参数:

参数 类型 默认值 说明
--n_epochs int 100 训练总轮次
--batch_size int 64 每批数据量
--lr float 0.0002 学习率

优化器参数

| --b1 | float | 0.5 | Adam的β₁参数 | | --b2 | float | 0.999 | Adam的β₂参数 |

图像参数

| --img_size | int | 32 | 图像尺寸(px) | | --channels | int | 1 | 图像通道数 |

模型参数

| --latent_dim | int | 100 | 噪声向量维度 | | --n_classes | int | 10 | 分类类别数 |

系统参数

| --n_cpu | int | 8 | CPU线程数 | | --sample_interval | int | 1000 | 采样间隔步数 |

使用示例

python main.py --n_epochs 200 --batch_size 128

致谢

此项目基于论文 Conditional Generative Adversarial Nets 实现,部分代码参考了 jittor-gan

关于

A Jittor implementation of Conditional GAN (CGAN).

41.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号