实验报告pdf
我们用 Jittor (计图) 在 MNIST 上训练了一个 Conditional GAN(CGAN):输入随机噪声 z 和类别标签 y,生成指定数字的手写体图像。 代码入口只有一个 CGAN.py,因为数据集直接通过 Jittor 内置的 MNIST loader 下载/读取,训练、生成、保存都写在同一个脚本里。 最终提交产物是 result.png(按随机 ID 逐位生成并拼成一张图),以及可选的模型权重 generator_last.pkl / discriminator_last.pkl。
z
y
CGAN.py
result.png
generator_last.pkl / discriminator_last.pkl
2310500 2312313
tanh -> [0,255]
jittor
numpy
Pillow
matplotlib/tqdm
# 1)(推荐)创建 conda 环境:隔离依赖,方便复现 conda create -n cgan-jittor python=3.9 -y conda activate cgan-jittor # 2) 安装依赖 pip install -U pip pip install jittor numpy pillow # 如果是 GPU 环境(推荐 T4/3060 等),请运行 CUDA 配置: # python -m jittor_utils.install_cuda
Windows 差异:Windows 本机运行时,首次 import jittor 也会触发编译/缓存;相比 WSL2 更容易遇到编译链/路径问题,所以我把主流程写成 WSL2 版本。
import jittor
python CGAN.py --n_epochs 1 --batch_size 64
python CGAN.py --n_epochs 50 --batch_size 64 --lr 0.0002
训练过程中会按 sample_interval 定期保存采样图(用于观察训练质量)。
sample_interval
备注:CGAN 训练轮数多的时候会比较吃时间(尤其是 CPU 环境)。如果本地跑得慢,为了提高效率也可以把 CGAN.py 放到云服务器/云 Notebook 上训练,最后把生成的 result.png 下载回来打包提交即可。
–n_epochs:训练轮数(建议 30~50 以上) –batch_size:批大小 –log_interval:日志间隔(batch) –d_dropout:判别器 Dropout(适当减小可提升生成质量)
在 CGAN.py 末尾找到:
number = "2310500"
把它改成你自己的随机 ID 字符串即可。脚本运行结束会输出:
np.random.seed(...)
jt.set_global_seed(...)
tanh
(x * 0.5 + 0.5) * 255
labels
int32
int32()
tanh -> [0,1] -> [0,255]
A Jittor implementation of Conditional GAN (CGAN)
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor_MNIST
简述
我们用 Jittor (计图) 在 MNIST 上训练了一个 Conditional GAN(CGAN):输入随机噪声
z和类别标签y,生成指定数字的手写体图像。 代码入口只有一个CGAN.py,因为数据集直接通过 Jittor 内置的 MNIST loader 下载/读取,训练、生成、保存都写在同一个脚本里。 最终提交产物是result.png(按随机 ID 逐位生成并拼成一张图),以及可选的模型权重generator_last.pkl / discriminator_last.pkl。任务
2310500 2312313)时,能够生成对应数字序列的图片result.png。特性
result.png,不需要额外写推理脚本。y(Embedding + concat),可控生成指定数字而不是随机数字。tanh -> [0,255]固定映射,避免 min-max 拉伸导致对比度漂移。环境与依赖
jittor、numpy、Pillow(以及可选的matplotlib/tqdm)安装(WSL2 / Ubuntu)
使用方法
1)最短可跑(先确认环境没问题)
2)正式训练
训练过程中会按
sample_interval定期保存采样图(用于观察训练质量)。3)参数说明
–n_epochs:训练轮数(建议 30~50 以上) –batch_size:批大小 –log_interval:日志间隔(batch) –d_dropout:判别器 Dropout(适当减小可提升生成质量)
4)生成指定随机 ID
在
CGAN.py末尾找到:把它改成你自己的随机 ID 字符串即可。脚本运行结束会输出:
result.png(最终提交用图)generator_last.pkl / discriminator_last.pkl(训练权重)项目结构
CGAN.py:模型定义(G/D)+ MNIST 数据加载 + 训练循环 + 最终生成并保存result.pngresult.png:最终提交用的结果图(按随机 ID 逐位生成并拼接)generator_last.pkl / discriminator_last.pkl:训练过程中保存的权重(可选)复现说明
result.pngnp.random.seed(...)+jt.set_global_seed(...),待补充)tanh,保存时按(x * 0.5 + 0.5) * 255映射到灰度图(避免 min-max 导致发灰)常见问题 FAQ
原因:Jittor 首次会编译并缓存算子(正常现象)。处理:耐心等一次,后面运行会快很多。
标签
labels需要是整数类型(例如int32)。如果你把标签转成了 float,Embedding 可能会直接报错。处理:确保训练与生成阶段的 labels 都是int32()。不用 min-max 归一化去保存每张图(不同 batch 的 min/max 会变化,观感不稳定)。我这里使用固定映射
tanh -> [0,1] -> [0,255]。MNIST 通过 Jittor 的 dataset 接口创建 dataloader,数据下载/缓存由框架处理,所以项目目录里不一定有独立的数据文件夹。
合作者 / 友情链接
许可 / 声明