delete
本项目是第三届计图人工智能挑战赛的参赛代码。项目基于 Jittor 深度学习框架实现了一个条件生成对抗网络(CGAN),能够根据给定的类别标签(0-9 的数字)生成相应的手写数字图像。模型在 MNIST 数据集上进行训练。
本项目运行依赖以下环境:
关于 Jittor 的安装和配置,请参考 Jittor 官网文档。
直接运行 CGAN.py 脚本即可开始训练,并在训练结束后生成特定数字序列的图像。
CGAN.py
python3 CGAN.py
模型训练: 脚本会加载 MNIST 数据集,并训练 Generator 和 Discriminator。 默认训练 100 个 Epoch。 每隔一定步数会保存中间采样图片。
模型保存: 训练过程中每 10 个 Epoch 会保存一次模型权重。 最终模型权重保存在:
generator_last.pkl
discriminator_last.pkl
结果生成: 训练完成后,脚本会自动加载最终保存的模型,并根据代码中预设的数字序列(如 2312479)生成对应的拼接图像,保存为 result.png。
2312479
result.png
你可以通过命令行参数自定义训练配置:
python CGAN.py --n_epochs 200 --batch_size 64 --lr 0.0002
主要参数:
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
运行结束后生成的 result.png 展示了模型根据特定数字序列生成的图像结果。
热身赛作业
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
第三届计图人工智能挑战赛 - Jittor 赛题
简介
本项目是第三届计图人工智能挑战赛的参赛代码。项目基于 Jittor 深度学习框架实现了一个条件生成对抗网络(CGAN),能够根据给定的类别标签(0-9 的数字)生成相应的手写数字图像。模型在 MNIST 数据集上进行训练。
环境依赖
本项目运行依赖以下环境:
关于 Jittor 的安装和配置,请参考 Jittor 官网文档。
快速开始
训练模型与生成结果
直接运行
CGAN.py脚本即可开始训练,并在训练结束后生成特定数字序列的图像。脚本功能说明
模型训练: 脚本会加载 MNIST 数据集,并训练 Generator 和 Discriminator。 默认训练 100 个 Epoch。 每隔一定步数会保存中间采样图片。
模型保存: 训练过程中每 10 个 Epoch 会保存一次模型权重。 最终模型权重保存在:
generator_last.pkldiscriminator_last.pkl结果生成: 训练完成后,脚本会自动加载最终保存的模型,并根据代码中预设的数字序列(如
2312479)生成对应的拼接图像,保存为result.png。参数说明
你可以通过命令行参数自定义训练配置:
主要参数:
--n_epochs: 训练的总轮数 (默认: 100)--batch_size: 批次大小 (默认: 64)--lr: Adam 优化器的学习率 (默认: 0.0002)--b1: Adam 优化器的 beta1 (默认: 0.5)--b2: Adam 优化器的 beta2 (默认: 0.999)--latent_dim: 隐向量的维度 (默认: 100)--n_classes: 类别数量 (默认: 10,对应 MNIST 数字 0-9)--img_size: 图片尺寸 (默认: 32)--channels: 图片通道数 (默认: 1)--sample_interval: 保存采样图片的间隔步数 (默认: 1000)结果展示
运行结束后生成的
result.png展示了模型根据特定数字序列生成的图像结果。