目录
目录README.md

Jittor 热身赛 baseline

简介

本项目包含了第五届计图挑战赛热身赛的代码实现。本项目将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。本项目的特点是:采用了 基于Jittor框架的Conditional GAN 方法 对 MNIST 手写数字生成任务 处理,取得了 高效训练并可控生成特定数字 的效果。

安装

本项目无显卡要求,训练时间约为 0.5 小时。

运行环境

  • ubuntu 22.04.5 LTS
  • python==3.10.16
  • numpy==1.26.4
  • jittor==1.3.9.14

安装依赖

首先,从官网https://cg.cs.tsinghua.edu.cn/jittor/download/,选择”Ubuntu - Pip - CPU“,安装Jittor 然后,执行以下命令创造虚拟环境(假定本地已安装miniconda),其中<env_name> 可以替换为自己的虚拟环境名称

conda create -n <env_name> python=3.10

最后,执行以下命令安装 python 依赖

pip install -r requirements.txt

数据预处理

本项目使用Jittor自带的MNIST数据集。首先,将数据集图像变换为统一大小、像素值[-1, 1]的标准化图像;随后,加载数据集并打乱。

训练

训练模型可以运行以下命令:

python3 CGAN.py

推理

模型会在每一轮训练前将当前训练在0~9上的结果输出,训练后将生成器、判别器权重分别存储在generator_last.pkldiscriminator_last.pkl中。训练完毕后,会输出对测试集的结果到result.png中。最终的训练权重和测试已上传到仓库中。

致谢

此项目基于比赛官方提供的示例代码框架,其中包括了数据下载、模型定义、训练步骤等功能。

关于

A Jittor implementation of Conditional GAN (CGAN).

10.0 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号