number = "your_number_sequence" # 替换为需要生成的数字序列
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)
JittorConditionalGAN
CGAN 图像生成项目 README
| 标题名称包含赛题、方法 | 说明:CGAN(Conditional Generative Adversarial Network)是基于 GAN 的一种条件生成对抗网络实现,用于根据给定条件(如标签)生成图像
简介
本项目实现了一个基于 Jittor 框架的条件生成对抗网络(CGAN),用于根据给定的标签生成手写数字图像(以 MNIST 数据集为例)。项目特点包括:
安装
运行环境
本项目可在主流 GPU 环境下运行,推荐配置:
Linux 和 macOS 环境要求
Python:版本 >=3.7 C++编译器(需要下列至少一个): g++:(Linux)>=5.4.0 clang:(macOS)>=8.0 GPU 编译器(可选):nvcc >=10.0 GPU 加速库(可选):cudnn-dev (推荐使用 tar 安装方法) Jittor 目前还支持主流国产 Linux 操作系统,如统信、麒麟、普华、龙芯 Loongnix。安装方式可参考 Linux pip 安装方法,准备好 python 和 g++ 即可。
Windows 环境要求
Python:版本 >=3.8 处理器:x86_64 操作系统:Windows 10 及以上
安装项目依赖
数据准备
本项目使用 MNIST 数据集,Jittor 将自动下载数据集。数据集将被存储在
~/.jittor/datasets/mnist
目录下。训练过程中将对图像进行以下预处理:训练
单卡训练
运行以下命令开始单卡训练:
训练过程中将输出每个批次的生成器和判别器损失,并在指定间隔保存生成的图像样本。
推理与测试
在训练过程中,模型会定期采样生成图像并保存。可以通过以下代码生成特定数字序列的图像:
生成的图像将被保存为
result.png
文件。模型与结果
模型结构
结果展示
训练过程中生成的图像样本将展示生成器逐渐学习到生成逼真数字图像的能力。最终生成的图像将根据给定的标签序列生成对应的数字图像。
致谢
本项目基于以下资源实现:
注意事项