目录
目录README.md

JittorConditionalGAN

CGAN 图像生成项目 README

| 标题名称包含赛题、方法 | 说明:CGAN(Conditional Generative Adversarial Network)是基于 GAN 的一种条件生成对抗网络实现,用于根据给定条件(如标签)生成图像

简介

本项目实现了一个基于 Jittor 框架的条件生成对抗网络(CGAN),用于根据给定的标签生成手写数字图像(以 MNIST 数据集为例)。项目特点包括:

  • 采用 Jittor 框架高效实现 CGAN 模型
  • 基于 MNIST 数据集进行数字图像生成任务
  • 通过条件约束提高生成图像的相关性和质量
  • 提供完整的训练、采样和测试流程

安装

运行环境

本项目可在主流 GPU 环境下运行,推荐配置:

Linux 和 macOS 环境要求

Python:版本 >=3.7 C++编译器(需要下列至少一个): g++:(Linux)>=5.4.0 clang:(macOS)>=8.0 GPU 编译器(可选):nvcc >=10.0 GPU 加速库(可选):cudnn-dev (推荐使用 tar 安装方法) Jittor 目前还支持主流国产 Linux 操作系统,如统信、麒麟、普华、龙芯 Loongnix。安装方式可参考 Linux pip 安装方法,准备好 python 和 g++ 即可。

Windows 环境要求

Python:版本 >=3.8 处理器:x86_64 操作系统:Windows 10 及以上

安装项目依赖



 check your python version(>=3.8)
python --version
python -m pip install jittor
 if conda is used
conda install pywin32# Windows

数据准备

本项目使用 MNIST 数据集,Jittor 将自动下载数据集。数据集将被存储在 ~/.jittor/datasets/mnist 目录下。训练过程中将对图像进行以下预处理:

  • 调整图像大小为指定尺寸(默认 32x32)
  • 转换为灰度图像
  • 进行归一化处理(均值 0.5,标准差 0.5)

训练

单卡训练

运行以下命令开始单卡训练:

python CGAN.py

训练过程中将输出每个批次的生成器和判别器损失,并在指定间隔保存生成的图像样本。

推理与测试

在训练过程中,模型会定期采样生成图像并保存。可以通过以下代码生成特定数字序列的图像:

number = "your_number_sequence"  # 替换为需要生成的数字序列
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)

生成的图像将被保存为 result.png 文件。

模型与结果

模型结构

  • 生成器(Generator):将随机噪声和条件标签作为输入,通过多层全连接网络和激活函数生成图像。
  • 判别器(Discriminator):将图像和条件标签作为输入,通过多层全连接网络和激活函数判断图像的真实性。

结果展示

训练过程中生成的图像样本将展示生成器逐渐学习到生成逼真数字图像的能力。最终生成的图像将根据给定的标签序列生成对应的数字图像。

致谢

本项目基于以下资源实现:

注意事项

  • 确保安装正确版本的 Jittor 和依赖库
  • 如果使用 GPU,确保安装正确版本的 CUDA Toolkit 并配置好环境变量
  • 可以通过调整超参数(如学习率、批次大小等)来优化训练效果
关于

A Jittor implementation of Conditional GAN (CGAN).

35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号