目录
目录README.md

CGAN_jittor

项目概述

本项目为清华大学图形学实验 PA3:Conditional GAN 的实现,使用 Jittor 深度学习框架在 MNIST 数据集上训练一个条件生成对抗网络 (Conditional GAN, CGAN)。通过输入随机噪声和指定数字标签(0-9),生成器能够生成对应数字的图片,判别器则区分真实图片与生成图片。本项目完成了一个 CGAN 模型的训练与推理,生成比赛指定数字序列的图片,并满足 EduCoder 平台的评测要求。

环境要求

  • 操作系统:Linux(推荐 Ubuntu 24.04)或 Windows(含 WSL)。macOS 用户需使用虚拟机。
  • Python 版本:Python 3.7(推荐通过 DeadSnakes PPA 安装)。
  • 依赖
    • Jittor(深度学习框架)
    • NumPy
    • Pillow(PIL)
    • C++ 编译器(g++ 或 clang)
  • 可选:CUDA(用于 GPU 加速)

安装步骤

  1. 安装 Python 3.7

    sudo add-apt-repository ppa:deadsnakes/ppa
    sudo apt update
    sudo apt install python3.7 python3.7-venv python3.7-dev
  2. 创建虚拟环境

    python3.7 -m venv jittor-venv
    source jittor-venv/bin/activate
  3. 安装 Jittor 和依赖

    pip install jittor numpy pillow -i https://pypi.tuna.tsinghua.edu.cn/simple
  4. 验证 Jittor 安装

    python -c "import jittor as jt; print(jt.__version__)"
  5. 安装 libomp-dev(Jittor 依赖)

    sudo apt install libomp-dev

项目结构

CGAN_jittor/
├── CGAN.py           # 主代码文件,包含 CGAN 模型定义、训练和推理逻辑
├── result.png        # 最终生成的图片(比赛指定数字序列)
├── generator_last.pkl # 保存的生成器模型权重
├── discriminator_last.pkl # 保存的判别器模型权重
├── README.md         # 项目说明文档
└── *.png             # 训练过程中生成的中间图片(如 0.png, 1000.png)

使用说明

1. 配置比赛数字序列

  • 打开 CGAN.py,找到以下行:
    number = "YOUR_NUMBER_SEQUENCE"  # 替换为比赛页面指定的数字序列
  • 从 EduCoder 比赛页面(https://www.educoder.net/competitions/index/Jittor-6)获取数字序列(例如 "0123456789"),替换 YOUR_NUMBER_SEQUENCE

2. 运行训练

在虚拟环境中运行:

python CGAN.py
  • 默认训练 100 个 epoch(可通过 --n_epochs 参数调整,例如 --n_epochs 10)。
  • 每 1000 个批次保存中间图片(如 0.png, 1000.png)。
  • 每 10 个 epoch 保存模型权重(generator_last.pkl, discriminator_last.pkl)。
  • 训练完成后生成 result.png,包含指定数字序列的图片。

3. 仅生成结果

若模型已训练(存在 generator_last.pkldiscriminator_last.pkl),可直接生成 result.png

python -c "
import jittor as jt
from jittor import nn
from PIL import Image
import numpy as np
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(10, 10)
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers
        self.model = nn.Sequential(
            *block((100 + 10), 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1 * 32 * 32),
            nn.Tanh()
        )
    def execute(self, noise, labels):
        gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
        img = self.model(gen_input)
        img = img.view((img.shape[0], 1, 32, 32))
        return img
generator = Generator()
generator.load('generator_last.pkl')
generator.eval()
number = 'YOUR_NUMBER_SEQUENCE'  # 替换为比赛页面指定的数字序列
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, 100))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)
img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((32, -1))
min_ = img_array.min()
max_ = img_array.max()
img_array = (img_array - min_) / (max_ - min_) * 255
Image.fromarray(np.uint8(img_array)).save('result.png')
"

4. 检查结果

  • 检查 result.png,确认是否包含指定数字的图片。
  • 如果图片质量不佳,可增加训练轮数(例如 --n_epochs 200)。

参考资源

  • Jittor 官网:https://cg.cs.tsinghua.edu.cn/jittor/
  • Jittor 安装指南:https://cg.cs.tsinghua.edu.cn/jittor/download/
  • Jittor 教程:https://www.educoder.net/paths/89rcg6jn
  • 比赛页面:https://www.educoder.net/competitions/index/Jittor-6
关于

A Jittor implementation of Conditional GAN (CGAN).

10.0 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号