CGAN_Jittor
A Conditional GAN (CGAN) implementation in Jittor for generating MNIST digit images.
清华大学《计算机图形学基础》课程实验三(PA3)项目
🌟 项目简介
本项目基于 Jittor 深度学习框架实现了 Conditional Generative Adversarial Network(CGAN),用于在 MNIST 数据集上生成手写数字图像。通过引入类别标签,CGAN 能够生成指定数字的图像。该项目为清华大学《计算机图形学基础》课程 PA3 实验的实践任务。
📁 项目结构
.
├── CGAN.py # 主程序,包含网络定义、训练与推理
├── generator_last.pkl # 训练保存的生成器模型(示例)
├── discriminator_last.pkl # 训练保存的判别器模型(示例)
├── result.png # 生成的指定数字图像拼接图(示例)
└── README.md # 项目说明文件
🧠 模型架构
Generator(生成器):
- 输入:随机噪声 z + 标签 embedding
- 输出:32×32 的灰度图像
Discriminator(判别器):
- 输入:图像 + 标签 embedding
- 输出:实数(越接近 1 表示越可能是真实图像)
损失函数:采用均方误差(MSE)
🚀 运行说明
依赖安装
请确保系统安装了 Jittor 框架,可使用如下命令:
pip install jittor
确保安装版本在 1.2.2.59 及以上。
训练模型
直接运行:
python CGAN.py
你可以使用以下参数进行自定义:
--n_epochs # 训练轮数(默认:100)
--batch_size # 每批图像数量(默认:64)
--lr # 学习率(默认:0.0002)
--latent_dim # 隐空间维度(默认:100)
--img_size # 图像尺寸(默认:32)
--sample_interval # 每训练多少步保存一次生成图像(默认:1000)
例如:
python CGAN.py --n_epochs 200 --batch_size 128
推理生成
请在代码末尾填写指定数字序列,例如:
number = "3141592653"
然后重新运行 CGAN.py 生成 result.png。
CGAN_Jittor
🌟 项目简介
本项目基于 Jittor 深度学习框架实现了 Conditional Generative Adversarial Network(CGAN),用于在 MNIST 数据集上生成手写数字图像。通过引入类别标签,CGAN 能够生成指定数字的图像。该项目为清华大学《计算机图形学基础》课程 PA3 实验的实践任务。
📁 项目结构
. ├── CGAN.py # 主程序,包含网络定义、训练与推理 ├── generator_last.pkl # 训练保存的生成器模型(示例) ├── discriminator_last.pkl # 训练保存的判别器模型(示例) ├── result.png # 生成的指定数字图像拼接图(示例) └── README.md # 项目说明文件
🧠 模型架构
Generator(生成器):
Discriminator(判别器):
损失函数:采用均方误差(MSE)
🚀 运行说明
依赖安装
请确保系统安装了 Jittor 框架,可使用如下命令: