目录
目录README.md

基于 Jittor 框架的 Conditional GAN(CGAN)

这是一个使用 Jittor 框架实现的 Conditional Generative Adversarial Network(条件生成对抗网络)项目,基于 MNIST 数据集训练模型生成指定类别的手写数字图像,用于第三届计图人工智能挑战赛热身赛任务。

配置

环境依赖

  • Python 3.x
  • Jittor(推荐使用 GPU 加速)

安装 Jittor(GPU)

pip install jittor
python -m jittor.test.test_example  # 测试是否安装成功

运行训练与生成

python CGAN.py

运行后程序将自动:

  • 下载 MNIST 数据集;
  • 训练 CGAN 模型;
  • 持续保存生成图像;
  • 根据指定数字字符串(如 "28192092811985")输出最终结果图 result.png

模型结构

Generator(生成器)

将随机噪声和类别标签拼接后输入,通过全连接层逐步扩展为图像。

输入:100维噪声 + 类别标签(one-hot)
层次结构:
    Linear -> LeakyReLU
    Linear -> BatchNorm -> LeakyReLU
    Linear -> BatchNorm -> LeakyReLU
    Linear -> BatchNorm -> LeakyReLU
    Linear -> Tanh
输出:32x32 灰度图像(1通道)

Discriminator(判别器)

输入图像与类别标签拼接后,判断是否为真实图像。

输入:图像(展平)+ 类别标签(one-hot)
层次结构:
    Linear -> LeakyReLU
    Linear -> Dropout -> LeakyReLU
    Linear -> Dropout -> LeakyReLU
    Linear -> 实数输出
输出:真/伪分值(非概率)

参数说明

参数名 描述
--n_epochs 训练轮数
--batch_size 每批样本大小
--lr 学习率(Adam优化器)
--b1 Adam优化器参数 beta1
--b2 Adam优化器参数 beta2
--latent_dim 噪声向量维度
--n_classes 类别数量(MNIST为10)
--img_size 图像尺寸(32x32)
--channels 图像通道数(灰度图为1)
--sample_interval 每隔多少步保存一次生成图像样本
--n_cpu 加载数据使用的CPU线程数

数据说明

  • 使用 Jittor 内置接口加载 MNIST 数据集:
    from jittor.dataset.mnist import MNIST
  • 图像预处理包括:
    • Resize 到 32x32
    • 灰度转换
    • 归一化到 [-1, 1]
关于
35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号