目录
目录README.md

Conditional GAN with Jittor

Overview

This repository contains the implementation of a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The network is designed to generate images conditioned on class labels. It has been implemented to train on the MNIST dataset, a popular dataset of handwritten digits.

Requirements

  • Python 3.6+
  • Jittor
  • Pillow (for image saving and manipulation)

Setup

  1. Install Jittor:
    pip3 install jittor

Usage

Training the CGAN

The main script is designed to train the CGAN model on the MNIST dataset. You can customize various training parameters using command-line arguments.

python main.py --n_epochs 100 --batch_size 128 --lr 0.0001 --b1 0.5 --b2 0.999 --n_cpu 8 --latent_dim 100 --n_classes 10 --img_size 32 --channels 1 --sample_interval 1000`

Command-line arguments:

  • --n_epochs: Number of epochs of training (default: 100)
  • --batch_size: Size of the batches (default: 128)
  • --lr: Learning rate for Adam optimizer (default: 0.0001)
  • --b1: Decay of first-order momentum of gradient for Adam optimizer (default: 0.5)
  • --b2: Decay of second-order momentum of gradient for Adam optimizer (default: 0.999)
  • --n_cpu: Number of CPU threads to use during batch generation (default: 8)
  • --latent_dim: Dimensionality of the latent space (default: 100)
  • --n_classes: Number of classes for the dataset (default: 10)
  • --img_size: Size of each image dimension (default: 32)
  • --channels: Number of image channels (default: 1)
  • --sample_interval: Interval between image sampling (default: 1000)

Generating Images

After training, you can generate images conditioned on specific labels. Juts need to run test.py specifying command line argument --number.

  1. Ensure you have trained models saved in the ./models/ directory.
  2. Run the script to generate and save images.
# Test example
python test.py --number 20773852072461

Code Explanation

Main Components

  1. Generator:

    • Takes noise and class labels as input.
    • Consists of several fully connected layers with batch normalization and LeakyReLU activation.
    • Outputs an image of specified dimensions.
  2. Discriminator:

    • Takes an image and class labels as input.
    • Consists of several fully connected layers with dropout and LeakyReLU activation.
    • Outputs a single value representing the probability of the image being real.
  3. Training Loop:

    • The training loop alternates between training the generator and the discriminator.
    • The discriminator is trained to distinguish between real and generated images.
    • The generator is trained to produce images that can fool the discriminator.
    • Losses are calculated using mean squared error (MSE).

Saving and Loading Models

  • Models are saved every 10 epochs to the ./models/ directory.
  • The script includes code to load the last saved generator and discriminator models for inference.

Directory Structure

├── CGAN.py # Main script with model definitions and training loop
├── test.py # Script to generate image of given number
├── models # Directory to save trained models
├── images #Directory to save generated images
└── README.md # This readme file`
关于

该代码使用 Jittor 框架实现了一个条件生成对抗网络 (CGAN),用于生成受类标签控制的图像,并基于 MNIST 数据集进行训练。

4.0 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号