This repository contains the implementation of a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The network is designed to generate images conditioned on class labels. It has been implemented to train on the MNIST dataset, a popular dataset of handwritten digits.
Requirements
Python 3.6+
Jittor
Pillow (for image saving and manipulation)
Setup
Install Jittor:
pip3 install jittor
Usage
Training the CGAN
The main script is designed to train the CGAN model on the MNIST dataset. You can customize various training parameters using command-line arguments.
--n_epochs: Number of epochs of training (default: 100)
--batch_size: Size of the batches (default: 128)
--lr: Learning rate for Adam optimizer (default: 0.0001)
--b1: Decay of first-order momentum of gradient for Adam optimizer (default: 0.5)
--b2: Decay of second-order momentum of gradient for Adam optimizer (default: 0.999)
--n_cpu: Number of CPU threads to use during batch generation (default: 8)
--latent_dim: Dimensionality of the latent space (default: 100)
--n_classes: Number of classes for the dataset (default: 10)
--img_size: Size of each image dimension (default: 32)
--channels: Number of image channels (default: 1)
--sample_interval: Interval between image sampling (default: 1000)
Generating Images
After training, you can generate images conditioned on specific labels. Juts need to run test.py specifying command line argument --number.
Ensure you have trained models saved in the ./models/ directory.
Run the script to generate and save images.
# Test example
python test.py --number 20773852072461
Code Explanation
Main Components
Generator:
Takes noise and class labels as input.
Consists of several fully connected layers with batch normalization and LeakyReLU activation.
Outputs an image of specified dimensions.
Discriminator:
Takes an image and class labels as input.
Consists of several fully connected layers with dropout and LeakyReLU activation.
Outputs a single value representing the probability of the image being real.
Training Loop:
The training loop alternates between training the generator and the discriminator.
The discriminator is trained to distinguish between real and generated images.
The generator is trained to produce images that can fool the discriminator.
Losses are calculated using mean squared error (MSE).
Saving and Loading Models
Models are saved every 10 epochs to the ./models/ directory.
The script includes code to load the last saved generator and discriminator models for inference.
Directory Structure
├── CGAN.py # Main script with model definitions and training loop
├── test.py # Script to generate image of given number
├── models # Directory to save trained models
├── images #Directory to save generated images
└── README.md # This readme file`
Conditional GAN with Jittor
Overview
This repository contains the implementation of a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The network is designed to generate images conditioned on class labels. It has been implemented to train on the MNIST dataset, a popular dataset of handwritten digits.
Requirements
Setup
Usage
Training the CGAN
The main script is designed to train the CGAN model on the MNIST dataset. You can customize various training parameters using command-line arguments.
Command-line arguments:
--n_epochs
: Number of epochs of training (default: 100)--batch_size
: Size of the batches (default: 128)--lr
: Learning rate for Adam optimizer (default: 0.0001)--b1
: Decay of first-order momentum of gradient for Adam optimizer (default: 0.5)--b2
: Decay of second-order momentum of gradient for Adam optimizer (default: 0.999)--n_cpu
: Number of CPU threads to use during batch generation (default: 8)--latent_dim
: Dimensionality of the latent space (default: 100)--n_classes
: Number of classes for the dataset (default: 10)--img_size
: Size of each image dimension (default: 32)--channels
: Number of image channels (default: 1)--sample_interval
: Interval between image sampling (default: 1000)Generating Images
After training, you can generate images conditioned on specific labels. Juts need to run
test.py
specifying command line argument--number
../models/
directory.Code Explanation
Main Components
Generator:
Discriminator:
Training Loop:
Saving and Loading Models
./models/
directory.Directory Structure