ADD file via upload
本项目将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
本项目包含了第五届计图挑战赛-热身赛的代码实现。本项目的特点是:训练一个将随机噪声和类别标签映射为数字图片的Conditiona1GAN模型,并生成给定用户随机ID对应的数字图片结果。
使用argparse解析训练的参数。 加载MNIST数据集,并进行预处理。
使用均方误差(MSE)作为损失函数。 定义生成器和判别器的优化器。
在每个训练周期(epoch)中,执行以下步骤:
获取批次数据。 训练生成器: 采样随机噪声和类别标签作为输入。 生成图片并计算生成器的损失。 更新生成器参数。 训练判别器: 对真实图片计算判别器的损失。 对生成图片计算判别器的损失。 总的判别器损失为两者的平均。 更新判别器参数。
加载训练好的生成器模型。 输入随机噪声和类别标签生成图片。 判别器的推理:
输入生成的图片和相应的标签,输出图片的真实性得分。
python CGAN.py
本项目使用jittor框架,将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Jittor 热身赛 CGAN
本项目将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
简介
本项目包含了第五届计图挑战赛-热身赛的代码实现。本项目的特点是:训练一个将随机噪声和类别标签映射为数字图片的Conditiona1GAN模型,并生成给定用户随机ID对应的数字图片结果。
安装
运行环境
训练
参数解析和数据加载:
使用argparse解析训练的参数。 加载MNIST数据集,并进行预处理。
定义生成器和判别器:
使用均方误差(MSE)作为损失函数。 定义生成器和判别器的优化器。
训练过程:
在每个训练周期(epoch)中,执行以下步骤:
获取批次数据。 训练生成器: 采样随机噪声和类别标签作为输入。 生成图片并计算生成器的损失。 更新生成器参数。 训练判别器: 对真实图片计算判别器的损失。 对生成图片计算判别器的损失。 总的判别器损失为两者的平均。 更新判别器参数。
推理
生成器的推理:
加载训练好的生成器模型。 输入随机噪声和类别标签生成图片。 判别器的推理:
加载训练好的判别器模型。
输入生成的图片和相应的标签,输出图片的真实性得分。
生成测试集上的结果可以运行以下命令: