目录
目录README.md

| 第五届计图挑战热身赛

MNIST训练 Conditional GAN生成特定数字的图像

image-20250520153748247

简介

本项目包含了第五届计图挑战热身赛 - CGAN生成特定数字的图像。在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

安装

| 介绍基本的硬件需求、运行环境、依赖安装方法

运行环境

  • WSL
  • python = 3.12.3
  • jittor = 1.3.9.14
  • numpy
  • Pillow

数据预处理

  • 使用 Jittor 自带的 MNIST 数据集

  • 预处理步骤包括:

    • 图片尺寸调整为指定大小(默认为 32×32)

    • 灰度化处理

    • 归一化(均值 0.5,标准差 0.5),将像素值映射到 [-1, 1]

  • 具体数据加载示例如下:

import jittor.transform as transform
from jittor.dataset.mnist import MNIST

transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])

dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

训练

  1. 初始化模型
    • 生成器(Generator)和判别器(Discriminator)均为多层全连接神经网络
    • 生成器输入:随机噪声 + 类别标签
    • 判别器输入:图像 + 类别标签
  2. 优化器
    • 使用 Adam 优化器
    • 学习率、动量参数可通过命令行参数调整
  3. 训练流程
    • 对判别器:计算真实图片和生成图片的判别损失,更新判别器参数
    • 对生成器:计算生成图片的判别结果对“真实”标签的误差,更新生成器参数
    • 训练采用均方误差(MSE)作为对抗损失函数
  4. 模型保存和采样
    • 每隔一定训练步数(默认1000步)保存生成的图像样本
    • 每隔若干epoch保存一次模型权重
  5. 运行示例
python CGAN.py --n_epochs 100 --batch_size 64 --lr 0.0002 --latent_dim 100 --n_classes 10 --img_size 32 --channels 1

致谢

  • 感谢 Jittor 团队提供高效灵活的深度学习框架支持

  • 参考自经典的 CGAN 论文与 MNIST 数据集标准预处理流程

  • 感谢开源社区提供的各类工具和资源

注意事项

  • 本代码支持 GPU 加速,前提是系统安装好CUDA并正确配置Jittor

  • 训练时间依赖于设备性能和参数设置

  • 图片大小和类别数需根据具体数据集调整

  • 生成器和判别器模型可根据需求进一步优化或更换

  • 训练时需确保样本数量与批量大小兼容

  • 采样图片保存格式为 PNG,路径和名称在代码中定义

关于

在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。(代码实例的随机ID是:28125622805359)

38.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号