目录
目录readme.md

第五届计图人工智能挑战赛热身赛 手写数字生成 Conditional GAN

主要结果

简介

本项目包含了第五届计图人工智能挑战赛热身赛 - 手写数字生成的代码实现。本项目的特点是:使用 Conditional GAN 在 MNIST 数据集上进行对抗训练,取得了0.9981的准确率。

安装

本项目在 1 张 4090 上运行,训练时间约为 15 分钟。

运行环境

  • ubuntu 24.04 LTS
  • python >= 3.7
  • jittor >= 1.3.0

安装依赖

执行以下命令安装 python 依赖

pip install -r requirements.txt

预训练模型

预训练模型模型下载地址,解压后放置到<root>目录下。

训练

可运行以下命令开始训练:

python CGAN.py

训练将在当前目录下生成:

  • discriminator_last.pkl 和 generator_last.pkl:保存模型的参数
  • nnnnn.png:模型训练过程中每隔1000步随机采样,用于观察训练效果
  • result.png:根据CGAN.py 中numbers字符串采样的输出

CGAN.py的可选参数如下:

参数 描述 默认值
--n_epochs N_EPOCHS 训练的 epoch 数量 100
--batch_size BATCH_SIZE 批大小 64
--lr LR Adam 优化器的学习率 0.0002
--b1 B1 Adam 优化器的一阶动量衰减率 0.5
--b2 B2 Adam 优化器的二阶动量衰减率 0.999
--n_cpu N_CPU 生成批次时使用的 CPU 线程数 8
--latent_dim LATENT_DIM 潜在空间的维度 100
--n_classes N_CLASSES 数据集的类别数量 10
--img_size IMG_SIZE 图像的尺寸(长和宽) 32
--channels CHANNELS 图像的通道数 1
--sample_interval SAMPLE_INTERVAL 图像采样的间隔 1000

推理

运行以下命令基于预训练模型进行推理:

python infer.py

致谢

此项目基于示例代码实现。

有关计图的更多信息,参见

关于

第五届计图人工智能挑战赛热身赛代码开源 jittor

54.0 KB
邀请码