目录
目录README.md

基于 Jittor 实现的条件生成对抗网络 (CGAN)

项目概述

这个项目实现了一个基于 Jittor 框架的条件生成对抗网络 (Conditional Generative Adversarial Network, CGAN)CGAN 是一种生成模型,它不仅学习生成逼真的数据(在这个例子中是 MNIST 手写数字图像),还学习根据给定的条件(数字类别标签)来生成特定类别的数据。本项目旨在训练一个 CGAN 模型,使其能够根据输入的数字标签生成对应的手写数字图像。

主要组件

  1. 生成器 (Generator): 接收一个随机噪声向量 (latent vector) 和一个类别标签作为输入,并输出一张对应标签的、看起来像真实 MNIST 图像的图片。
  2. 判别器 (Discriminator): 接收一张图片(真实的或生成的)和一个类别标签作为输入,并判断这张图片是否真实,以及它属于哪个类别。
  3. 训练过程: 通过对抗训练的方式,生成器和判别器相互竞争、共同进化。生成器努力生成更逼真的图片以欺骗判别器,而判别器则努力提高辨别真伪的能力。

    代码结构

  • 导入库: 导入必要的 Jittor 库、NumPy、数学库、argparse 用于解析命令行参数、os 用于文件操作以及 PIL 用于图像保存。
  • 参数配置: 使用 argparse 定义并解析训练参数,如训练轮数 (n_epochs)、批次大小 (batch_size)、学习率 (lr) 等。
  • 数据准备: 定义图像尺寸 (img_size) 和通道数 (channels),构建 img_shape。定义生成器使用的标签嵌入层。
  • 模型定义:
    • Generator 类:定义生成器的网络结构,包含嵌入层、多个全连接层块(带批归一化和 LeakyReLU 激活)以及最终的输出层(Tanh 激活以输出 [-1, 1] 范围的像素值)。
    • Discriminator 类:定义判别器的网络结构,包含嵌入层、多个全连接层(带 Dropout 和 LeakyReLU 激活)以及最终的输出层(输出一个实数,代表“真实性”得分)。
  • 损失函数: 使用均方误差 (MSE) 作为对抗损失函数 (adversarial_loss)。
  • 数据加载: 使用 Jittor 的 MNIST 数据集类,并应用预处理变换(调整大小、转为灰度、归一化)。
  • 优化器: 为生成器和判别器分别定义 Adam 优化器。
  • 辅助函数:
    • save_image: 将一批生成的图像拼接并保存为一张图片。
    • sample_image: 在训练过程中定期生成并保存一批图像样本,用于监控训练进度。
  • 训练循环:
    • 外层循环遍历训练轮数 (epoch)。
    • 内层循环遍历数据加载器 (dataloader),获取真实图像和标签。
    • 生成器训练: 生成随机噪声和标签,生成图像,计算生成器损失(判别器被欺骗的程度),并更新生成器参数。
    • 判别器训练: 计算判别器在真实图像上的损失,计算判别器在生成图像上的损失,合并两者得到总损失,并更新判别器参数。
    • 定期打印训练损失信息。
    • 定期调用 sample_image 保存生成的图像样本。
    • 定期保存生成器和判别器的模型参数。
  • 模型评估与生成:
    • 将模型设置为评估模式。
    • 加载之前保存的模型参数。
    • 定义一个特定的数字序列 (number = "28164482809235")。
    • 为该序列中的每个数字生成随机噪声,并生成对应的图像。
    • 将生成的图像数组进行后处理(归一化、转置维度),并保存为 result.png 文件。

      如何运行

  1. 环境准备:
    • 确保已安装 Python 3。
    • 安装 Jittor 框架。可以通过 pip 安装:pip install jittor
    • 安装 NumPy 和 Pillow (PIL):pip install numpy pillow
  2. 运行训练:
    • 将代码保存为 .py 文件(例如 mnist_cgan.py)。
    • 在终端中运行:python mnist_cgan.py
    • 你可以通过修改 argparse 中定义的参数来调整训练过程,例如增加训练轮数 (--n_epochs 200) 或改变批次大小 (--batch_size 128)。
  3. 查看结果:
    • 训练过程中,每隔 sample_interval 步会生成并保存一批图像,文件名格式为 <步数>.png
    • 训练结束后,会生成一个名为 result.png 的图像,其中包含根据数字序列 “28164482809235” 生成的对应数字图像。

      注意事项

  • 训练 GAN 模型可能需要较长时间,尤其是在没有 GPU 加速的情况下。代码中检查了 CUDA 是否可用 (jt.has_cuda),如果可用会自动使用 GPU。
  • 训练 GAN 时可能会遇到模式崩溃 (mode collapse) 等问题,即生成器可能只生成少数几种类型的图像。调整超参数(如学习率、网络结构、损失函数权重)或使用更先进的训练技巧可能有助于改善这种情况。
  • 生成的图像质量取决于训练的充分程度和模型设计。训练轮数 (n_epochs) 越多,通常生成质量越好,但也会增加训练时间。
关于
10.4 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号