目录
目录README.md

lab0

项目概述

本项目利用 Jittor 框架实现了一个图像生成模型,主要基于生成对抗网络(GAN)架构,可生成特定数字图像。通过训练生成器与判别器,模型能学习数据分布,生成高质量图像,在图像生成领域具实用价值,如艺术创作、数据增强等场景。

功能特性

  • 自定义图像生成:依设定数字序列生成对应图像,输入不同数字组合可获多样化结果,助用户按需创作图像。
  • 模型训练与优化:采用 Adam 优化器与 MSELoss 损失函数训练,优化生成器与判别器参数,提升图像生成质量与模型性能。
  • 数据集支持:集成 MNIST 数据集训练模型,为模型学习图像特征与分布提供丰富数据,经数据预处理适配模型训练需求。

安装指南

  1. 环境依赖:确保系统安装 Python,建议 3.6 及以上版本;安装 Jittor 框架,依官方文档(https://cg.cs.tsinghua.edu.cn/jittor/)指引操作;其他必要库如 numpyargparsePIL 等可通过 pip install 安装。
  2. 项目克隆与配置:从项目仓库克隆代码至本地;在项目目录打开终端或命令提示符,依需求修改代码中参数(如训练轮数、批次大小、学习率等),于命令行执行脚本启动项目。

使用方法

  1. 模型训练:运行主脚本启动训练,训练进度与损失信息将打印于控制台,训练完成模型参数存为 generator_last.pkldiscriminator_last.pkl,可供后续使用或评估。
  2. 图像生成:运行生成代码段,修改 number 变量为目标数字序列(字符串),生成图像存为 result.png,生成图像依数字序列定制,展示模型生成能力。

项目结构

  • generator.py:定义生成器类 Generator,含嵌入层与多层全连接网络构建模型结构及前向传播逻辑,将随机噪声与标签转成图像。
  • discriminator.py:定义判别器类 Discriminator,由嵌入层与多层全连接网络组成,判别输入图像真伪,输出判别结果。
  • train.py:主训练脚本,设置训练参数、加载数据集、定义损失函数与优化器,迭代训练生成器与判别器,定期采样保存生成图像并保存模型。
  • utils.py:存放数据预处理、图像保存等工具函数,如 save_image 函数处理图像保存格式与归一化,确保图像正确存储显示。
关于

A Jittor implementation of Conditional GAN (CGAN).

1.1 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号