目录
目录README.md

CGAN_Jittor

A Conditional GAN (CGAN) implementation in Jittor for generating MNIST digit images.
清华大学《计算机图形学基础》课程实验三(PA3)项目

🌟 项目简介

本项目基于 Jittor 深度学习框架实现了 Conditional Generative Adversarial Network(CGAN),用于在 MNIST 数据集上生成手写数字图像。通过引入类别标签,CGAN 能够生成指定数字的图像。该项目为清华大学《计算机图形学基础》课程 PA3 实验的实践任务。

📁 项目结构

. ├── CGAN.py # 主程序,包含网络定义、训练与推理 ├── generator_last.pkl # 训练保存的生成器模型(示例) ├── discriminator_last.pkl # 训练保存的判别器模型(示例) ├── result.png # 生成的指定数字图像拼接图(示例) └── README.md # 项目说明文件

🧠 模型架构

  • Generator(生成器)

    • 输入:随机噪声 z + 标签 embedding
    • 输出:32×32 的灰度图像
  • Discriminator(判别器)

    • 输入:图像 + 标签 embedding
    • 输出:实数(越接近 1 表示越可能是真实图像)
  • 损失函数:采用均方误差(MSE)

🚀 运行说明

依赖安装

请确保系统安装了 Jittor 框架,可使用如下命令:

pip install jittor

确保安装版本在 1.2.2.59 及以上。

训练模型

直接运行:

python CGAN.py

你可以使用以下参数进行自定义:

--n_epochs            # 训练轮数(默认:100)
--batch_size          # 每批图像数量(默认:64)
--lr                  # 学习率(默认:0.0002)
--latent_dim          # 隐空间维度(默认:100)
--img_size            # 图像尺寸(默认:32)
--sample_interval     # 每训练多少步保存一次生成图像(默认:1000)

例如:

python CGAN.py --n_epochs 200 --batch_size 128

推理生成

请在代码末尾填写指定数字序列,例如:

number = "3141592653"

然后重新运行 CGAN.py 生成 result.png。

关于

清华大学《计算机图形学基础》PA3

36.0 KB
邀请码