目录
目录README.md

CGAN for Jittor

在计图平台上基于 MNIST 数据集训练的条件对抗生成网络 (CGAN) 模型。该模型可以根据给定的数字序列生成对应的手写数字图片。

GAN 能够通过生成器和判别器的对抗训练来获得拟真的生成能力。CGAN 则在此基础上引入了条件控制,使得我们可以控制需要生成的图片的某些特征。因此,本项目采用 CGAN 作为模型结构,并通过手写数字数据集 MNIST 训练,得到了效果优秀的手写数字图片生成器。

安装依赖

本项目基于计图框架,你需要先跟随 计图安装 来完成框架的安装。

运行

使用如下命令来运行程序:

usage: CGAN.py [-h] [--n_epochs N_EPOCHS] [--batch_size BATCH_SIZE] [--lr LR] [--b1 B1] [--b2 B2] [--n_cpu N_CPU]
               [--latent_dim LATENT_DIM] [--n_classes N_CLASSES] [--img_size IMG_SIZE] [--channels CHANNELS]
               [--sample_interval SAMPLE_INTERVAL]

各参数的含义如下:

参数 说明
-h, –help 显示帮助信息并退出
–n_epochs N_EPOCHS 训练轮次数
–batch_size BATCH_SIZE 批次大小
–lr LR Adam 学习率
–b1 B1 Adam 梯度一阶动量衰减率
–b2 B2 Adam 梯度二阶动量衰减率
–n_cpu N_CPU 批次生成使用的 CPU 线程数
–latent_dim LATENT_DIM 生成器输入的随机向量维度
–n_classes N_CLASSES 数据集类别数
–img_size IMG_SIZE 生成图片每个维度的大小
–channels CHANNELS 生成图片的通道数
–sample_interval SAMPLE_INTERVAL 保存生成图片的间隔批次数

运行后,程序会自动下载 MNIST 数据集并开始训练。训练过程中,程序会每隔 sample_interval 批次保存一张生成的图片。每 10 轮训练结束后,程序会自动保存模型到 generator_last.pkldiscriminator_last.pkl 文件中。训练结束后,程序将生成的图片保存到 result.png 文件中。

以下是生成图片的示例:

图片生成示例

许可

本项目采用 WTFPLv2 许可证。

关于

一个条件对抗生成网络 (CGAN) 的计图实现 A Jittor implementation of Conditional GAN (CGAN)

596.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号