目录
目录README.md

第五届计图人工智能挑战赛 - 热身赛:Conditional GAN 手写数字生成

主要结果

简介

本项目包含了第五届计图人工智能挑战赛 - 热身赛:Conditional GAN 手写数字生成的代码实现。本项目的特点是:使用 Jittor 深度学习框架,采用 Conditional GAN 方法在 MNIST 手写数字图片数据集上进行训练,取得了 0.9986 的准确率。

安装

本项目可在 1 张 GeForce MX450 上运行,训练时间约为 15 分钟。

运行环境

  • Ubuntu 20.04 LTS
  • python >= 3.7
  • jittor >= 1.3.0

安装依赖

执行以下命令安装 python 依赖:

pip install -r requirements.txt

预训练模型

预训练模型下载地址,下载后放入 <root> 目录下。

训练

可运行以下命令开始训练:

python CGAN.py

训练过程中:

  • 每个 epoch 结束后,将在 <root> 目录下保存生成器参数 generator_last.pkl 和判别器参数 discriminator_last.pkl
  • 每训练 sample_interval(默认值:1000)步,将生成一张样本图片保存在 <root>/samples 目录下,供调试用。

训练结束后,将在 <root> 目录下保存生成的数字图片 result.png

CGAN.py 的可选参数如下:

参数 默认值 描述
--number "0123456789" 指定生成的数字(字符串类型)
--n_epochs 50 训练的 epoch 数
--batch_size 64 批大小
--lr 0.0002 Adam 优化器的学习率
--b1 0.5 Adam 优化器的一阶动量衰减率
--b2 0.999 Adam 优化器的二阶动量衰减率
--n_cpu 8 生成批次时使用的 CPU 线程数
--latent_dim 100 潜在空间的维数
--n_classes 10 数据集的类别数
--img_size 32 图像的尺寸(长和宽)
--channels 1 图像的通道数
--sample_interval 1000 生成样本图片的间隔

推理

使用预训练模型生成指定的数字图片,可以运行以下命令:

python gen.py --number="9876543210"

gen.py 的可选参数如下:

参数 默认值 描述
--number "0123456789" 指定生成的数字(字符串类型)
--latent_dim 100 潜在空间的维数
--n_classes 10 数据集的类别数
--img_size 32 图像的尺寸(长和宽)
--channels 1 图像的通道数

运行 gen.py 时,应保持 --latent_dim--n_classes--img_size--channels 参数与运行 CGAN.py 时的参数一致。

致谢

此项目基于第五届计图人工智能挑战赛提供的示例代码实现。

相关链接

关于

A Jittor implementation of Conditional GAN (CGAN). 第五届计图人工智能挑战赛 - 热身赛 项目开源

55.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号