目录
目录README.md

第三届计图人工智能挑战赛 - Jittor 赛题

简介

本项目是第三届计图人工智能挑战赛的参赛代码。项目基于 Jittor 深度学习框架实现了一个条件生成对抗网络(CGAN),能够根据给定的类别标签(0-9 的数字)生成相应的手写数字图像。模型在 MNIST 数据集上进行训练。

环境依赖

本项目运行依赖以下环境:

  • Python 3.10.12
  • Jittor
  • Numpy
  • Pillow (PIL)

关于 Jittor 的安装和配置,请参考 Jittor 官网文档

快速开始

训练模型与生成结果

直接运行 CGAN.py 脚本即可开始训练,并在训练结束后生成特定数字序列的图像。

python3 CGAN.py

脚本功能说明

  1. 模型训练: 脚本会加载 MNIST 数据集,并训练 Generator 和 Discriminator。 默认训练 100 个 Epoch。 每隔一定步数会保存中间采样图片。

  2. 模型保存: 训练过程中每 10 个 Epoch 会保存一次模型权重。 最终模型权重保存在:

    • generator_last.pkl
    • discriminator_last.pkl
  3. 结果生成: 训练完成后,脚本会自动加载最终保存的模型,并根据代码中预设的数字序列(如 2312479)生成对应的拼接图像,保存为 result.png

参数说明

你可以通过命令行参数自定义训练配置:

python CGAN.py --n_epochs 200 --batch_size 64 --lr 0.0002

主要参数:

  • --n_epochs: 训练的总轮数 (默认: 100)
  • --batch_size: 批次大小 (默认: 64)
  • --lr: Adam 优化器的学习率 (默认: 0.0002)
  • --b1: Adam 优化器的 beta1 (默认: 0.5)
  • --b2: Adam 优化器的 beta2 (默认: 0.999)
  • --latent_dim: 隐向量的维度 (默认: 100)
  • --n_classes: 类别数量 (默认: 10,对应 MNIST 数字 0-9)
  • --img_size: 图片尺寸 (默认: 32)
  • --channels: 图片通道数 (默认: 1)
  • --sample_interval: 保存采样图片的间隔步数 (默认: 1000)

结果展示

运行结束后生成的 result.png 展示了模型根据特定数字序列生成的图像结果。

关于

热身赛作业

10.3 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号