Jittor 热身赛


简介
本项目包含了第二届计图挑战赛计图 -热身赛的代码实现。本项目基于图片数据集 MNIST,训练了一个将随机噪声和类别标签映射为数字图片的Conditional GAN模型,并生成注册时绑定的手机号。
安装
本项目可在 标压九代i9上运行,训练时间约为 1.5 小时。
运行环境
- ubuntu 20.04 LTS 或 Windows 11/10
- python >= 3.7
- jittor >= 1.3.0
- (可选) cuda >= 11.0
安装依赖
执行以下命令安装 python 依赖
pip install jittor
ToDo部分
添加一个线性层,输出为一个实数
self.model = nn.Sequential(
nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512,1))
将d_in输入模型进行计算,返回计算结果
def execute(self, img, labels):
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
validity = self.model(d_in)
return validity
计算真实数据与生成数据的损失
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real,valid)
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake,fake)
设置要生成的数字
number = '15533237079'
训练
可设置
jt.flags.use_cuda = 1
选择开启显卡加速训练
致谢
此项目代码参考了 jittor-gan。
Jittor 热身赛
简介
本项目包含了第二届计图挑战赛计图 -热身赛的代码实现。本项目基于图片数据集 MNIST,训练了一个将随机噪声和类别标签映射为数字图片的Conditional GAN模型,并生成注册时绑定的手机号。
安装
本项目可在 标压九代i9上运行,训练时间约为 1.5 小时。
运行环境
安装依赖
执行以下命令安装 python 依赖
ToDo部分
添加一个线性层,输出为一个实数
将d_in输入模型进行计算,返回计算结果
计算真实数据与生成数据的损失
设置要生成的数字
训练
可设置
jt.flags.use_cuda = 1
选择开启显卡加速训练
致谢
此项目代码参考了 jittor-gan。