目录
目录README.md

GGAN_jittor

CGAN_jittor - 基于Jittor的条件生成对抗网络

一个基于Jittor深度学习框架实现的条件生成对抗网络(Conditional GAN, CGAN)项目,用于生成MNIST手写数字图像。

🎯 项目简介

本项目使用Jittor框架实现了条件生成对抗网络(CGAN),能够根据指定的数字标签生成对应的MNIST手写数字图像。通过条件标签的控制,可以精确生成0-9中任意数字的图像。

✨ 功能特性

  • 基于Jittor框架:利用Jittor的元算子与统一计算图优势,提升训练效率
  • 条件生成:可控制生成特定数字的图像
  • 完整训练流程:包含数据加载、模型训练、损失计算、模型保存等完整流程
  • 可视化生成:支持生成图像的可视化保存
  • 易于使用:简单的命令行参数配置

📁 项目结构

CGAN_jittor/
├── CGAN.py              # 主程序文件 - 模型训练与图像生成
├── README.md            # 项目说明文档
├── .gitignore           # Git忽略文件配置
├── result.png           # 生成的数字序列图像(示例)
├── requirements.txt     # Python依赖包列表
└── assets/              # 资源文件目录
    └── training_progress/ # 训练过程中的生成图像

🛠️ 环境要求

  • Python 3.7+
  • Jittor 1.3.0+
  • NumPy 1.19.0+
  • Pillow 8.0.0+

📦 安装与运行

1. 安装依赖

pip install jittor numpy pillow

2. 训练模型

# 使用默认参数训练
python CGAN.py

# 自定义训练参数
python CGAN.py --n_epochs 200 --batch_size 128 --lr 0.0001

3. 生成指定数字图像

修改代码中的数字序列并运行:

# 在CGAN.py文件末尾修改number变量
number = "2313804"  # 修改为你想要的数字序列

⚙️ 命令行参数

python CGAN.py --n_epochs 100    # 训练轮数,默认100
               --batch_size 64   # 批次大小,默认64
               --lr 0.0002       # 学习率,默认0.0002
               --latent_dim 100  # 隐变量维度,默认100
               --img_size 32     # 图像尺寸,默认32×32

🏗️ 模型架构

生成器(Generator)

  • 输入:噪声向量 + 数字标签嵌入
  • 结构:全连接层 → BatchNorm → LeakyReLU
  • 输出:32×32的灰度图像

判别器(Discriminator)

  • 输入:图像 + 数字标签嵌入
  • 结构:全连接层 → Dropout → LeakyReLU
  • 输出:判别结果(真实/虚假)

📊 训练过程

  1. 数据准备:加载MNIST数据集,进行归一化处理
  2. 交替训练
    • 训练判别器:最大化真实图像与生成图像的判别准确率
    • 训练生成器:最小化判别器对生成图像的判别准确率
  3. 损失函数:均方误差(MSE)损失
  4. 优化器:Adam优化器

🖼️ 生成示例

训练完成后,模型可以生成清晰的MNIST数字图像: 生成结果示例

📝 关键代码说明

1. 模型定义

class Generator(nn.Module):
    def __init__(self):
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(...)

2. 训练循环

# 训练生成器
g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)

# 训练判别器
d_loss = (d_real_loss + d_fake_loss) / 2

3. 图像生成

gen_imgs = generator(z, labels)  # 生成图像
save_image(gen_imgs.numpy(), "result.png")  # 保存图像

🎓 实验结果

  • 训练轮数:100 epochs
  • 批次大小:64
  • 生成质量:生成的数字图像清晰可辨
  • 训练时间:在GPU上约30分钟完成训练
  • 生成效果:成功生成学号”2313804”对应的数字序列

🤝 贡献指南

  1. Fork本仓库
  2. 创建功能分支 (git checkout -b feature/AmazingFeature)
  3. 提交更改 (git commit -m 'Add some AmazingFeature')
  4. 推送到分支 (git push origin feature/AmazingFeature)
  5. 开启Pull Request

📄 开源协议

本项目基于 MIT License 开源。详见 LICENSE 文件。

🙏 致谢

  • Jittor团队:提供优秀的深度学习框架
  • 清华大学计算机系:组织计图挑战赛
  • MNIST数据集:提供训练数据
关于

A Jittor implementation of Conditional GAN(CGAN)

38.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号