目录
目录README.md

CGAN (Conditional Generative Adversarial Network) 实现

这是一个基于Jittor框架实现的Conditional GAN (CGAN)项目,用于生成手写数字图像。CGAN通过引入条件信息(如类别标签)来控制生成器的输出。

功能概述

模型架构

**生成器(Generator)**:接收随机噪声向量 z 和类别标签 y,生成指定类别的数字图像。

**判别器(Discriminator)**:判断输入图像是真实数据还是生成数据,并结合类别标签进行条件判别。

训练过程

  • 使用 MNIST 手写数字数据集 进行对抗训练。
  • 采用 平方误差损失函数 替代传统对数损失(公式见实验说明)。
  • 每训练若干批次后,保存生成的中间结果图片。

生成功能

  • 支持生成 指定数字序列 的图像(如 [3, 7, 0, 9])。
  • 最终生成结果保存为 result.png

使用说明

环境要求

  • Python 3.6+
  • Jittor 框架
  • NumPy
  • Pillow (PIL)

安装依赖

pip install jittor numpy pillow

模型训练

python CGAN.py

程序会定期输出训练进度和损失值,格式如下:

[Epoch x/y] [Batch x/y] [D loss: z] [G loss: z]

同时会在本目录下保存生成的样本图片。

注意事项

  • 首次运行会自动下载 MNIST 数据集。

  • 训练过程中每 100 批次保存一次模型检查点。

  • 生成的图像默认保存在当前目录下(result.png)。

  • 如需启用 GPU 加速,请确保已安装 CUDA 版本的 Jittor。

关于

A Jittor implementation of Conditional GAN (CGAN).

32.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号