rearrange files
本项目在数字图片数据集 MNIST 上训练 Conditional GAN(条件生成对抗网络)模型。通过输入随机噪声向量 z 和条件信息 y(类别标签),生成特定数字的手写体图像。
任务目标:
技术特点:
python 3.8 argon2-cffi 20.1.0 astunparse 1.6.3 async-generator 1.10 attrs 20.3.0 backcall 0.2.0 bleach 3.3.0 cffi 1.14.5 cycler 0.10.0 decorator 5.0.6 defusedxml 0.7.1 entrypoints 0.3 ipykernel 5.5.3 ipython 7.22.0 ipython-genutils 0.2.0 jedi 0.18.0 Jinja2 2.11.3 jittor 1.2.2.59 jsonschema 3.2.0 jupyter-client 6.1.13 jupyter-core 4.7.1 jupyterlab-pygments 0.1.2 kiwisolver 1.3.1 MarkupSafe 1.1.1 matplotlib 3.4.1 mistune 0.8.4 nbclient 0.5.3 nbconvert 6.0.7 nbformat 5.1.3 nest-asyncio 1.5.1 notebook 6.3.0 numpy 1.20.2 packaging 20.9 pandocfilters 1.4.3 parso 0.8.2 pexpect 4.8.0 pickleshare 0.7.5 Pillow 8.2.0 pip 20.0.2 prometheus-client 0.10.1 prompt-toolkit 3.0.18 ptyprocess 0.7.0 pybind11 2.6.2 pycparser 2.20 Pygments 2.8.1 pyparsing 2.4.7 pyrsistent 0.17.3 python-dateutil 2.8.1 pyzmq 22.0.3 Send2Trash 1.5.0 setuptools 45.2.0 six 1.15.0 terminado 0.9.4 testpath 0.4.4 tornado 6.1 tqdm 4.60.0 traitlets 5.0.5 wcwidth 0.2.5 webencodings 0.5.1 wheel 0.34.2
cd /path/to/project/ docker run -it --gpus all -v /path/to/project:/workspace jittor-cuda-12-2 bash cd /workspace python3.8 -m CGAN.py
部分可配置参数如下:
--n_epochs: 训练轮数,默认100 --batch_size: 批次大小,默认64 --lr: Adam优化器学习率,默认0.0002 --b1: Adam优化器一阶动量衰减率,默认0.5 --b2: Adam优化器二阶动量衰减率,默认0.999 --latent_dim: 潜在空间维度,默认100 --n_classes: 数据集类别数,默认10 --img_size: 图像尺寸,默认32x32 --channels: 图像通道数,默认1(灰度图) --sample_interval: 图像采样间隔,默认1000
输入:
noise
[batch_size, latent_dim]
labels
[batch_size]
网络结构:
输出:
[batch_size, 1, 32, 32]
img
[batch_size, 1]
. ├── CGAN.py # 主代码文件 ├── Dockerfile # 自定义 Dockerfile ├── README.md # README 文件 ├── result.png # 生成的数字序列图像(最终结果) ├── generator_last.pkl # 训练好的生成器模型 └── discriminator_last.pkl # 训练好的判别器模型
Jittor文档
感谢王潇策同学提供的镜像包
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN
项目简介
本项目在数字图片数据集 MNIST 上训练 Conditional GAN(条件生成对抗网络)模型。通过输入随机噪声向量 z 和条件信息 y(类别标签),生成特定数字的手写体图像。
任务目标:
技术特点:
运行环境
使用方法
部分可配置参数如下:
模型简介
生成器(Generator)
输入:
noise
: 随机噪声向量,维度为[batch_size, latent_dim]
(默认100维)labels
: 类别标签,维度为[batch_size]
,取值范围0-9网络结构:
输出:
[batch_size, 1, 32, 32]
,像素值范围[-1, 1]判别器(Discriminator)
输入:
img
: 图像数据,维度为[batch_size, 1, 32, 32]
labels
: 对应的类别标签,维度为[batch_size]
网络结构:
输出:
[batch_size, 1]
,表示输入图像为真实图像的概率损失函数
项目结构
参考资料
Jittor文档
感谢王潇策同学提供的镜像包