目录
目录README.md

CGAN-Jittor: 条件生成对抗网络实现手写数字生成

Python Jittor License

第四届计图人工智能挑战赛—计图挑战热身赛参赛项目,基于 Jittor 框架实现条件生成对抗网络(CGAN),可生成指定数字类别(0-9)的手写数字图像。

📋 项目简介

传统 GAN 生成的图像缺乏类别可控性,本项目实现的 CGAN 通过在生成器和判别器中引入数字标签作为条件约束,使模型能够精准生成指定类别的 MNIST 手写数字图像。

核心特性

  • ✅ 基于 Jittor 深度学习框架开发,适配计图挑战赛环境
  • ✅ 使用 MSE Loss 替代传统对数损失,提升训练稳定性
  • ✅ 支持指定数字类别生成,生成结果可视化
  • ✅ 模型断点保存与加载,训练过程实时监控

🔧 环境配置

1. 克隆项目

git clone https://gitlink.org.cn/Abdaaa/CGAN_jittor.git
cd CGAN_jittor

2. 安装 Jittor

根据官方文档安装适配的 Jittor 版本:

# 创建 Python 3.7 环境(命名为 jittor-env)
conda create -n jittor-env python=3.7 -y
# 激活环境
conda activate jittor-env
# 安装计图
python -m pip install jittor
# 验证安装
python -m jittor.test.test_example
python -m jittor.test.test_cudnn_op

详细安装指南:Jittor 官方下载页

🚀 快速使用

训练模型

直接运行主程序开始训练:

python CGAN.py

训练过程说明

  • 训练日志实时输出:EpochBatchD_loss(判别器损失)、G_loss(生成器损失)
  • 每迭代指定轮次自动生成示例图像(默认路径:images/xxx.png
  • 训练结束后自动保存模型:
    • 判别器:discriminator_last.pkl
    • 生成器:generator_last.pkl
  • 最终生成指定数字的示例图像:result.png

生成指定数字

修改 CGAN.py 中的 number 变量(0-9),运行后即可生成对应数字的手写图像。

📊 结果展示

训练日志示例

Epoch [1/100], Batch [50/600], D_loss: 0.234, G_loss: 1.567
Epoch [1/100], Batch [100/600], D_loss: 0.189, G_loss: 1.234
...
Epoch [100/100], Batch [600/600], D_loss: 0.056, G_loss: 0.890

训练过程示例

生成图像示例

批量生成示例 生成制定数字2312966
批量生成示例 生成指定数字

📄 参考与致谢

📜 许可协议

本项目基于 MIT 协议开源,详见 LICENSE 文件。

关于

A Jittor implementation of Conditional GAN (CGAN)

37.8 MB
邀请码