目录
目录README.md

CGAN_jittor

本项目使用 Jittor 框架实现了一个条件生成对抗网络(Conditional GAN),用于生成带有类别标签的图像。示例数据集为 MNIST 手写数字图像,可通过输入数字标签生成相应图像。

环境依赖

  • Python 3.7+
  • Jittor
  • NumPy
  • Pillow

安装 Jittor

pip install jittor

更多安装详情见:Jittor 安装文档

文件结构说明

  • CGAN.py:主程序文件,包含模型构建、训练、采样和生成流程。
  • result.png:根据给定数字序列生成的图像拼接结果。

使用说明

训练模型

运行以下命令开始训练(默认参数见下方):

python CGAN.py

支持的命令行参数:

参数 默认值 说明
--n_epochs 100 训练轮数
--batch_size 64 每个批次图像数量
--lr 0.0002 学习率
--latent_dim 100 隐变量维度(输入噪声)
--n_classes 10 图像类别数量(MNIST 为 10 类)
--img_size 32 图像尺寸(32x32)
--channels 1 图像通道数(灰度图为 1)
--sample_interval 1000 生成样本图像的间隔步数

你可以通过指定参数来自定义训练,例如:

python CGAN.py --n_epochs 50 --batch_size 128

查看生成图像

训练期间,每隔 sample_interval 步数将会生成一次图像,保存为 x.png,其中 x 是当前步数。

使用训练好的模型生成图像

在训练完成后,程序会自动保存模型为 generator_last.pkldiscriminator_last.pkl。你可以使用如下代码加载模型并生成指定数字图像序列:

number = "28262472819004"  # 指定数字序列

运行后将自动生成并保存为 result.png,其中图像按顺序拼接显示每个数字生成的图片。

模型结构

Generator

  • 输入:随机噪声向量 + 类别标签
  • 多层全连接层 + LeakyReLU + BatchNorm
  • 输出大小为 channels × img_size × img_size 的图像张量
  • 使用 Tanh 归一化输出至 [-1, 1]

Discriminator

  • 输入:图像张量 + 类别标签
  • 多层全连接层 + LeakyReLU + Dropout
  • 最终输出一个实数,代表真假判断值
  • 使用均方误差损失(MSELoss)

注意事项

  • 若使用 GPU,请确保 CUDA 环境正确,并自动开启 jt.flags.use_cuda = 1
  • 生成图像默认输出为灰度图,可以根据需要修改 channels 参数。
  • 项目基于 Jittor,因此不支持 PyTorch 等框架的模型互换。
关于

A Jittor implementation of Conditional GAN (CGAN)

379.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号