Jittor Conditional GAN (CGAN) for MNIST Digit Generation
This project implements a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The model is trained on the MNIST dataset to generate handwritten digit images conditioned on specific class labels (0-9). This implementation was developed as part of the Graphics Course PA3 at Tsinghua University.
Generative Adversarial Networks (GANs) are powerful models for generating realistic data, but standard GANs offer little control over the generated output. Conditional GANs (CGANs) extend the GAN framework by incorporating conditional information, such as class labels, into both the generator and the discriminator.
This project focuses on building a CGAN that takes a random noise vector z and a digit label y (0-9) as input and generates an image G(z, y) resembling a handwritten digit corresponding to the label y. The discriminator D learns to distinguish real MNIST images (x, y1) from generated images (G(z, y2), y2), considering both the image and its label. We use a Mean Squared Error (MSE) based loss function instead of the traditional logarithmic loss for potentially more stable training.
Features
Conditional Generation: Generates MNIST digit images based on provided labels (0-9).
Jittor Implementation: Built entirely using the high-performance Jittor framework.
Fully Connected Networks: Utilizes fully connected layers for both Generator and Discriminator, suitable for the MNIST image size (32x32).
Customizable Training: Offers command-line arguments to configure hyperparameters like learning rate, batch size, and epochs.
Model Persistence: Saves trained generator and discriminator models for later use.
Sequence Generation: Includes functionality to generate a sequence of digits based on a predefined string and save it as result.png.
Model Architecture
Both the Generator and Discriminator leverage embeddings for the class labels and utilize sequences of fully connected layers.
Generator (G):
Takes a latent noise vector z (dim: latent_dim) and a digit label y as input.
Embeds the label y into a vector using nn.Embedding.
Concatenates the noise z and the embedded label.
Passes the concatenated vector through a series of nn.Linear layers with nn.LeakyReLU activations and nn.BatchNorm1d (optional).
The final layer outputs a vector of size img_size * img_size * channels, which is reshaped into the image format.
Uses nn.Tanh activation on the output layer to scale pixel values between -1 and 1.
Discriminator (D):
Takes an image x and a digit label y as input.
Embeds the label y using nn.Embedding.
Flattens the input image x.
Concatenates the flattened image and the embedded label.
Passes the concatenated vector through several nn.Linear layers with nn.LeakyReLU activations and nn.Dropout for regularization.
The final layer outputs a single scalar value representing the predicted validity (realness) of the input image-label pair (without a sigmoid activation, as MSE loss is used directly).
See python CGAN.py --help for all available options.
Generating Specific Digits
After training (or if you have pre-trained generator_last.pkl), the script automatically proceeds to generate a specific sequence of digits.
Modify the Sequence (Optional):
The script currently hardcodes the sequence "28132182806014". To generate a different sequence (e.g., your student ID), modify this line in CGAN.py:
number = "YOUR_DESIRED_SEQUENCE" # e.g., "1234567890"
Run the Script:
If you just finished training, the generation step runs automatically. If you want to generate using existing model files without retraining, you can comment out the training loop section in CGAN.py and run:
python CGAN.py
Output:
The script will generate images corresponding to the digits in the number string and save them concatenated horizontally into a single image file named result.png.
Results
The training process generates intermediate samples. After training, the script produces a result.png file containing the generated digits for the specified sequence.
File Structure
.
├── CGAN.py # Main Python script for training and generation
├── README.md # This file
├── LICENSE # MIT License
└── .gitignore # Specifies intentionally untracked files by Git
Acknowledgements
This project was completed for the Graphics Course PA3 at Tsinghua University.
Jittor Conditional GAN (CGAN) for MNIST Digit Generation
This project implements a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The model is trained on the MNIST dataset to generate handwritten digit images conditioned on specific class labels (0-9). This implementation was developed as part of the Graphics Course PA3 at Tsinghua University.
Project Status: Completed for course assignment.
Table of Contents
Overview
Generative Adversarial Networks (GANs) are powerful models for generating realistic data, but standard GANs offer little control over the generated output. Conditional GANs (CGANs) extend the GAN framework by incorporating conditional information, such as class labels, into both the generator and the discriminator.
This project focuses on building a CGAN that takes a random noise vector z and a digit label y (0-9) as input and generates an image G(z, y) resembling a handwritten digit corresponding to the label y. The discriminator D learns to distinguish real MNIST images (x, y1) from generated images (G(z, y2), y2), considering both the image and its label. We use a Mean Squared Error (MSE) based loss function instead of the traditional logarithmic loss for potentially more stable training.
Features
result.png
.Model Architecture
Both the Generator and Discriminator leverage embeddings for the class labels and utilize sequences of fully connected layers.
Generator (G):
z
(dim:latent_dim
) and a digit labely
as input.y
into a vector usingnn.Embedding
.z
and the embedded label.nn.Linear
layers withnn.LeakyReLU
activations andnn.BatchNorm1d
(optional).img_size * img_size * channels
, which is reshaped into the image format.nn.Tanh
activation on the output layer to scale pixel values between -1 and 1.Discriminator (D):
x
and a digit labely
as input.y
usingnn.Embedding
.x
.nn.Linear
layers withnn.LeakyReLU
activations andnn.Dropout
for regularization.Installation
Prerequisites:
Install Jittor: Follow the official Jittor installation guide: https://cg.cs.tsinghua.edu.cn/jittor/download/ (Pip installation is usually the simplest:
python -m pip install jittor
)Clone the Repository:
Usage
The main script is
CGAN.py
.Training
To train the model from scratch:
The script will:
1000.png
,2000.png
, etc.) to monitor progress.generator_last.pkl
anddiscriminator_last.pkl
every 10 epochs and upon completion.You can customize training using command-line arguments:
See
python CGAN.py --help
for all available options.Generating Specific Digits
After training (or if you have pre-trained
generator_last.pkl
), the script automatically proceeds to generate a specific sequence of digits.Modify the Sequence (Optional): The script currently hardcodes the sequence
"28132182806014"
. To generate a different sequence (e.g., your student ID), modify this line inCGAN.py
:Run the Script: If you just finished training, the generation step runs automatically. If you want to generate using existing model files without retraining, you can comment out the training loop section in
CGAN.py
and run:Output: The script will generate images corresponding to the digits in the
number
string and save them concatenated horizontally into a single image file namedresult.png
.Results
The training process generates intermediate samples. After training, the script produces a
result.png
file containing the generated digits for the specified sequence.File Structure
Acknowledgements
License
This project is open-sourced under the MIT License.