目录
liuxm2325天前7次提交
目录readme.md

Conditional GAN (CGAN) for MNIST Digit Generation

项目简介

基于Jittor框架实现的条件生成对抗网络(CGAN),通过输入数字标签(0-9)生成对应手写数字图像,支持自定义数字序列生成指定图像(如比赛要求的数字串)。模型包含生成器与判别器,利用MNIST数据集训练,实现条件约束下的图像生成任务。

核心功能

  • 条件生成:结合潜在向量与数字标签,生成指定类别的图像。
  • 模型架构
    • 生成器:接收噪声向量和标签嵌入,通过多层全连接网络输出32x32的单通道图像。
    • 判别器:同时输入图像和标签,判断图像是真实样本还是生成样本。
  • 训练支持:支持GPU加速(需Jittor开启CUDA),可自定义训练参数(批次大小、学习率等)。

安装与依赖

# 安装Jittor(支持CPU/GPU)
pip install jittor 

依赖:numpy, argparse, Pillow(内置MNIST数据集由Jittor提供)。

使用方法

1. 训练模型

python CGAN.py --n_epochs 100  # 可自定义参数

训练过程中每1000次迭代生成示例图像,每10个epoch保存生成器/判别器模型(generator_last.pkl/discriminator_last.pkl)。

2. 生成指定数字图像

修改代码末尾的number变量为目标数字序列(如比赛指定字符串),运行脚本生成result.png,展示按顺序生成的数字图像。

代码结构

  • 模型定义GeneratorDiscriminator类实现网络结构,结合标签嵌入层处理条件输入。
  • 数据加载:使用Jittor内置MNIST数据集,预处理包括缩放、灰度化和归一化。
  • 训练循环:交替优化生成器与判别器,采用均方误差(MSE)作为对抗损失。
  • 图像生成save_image函数处理多图像拼接,sample_image生成示例图像用于可视化。

注意事项

  • 确保Jittor正确安装并配置CUDA(若使用GPU)。
  • 生成图像质量与训练轮次、参数设置相关,可调整--lr--batch_size等优化效果。
  • 最终生成图像保存在脚本运行目录,命名为result.png

通过本项目可快速复现条件图像生成任务,适用于学习CGAN原理及Jittor框架实践。

关于

A Jittor implementation of Conditional GAN (CGAN)

15.0 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号