第五届计图人工智能挑战赛热身赛 手写数字生成 Conditional GAN

简介
本项目包含了第五届计图人工智能挑战赛热身赛 - 手写数字生成的代码实现。本项目的特点是:使用 Conditional GAN 在 MNIST 数据集上进行对抗训练,取得了0.9981的准确率。
安装
本项目在 1 张 4090 上运行,训练时间约为 15 分钟。
运行环境
- ubuntu 24.04 LTS
- python >= 3.7
- jittor >= 1.3.0
安装依赖
执行以下命令安装 python 依赖
pip install -r requirements.txt
预训练模型
预训练模型模型下载地址,解压后放置到<root>
目录下。
训练
可运行以下命令开始训练:
python CGAN.py
训练将在当前目录下生成:
- discriminator_last.pkl 和 generator_last.pkl:保存模型的参数
- nnnnn.png:模型训练过程中每隔1000步随机采样,用于观察训练效果
- result.png:根据CGAN.py 中numbers字符串采样的输出
CGAN.py的可选参数如下:
参数 |
描述 |
默认值 |
--n_epochs N_EPOCHS |
训练的 epoch 数量 |
100 |
--batch_size BATCH_SIZE |
批大小 |
64 |
--lr LR |
Adam 优化器的学习率 |
0.0002 |
--b1 B1 |
Adam 优化器的一阶动量衰减率 |
0.5 |
--b2 B2 |
Adam 优化器的二阶动量衰减率 |
0.999 |
--n_cpu N_CPU |
生成批次时使用的 CPU 线程数 |
8 |
--latent_dim LATENT_DIM |
潜在空间的维度 |
100 |
--n_classes N_CLASSES |
数据集的类别数量 |
10 |
--img_size IMG_SIZE |
图像的尺寸(长和宽) |
32 |
--channels CHANNELS |
图像的通道数 |
1 |
--sample_interval SAMPLE_INTERVAL |
图像采样的间隔 |
1000 |
推理
运行以下命令基于预训练模型进行推理:
python infer.py
致谢
此项目基于示例代码实现。
有关计图的更多信息,参见
第五届计图人工智能挑战赛热身赛 手写数字生成 Conditional GAN
简介
本项目包含了第五届计图人工智能挑战赛热身赛 - 手写数字生成的代码实现。本项目的特点是:使用 Conditional GAN 在 MNIST 数据集上进行对抗训练,取得了0.9981的准确率。
安装
本项目在 1 张 4090 上运行,训练时间约为 15 分钟。
运行环境
安装依赖
执行以下命令安装 python 依赖
预训练模型
预训练模型模型下载地址,解压后放置到
<root>
目录下。训练
可运行以下命令开始训练:
训练将在当前目录下生成:
CGAN.py的可选参数如下:
--n_epochs N_EPOCHS
--batch_size BATCH_SIZE
--lr LR
--b1 B1
--b2 B2
--n_cpu N_CPU
--latent_dim LATENT_DIM
--n_classes N_CLASSES
--img_size IMG_SIZE
--channels CHANNELS
--sample_interval SAMPLE_INTERVAL
推理
运行以下命令基于预训练模型进行推理:
致谢
此项目基于示例代码实现。
有关计图的更多信息,参见