JT_5th_Warmup_cjd
基于条件生成对抗网络(CGAN)的MNIST数字生成模型,使用Jittor框架实现。
项目特点
- 条件生成对抗网络(CGAN):结合类别标签信息,生成指定数字(0-9)的图像。
- MNIST数据集:使用经典手写数字数据集,输入图像尺寸为32x32像素。
- Jittor框架:支持动态图与静态图混合编程,高效利用GPU加速(需CUDA环境)。
环境要求
依赖项
硬件建议
- GPU(推荐):支持CUDA的NVIDIA显卡,加速训练过程。
- CPU:可运行但训练速度较慢。
使用方法
1. 训练模型
python your_script_name.py # 替换为实际文件名(如 cgan_mnist.py)
可选参数
参数名 |
说明 |
默认值 |
--n_epochs |
训练轮数 |
100 |
--batch_size |
批量大小 |
64 |
--lr |
Adam优化器学习率 |
0.0002 |
--sample_interval |
图像采样间隔(每生成n 批次保存一次) |
1000 |
训练过程
- 控制台实时输出判别器(D)和生成器(G)的损失值。
- 每
sample_interval
批次生成并保存一次示例图像(如1000.png
)。
- 每10个epoch保存模型参数(
generator_last.pkl
和discriminator_last.pkl
)。
2. 生成指定数字图像
打开代码,找到以下部分:
number = "TODO: 写入比赛页面中指定的数字序列(字符串类型)"
将number
替换为目标数字字符串(如"1234567890"
)。
运行脚本,生成结果保存在result.png
,展示指定数字的生成图像。
目录结构
.
├── your_script_name.py # 主代码文件
├── generator_last.pkl # 最后一次保存的生成器模型
├── discriminator_last.pkl # 最后一次保存的判别器模型
├── *.png # 训练过程中生成的示例图像(如 1000.png, 2000.png 等)
└── result.png # 最终生成的指定数字图像
代码说明
核心模块
生成器(Generator)
- 输入:随机噪声(100维)+ 类别标签(通过Embedding层编码)。
- 结构:多层全连接层,使用LeakyReLU激活函数和BatchNorm,最终输出32x32单通道图像(通过Tanh归一化到[-1, 1])。
判别器(Discriminator)
- 输入:图像(展平为1024维向量)+ 类别标签(Embedding编码)。
- 结构:多层全连接层,使用LeakyReLU和Dropout防止过拟合,最终输出单个实数表示真实性概率。
损失函数
- 均方误差(MSELoss):判别器区分真实/生成图像(标签1/0),生成器欺骗判别器(目标标签1)。
数据处理
- 使用Jittor内置的MNIST数据集,预处理包括Resize、灰度化、归一化(均值0.5,标准差0.5)。
结果示例
- 训练过程中生成的图像示例(如
1000.png
)展示不同epoch下的生成效果,随着训练轮数增加,数字清晰度逐渐提升。
result.png
为指定数字序列的生成结果,每个数字对应输入的类别标签。
贡献与反馈
欢迎通过GitHub提交Issue或Pull Request,优化代码、修复问题或添加新功能。
提交Issue时请注明环境配置、复现步骤及具体问题描述。
许可证
本项目采用 MIT License,详见LICENSE文件。允许自由使用、修改和分发,但需保留原作者声明。
JT_5th_Warmup_cjd
基于条件生成对抗网络(CGAN)的MNIST数字生成模型,使用Jittor框架实现。
项目特点
环境要求
依赖项
numpy
,argparse
,Pillow
,matplotlib
(可选,用于可视化)。硬件建议
使用方法
1. 训练模型
可选参数
--n_epochs
--batch_size
--lr
--sample_interval
n
批次保存一次)训练过程
sample_interval
批次生成并保存一次示例图像(如1000.png
)。generator_last.pkl
和discriminator_last.pkl
)。2. 生成指定数字图像
打开代码,找到以下部分:
将
number
替换为目标数字字符串(如"1234567890"
)。运行脚本,生成结果保存在
result.png
,展示指定数字的生成图像。目录结构
代码说明
核心模块
生成器(Generator)
判别器(Discriminator)
损失函数
数据处理
结果示例
1000.png
)展示不同epoch下的生成效果,随着训练轮数增加,数字清晰度逐渐提升。result.png
为指定数字序列的生成结果,每个数字对应输入的类别标签。贡献与反馈
欢迎通过GitHub提交Issue或Pull Request,优化代码、修复问题或添加新功能。
提交Issue时请注明环境配置、复现步骤及具体问题描述。
许可证
本项目采用 MIT License,详见LICENSE文件。允许自由使用、修改和分发,但需保留原作者声明。