目录
目录readme.md

一、Jittor(计图)框架介绍

​ 该项目的运行基于Jittor框架。下面是对该框架的讲解。

​ Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器。Jittor前端语言为Python。前端使用了模块化的设计,类似于PyTorch,Keras,后端则使用高性能语言编写,如CUDA,C++。

​ Jittor框架的使用样例如下,我们可以依赖此框架快速的搭建神经网络应用。

import jittor as jt
from jittor import Module
from jittor import nn
class Model(Module):
    def __init__(self):
        self.layer1 = nn.Linear(1, 10)
        self.relu = nn.Relu() 
        self.layer2 = nn.Linear(10, 1)
    def execute (self,x) :
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

def get_data(n): # generate random data for training test.
    for i in range(n):
        x = np.random.rand(batch_size, 1)
        y = x*x
        yield jt.float32(x), jt.float32(y)

model = Model()
learning_rate = 0.1
optim = nn.SGD(model.parameters(), learning_rate)

for i,(x,y) in enumerate(get_data(n)):
    pred_y = model(x)
    loss = ((pred_y - y)**2)
    loss_mean = loss.mean()
    optim.step (loss_mean)
    print(f"step {i}, loss = {loss_mean.data.sum()}")

二、项目介绍

​ 这是一个基于条件生成对抗网络(Conditional GAN)的手写数字生成项目,使用Jittor深度学习框架实现。项目的核心是训练一个能够生成指定数字手写体的生成器网络。整个系统由生成器和判别器两个主要网络组成:生成器接收随机噪声和目标数字标签作为输入,通过多层全连接网络将其转换为手写数字图像;判别器则负责区分图像是真实的MNIST数据集图像还是生成器生成的假图像。

​ 在训练过程中,生成器和判别器相互对抗,不断优化:生成器试图生成越来越逼真的手写数字图像来”欺骗”判别器,而判别器则努力提高自己区分真假图像的能力。项目使用MNIST数据集作为训练数据,通过多轮训练后,生成器能够学习到手写数字的特征分布,最终能够根据输入的数字序列(如”2211585”)生成对应的手写数字图像。项目还实现了图像保存功能,可以定期保存训练过程中的生成结果,以及保存最终生成的图像。

​ 整个项目展示了GAN在图像生成领域的应用,特别是在条件控制下生成特定类别图像的能力,这种技术可以扩展应用到更多领域,如艺术创作、数据增强等方面。

三、部署方法

(一)基础环境要求:

  • Python 3.7或以上版本
  • CUDA支持(如果要使用GPU)
  • 足够的内存和存储空间

(二)不同操作系统的安装步骤:

Ubuntu/Debian系统

sudo apt install python3.7-dev libomp-dev
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example

CentOS/RedHat系统

sudo yum install devtoolset-7-gcc-c++ -y
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example

macOS

brew install libomp
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example

Windows

# 检查python版本大于等于3.8
python --version
conda install pywin32
python -m pip install jittor
python -m jittor.test.test_core
python -m jittor.test.test_example

(三)运行项目:

# 克隆项目(如果是从git仓库)
git clone [项目地址]
cd [项目目录]

# 运行训练
python CGAN.py

四、原理讲解

​ 什么是GAN(生成对抗网络)。GAN是由Ian Goodfellow在2014年提出的一种深度学习模型,它的核心思想是通过两个神经网络的相互对抗来完成学习任务。这个想法源于博弈论中的零和博弈,就像造假者和鉴定者之间的较量。造假者不断改进造假技术,而鉴定者则不断提升鉴别能力,两者相互促进,最终造假者能够制造出以假乱真的赝品。在GAN中,”造假者”就是生成器(Generator),而”鉴定者”则是判别器(Discriminator)。

Overview of GAN Structure | Machine Learning | Google for Developers

​ GAN的工作原理。生成器接收随机噪声作为输入,试图生成看起来真实的数据(比如图像);判别器则接收真实数据和生成器生成的假数据作为输入,努力区分它们的真伪。这两个网络通过不断的对抗训练来提升各自的能力:生成器试图生成越来越逼真的假数据来欺骗判别器,而判别器则试图提高自己的判别能力来识破生成器的”诡计”。这个过程可以用数学方式表达为一个极小极大博弈问题:判别器要最大化正确判别的概率,而生成器要最小化被判别器识破的概率。

Generative Adversarial Network (GAN) - Semiconductor Engineering

​ 在训练过程中,GAN采用交替训练的方式。首先固定生成器的参数,训练判别器。判别器的训练目标是:对于真实数据,输出接近1的概率;对于生成器生成的假数据,输出接近0的概率。完成判别器的训练后,再固定判别器的参数,训练生成器。生成器的训练目标是生成能够欺骗判别器的数据,也就是让判别器对生成数据的输出接近1。这个交替训练的过程会持续进行,直到达到纳什均衡,即生成器生成的数据分布接近真实数据分布,判别器无法准确区分真假数据(输出概率接近0.5)。

​ 然而,GAN的训练过程并不像理论描述的那样简单。实际训练中会遇到很多挑战,最主要的是训练不稳定和模式崩溃问题。训练不稳定是指生成器和判别器的能力需要维持适当的平衡:如果判别器太强,生成器得不到有效的梯度信息来改进;如果判别器太弱,生成器又得不到正确的引导。模式崩溃是指生成器可能只学会生成有限几种样本,而不是学习到完整的数据分布。为了解决这些问题,研究者提出了多种改进版本的GAN,比如使用Wasserstein距离的WGAN,使用条件信息的CGAN,以及在计算机视觉领域广泛使用的DCGAN等。

​ 在实际应用中,GAN已经在多个领域展现出强大的能力。在图像生成领域,它可以生成逼真的人脸、风景、艺术作品等;在图像转换领域,它可以实现风格迁移、图像修复、超分辨率重建等任务;在其他领域,如文本生成、音频合成、视频生成等,GAN也有着广泛的应用。近年来,随着技术的发展,像StyleGAN这样的模型已经能够生成极其逼真的图像,几乎可以以假乱真。

​ 要成功训练一个GAN模型,需要注意很多细节。网络架构的选择很关键,通常生成器使用反卷积网络,判别器使用卷积网络。在训练过程中,需要合理设置学习率,使用批标准化等技术来稳定训练,同时要注意防止过拟合。此外,还要定期保存模型检查点,因为GAN的训练过程可能不是单调改进的,有时候中间状态可能比最终状态更好。

五、代码分析

  1. 首先是导入和参数配置部分:
python
import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn

if jt.has_cuda:
    jt.flags.use_cuda = 1

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
# ... 其他参数配置
opt = parser.parse_args()

这部分代码:

  • 导入必要的库,包括深度学习框架Jittor
  • 检测并启用CUDA加速
  • 设置训练参数,包括训练轮数、批次大小、学习率等超参数
  1. 生成器网络结构:
python
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers
            
        self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False), 
                                   *block(128, 256), 
                                   *block(256, 512), 
                                   *block(512, 1024), 
                                   nn.Linear(1024, int(np.prod(img_shape))), 
                                   nn.Tanh())

生成器网络:

  • 包含一个标签嵌入层,将数字标签转换为向量
  • 定义了一个基本块block,包含全连接层、批归一化和LeakyReLU激活
  • 主体是多个block的序列,逐步将输入转换为图像大小
  1. 判别器网络结构:
python
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 1))

判别器网络:

  • 也包含标签嵌入层
  • 将图像和标签信息拼接后输入网络
  • 通过多个全连接层和Dropout层
  • 最终输出一个值表示真假程度
  1. 数据加载部分:
python
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

这部分:

  • 导入MNIST数据集
  • 定义图像预处理操作:调整大小、转灰度、标准化
  • 创建数据加载器
  1. 训练循环:
python
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # 训练生成器
        z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
        gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        
        # 训练判别器
        validity_real = discriminator(real_imgs, labels)
        validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
        d_loss = (d_real_loss + d_fake_loss) / 2

训练过程:

  • 每个epoch遍历数据集
  • 交替训练生成器和判别器
  • 生成器试图生成真实的图像
  • 判别器学习区分真假图像
  1. 最后是应用部分:
python
number = "2211585"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)

这部分展示了如何使用训练好的模型:

  • 输入一串数字
  • 为每个数字生成对应的手写数字图像
  • 保存生成结果

四、运行结果

可以看到我们正确的生成了手写数字。

img

关于

A Jittor implemention of Conditional GAN

48.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号