目录
目录README.md

Jittor Conditional GAN (CGAN) for MNIST Digit Generation

Jittor Logo

This project implements a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The model is trained on the MNIST dataset to generate handwritten digit images conditioned on specific class labels (0-9). This implementation was developed as part of the Graphics Course PA3 at Tsinghua University.

Project Status: Completed for course assignment.

Table of Contents

Overview

Generative Adversarial Networks (GANs) are powerful models for generating realistic data, but standard GANs offer little control over the generated output. Conditional GANs (CGANs) extend the GAN framework by incorporating conditional information, such as class labels, into both the generator and the discriminator.

This project focuses on building a CGAN that takes a random noise vector z and a digit label y (0-9) as input and generates an image G(z, y) resembling a handwritten digit corresponding to the label y. The discriminator D learns to distinguish real MNIST images (x, y1) from generated images (G(z, y2), y2), considering both the image and its label. We use a Mean Squared Error (MSE) based loss function instead of the traditional logarithmic loss for potentially more stable training.

Features

  • Conditional Generation: Generates MNIST digit images based on provided labels (0-9).
  • Jittor Implementation: Built entirely using the high-performance Jittor framework.
  • Fully Connected Networks: Utilizes fully connected layers for both Generator and Discriminator, suitable for the MNIST image size (32x32).
  • Customizable Training: Offers command-line arguments to configure hyperparameters like learning rate, batch size, and epochs.
  • Model Persistence: Saves trained generator and discriminator models for later use.
  • Sequence Generation: Includes functionality to generate a sequence of digits based on a predefined string and save it as result.png.

Model Architecture

Both the Generator and Discriminator leverage embeddings for the class labels and utilize sequences of fully connected layers.

  1. Generator (G):

    • Takes a latent noise vector z (dim: latent_dim) and a digit label y as input.
    • Embeds the label y into a vector using nn.Embedding.
    • Concatenates the noise z and the embedded label.
    • Passes the concatenated vector through a series of nn.Linear layers with nn.LeakyReLU activations and nn.BatchNorm1d (optional).
    • The final layer outputs a vector of size img_size * img_size * channels, which is reshaped into the image format.
    • Uses nn.Tanh activation on the output layer to scale pixel values between -1 and 1.
  2. Discriminator (D):

    • Takes an image x and a digit label y as input.
    • Embeds the label y using nn.Embedding.
    • Flattens the input image x.
    • Concatenates the flattened image and the embedded label.
    • Passes the concatenated vector through several nn.Linear layers with nn.LeakyReLU activations and nn.Dropout for regularization.
    • The final layer outputs a single scalar value representing the predicted validity (realness) of the input image-label pair (without a sigmoid activation, as MSE loss is used directly).

Installation

  1. Prerequisites:

    • Python 3.x
    • C++ Compiler (g++ >= 5.4 or clang >= 8.0)
    • Linux or Windows (WSL recommended for Windows)
  2. Install Jittor: Follow the official Jittor installation guide: https://cg.cs.tsinghua.edu.cn/jittor/download/ (Pip installation is usually the simplest: python -m pip install jittor)

  3. Clone the Repository:

    git clone https://gitlink.org.cn/Zenith/CGAN_jittor.git
    cd CGAN_jittor

Usage

The main script is CGAN.py.

Training

To train the model from scratch:

python CGAN.py

The script will:

  • Automatically download the MNIST dataset if not found.
  • Train the Generator and Discriminator using the specified hyperparameters (defaults are provided in the script).
  • Print training progress (Epoch, Batch, D loss, G loss).
  • Periodically save sample generated images (e.g., 1000.png, 2000.png, etc.) to monitor progress.
  • Save the final trained models as generator_last.pkl and discriminator_last.pkl every 10 epochs and upon completion.

You can customize training using command-line arguments:

python CGAN.py --n_epochs 200 --batch_size 128 --lr 0.0001

See python CGAN.py --help for all available options.

Generating Specific Digits

After training (or if you have pre-trained generator_last.pkl), the script automatically proceeds to generate a specific sequence of digits.

  1. Modify the Sequence (Optional): The script currently hardcodes the sequence "28132182806014". To generate a different sequence (e.g., your student ID), modify this line in CGAN.py:

    number = "YOUR_DESIRED_SEQUENCE" # e.g., "1234567890"
  2. Run the Script: If you just finished training, the generation step runs automatically. If you want to generate using existing model files without retraining, you can comment out the training loop section in CGAN.py and run:

    python CGAN.py
  3. Output: The script will generate images corresponding to the digits in the number string and save them concatenated horizontally into a single image file named result.png.

Results

The training process generates intermediate samples. After training, the script produces a result.png file containing the generated digits for the specified sequence.

File Structure

.
├── CGAN.py             # Main Python script for training and generation
├── README.md           # This file
├── LICENSE             # MIT License
└── .gitignore          # Specifies intentionally untracked files by Git

Acknowledgements

  • This project was completed for the Graphics Course PA3 at Tsinghua University.
  • Based on the concepts from the Conditional GAN paper: Conditional Generative Adversarial Nets by Mehdi Mirza and Simon Osindero.
  • Built with the Jittor framework.

License

This project is open-sourced under the MIT License.

关于

A Jittor implementation of Conditional GAN (CGAN) trained on MNIST to generate handwritten digits conditioned on specific labels (0-9).

38.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号