Delete README.md
本项目基于 Jittor 实现一个 Conditional GAN(条件生成对抗网络):输入随机噪声 z 与类别标签 y,生成对应类别的手写数字图像;并支持按指定数字序列生成并保存拼接结果图 result.png 。
z
y
result.png
jittor.dataset.mnist
G(z, y)
D(x, y)
学号
Generator(生成器) 🏗️将噪声向量 z 与标签嵌入(Embedding)后的条件向量拼接,经多层全连接网络生成图像(输出为 C×H×W)。
C×H×W
Discriminator(判别器) 🕵️♂️将输入图像 x 展平为向量,与标签嵌入向量拼接,经过多层全连接网络输出一个标量,表示该“图像-标签”对为真的概率/置信度 。
x
损失函数 📉使用 MSELoss(最小二乘 GAN / LSGAN 风格):
MSELoss
libomp-dev
sudo apt update sudo apt install -y libomp-dev build-essential git
说明:如果你的系统源中找不到 python3.7-dev 属于正常现象(新版本 Ubuntu 往往不再提供 3.7 的 dev 包)。推荐直接使用 conda 创建 Python 3.7 环境
python3.7-dev
conda create -n jittor37 python=3.7 -y conda activate jittor37 python -m pip install -U pip setuptools wheel python -m pip install jittor # 或安装 github 最新版(二选一) # python -m pip install git+https://github.com/Jittor/jittor.git
python -m jittor.test.test_example
第一次运行 Jittor 可能会触发编译缓存,耗时较长属于正常现象
项目主脚本为:
CGAN.py
直接运行(使用默认参数训练):
python CGAN.py
脚本支持命令行参数(argparse):
--n_epochs
--batch_size
--lr
--b1
--b2
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
示例(训练 30 epoch):
python CGAN.py --n_epochs 30 --batch_size 64 --lr 0.0002
在脚本末尾修改:
number = "2213119" # 改成比赛页面要求的字符串数字序列
训练结束后脚本会生成并保存:
number
训练过程中/训练结束后会产生常见输出:
*.png
generator_last.pkl
discriminator_last.pkl
如果出现类似 “Embedding index must be int” 的问题,确保标签是整型(如 int32),而不是 float32。例如:
int32
float32
labels = labels.int32()
GAN 在早期通常不稳定是正常的。可尝试:
. ├── CGAN.py ├── generator_last.pkl ├── discriminator_last.pkl ├── result.png └── (训练过程中生成的若干 *.png)
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Conditional GAN for MNIST (Jittor) 🧠✨
本项目基于 Jittor 实现一个 Conditional GAN(条件生成对抗网络):输入随机噪声
z与类别标签y,生成对应类别的手写数字图像;并支持按指定数字序列生成并保存拼接结果图result.png。1. 任务说明 🎯
jittor.dataset.mnist自动下载)G(z, y)生成指定类别数字图像,同时训练判别器D(x, y)判断“图像-标签”对是否真实学号)生成对应数字图像并拼接保存2. 方法概述 🧪
Generator(生成器) 🏗️
将噪声向量
z与标签嵌入(Embedding)后的条件向量拼接,经多层全连接网络生成图像(输出为C×H×W)。Discriminator(判别器) 🕵️♂️
将输入图像
x展平为向量,与标签嵌入向量拼接,经过多层全连接网络输出一个标量,表示该“图像-标签”对为真的概率/置信度 。损失函数 📉
使用
MSELoss(最小二乘 GAN / LSGAN 风格):3. 环境依赖 🧰
libomp-dev),用于加速/编译相关4. 安装与配置 🛠️
4.1 安装系统依赖(Ubuntu/Debian)
4.2 使用 Conda 创建虚拟环境并安装 Jittor
4.3 运行 Jittor 自检
5. 训练与生成 🏋️♀️
5.1 开始训练
项目主脚本为:
CGAN.py直接运行(使用默认参数训练):
5.2 常用参数
脚本支持命令行参数(argparse):
--n_epochs:训练轮数(默认 100)🔁--batch_size:批大小(默认 64)📦--lr:学习率(默认 0.0002)📈--b1/--b2:Adam 的 beta 参数(默认 0.5 / 0.999)⚙️--latent_dim:噪声维度(默认 100)🌫️--n_classes:类别数(默认 10)🔟--img_size:图像尺寸(默认 32,会对 MNIST 做 Resize)📐--channels:通道数(默认 1)🎛️--sample_interval:采样间隔(默认 1000 step 保存一次采样图)⏱️示例(训练 30 epoch):
5.3 生成指定数字序列(用于提交)
在脚本末尾修改:
训练结束后脚本会生成并保存:
result.png:按number生成的数字序列拼接图(提交文件)6. 输出文件说明 📂
训练过程中/训练结束后会产生常见输出:
*.png:按--sample_interval采样生成的网格图(用于观察训练过程)generator_last.pkl:生成器参数(每 10 epoch 保存一次)discriminator_last.pkl:判别器参数(每 10 epoch 保存一次)result.png:最终按指定序列生成的结果图(用于提交)7. 常见问题(FAQ)🧩
7.1 Embedding 报错:索引类型不对?
如果出现类似 “Embedding index must be int” 的问题,确保标签是整型(如
int32),而不是float32。例如:7.2 训练初期生成很糊/噪声很大?
GAN 在早期通常不稳定是正常的。可尝试:
--lr(后期降低学习率更平稳)8. 目录结构(示例)🗂️
9. 参考 📚