目录

CGAN-Jittor:计图挑战热身赛

简介

本项目基于 Jittor 框架 实现,使用数字图片数据集 MNIST 训练一个条件生成对抗网络(Conditional GAN)模型,并通过输入一个随机向量 z 和额外的辅助信息 y(如类别标签),生成给定的特定数字序列的图像。

环境配置

  • Linux Ubuntu 22.04 LTS
  • GPU:NVIDIA GeForce RTX 4090
  • CUDA:12.4

使用 Anaconda 创建虚拟环境,名称为 jittor

conda create -n jittor python=3.9
conda activate jittor
conda install -c conda-forge gcc_linux-64==11.2.0
conda install -c conda-forge gxx_impl_linux-64==11.2.0
conda install -c conda-forge libstdcxx-ng=12
python3 -m pip install scipy trimesh matplotlib
python3 -m pip install jittor

运行指令

首先将 CGAN.py 中的 number 替换为目标数字序列字符串,运行:

python CGAN.py

如需调整训练参数,可通过命令行传入,如:

python CGAN.py --n_epochs 150 --batch_size 64 --lr 0.001

更多命令行参数详见 CGAN.py

运行过程

  • 数据处理:程序将调用 Jittor 框架自带接口加载 MNIST 数据集,并设置相应的图像变换。

  • 模型训练:在模型训练过程中,每 sample_interval 个 batch 将保存一次生成的图像样本,每 10 个 epoch 将保存一次模型生成器(generator_last.pkl)和判别器(discriminator_last.pkl)。

  • 模型推理:训练完成后,程序加载保存的模型,并根据设定的数字序列字符串生成对应的手写数字图像。

训练结果下载

可以通过 此云盘链接 下载训练得到的模型文件和生成的图像结果。

结果

致谢

本项目基于比赛官方提供的 代码框架 完成。

关于

A Jittor implementation of Conditional GAN (CGAN).

45.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号