目录
目录README.md

PA3_oyx - Conditional GAN (CGAN) Implementation

This project implements a Conditional Generative Adversarial Network (CGAN) using the Jittor deep learning framework. The CGAN is trained on the MNIST dataset to generate handwritten digit images conditioned on specific class labels.

Overview

Conditional GANs extend the original GAN architecture by conditioning both the generator and discriminator on additional information (class labels). This allows for controlled generation of samples from specific classes.

Features

  • Generator: Takes random noise and class labels as input to generate realistic digit images
  • Discriminator: Classifies images as real/fake while considering class labels
  • Conditional Training: Both networks are trained with class label information
  • MNIST Dataset: Trained on 32x32 grayscale handwritten digit images
  • GPU Support: Automatically uses CUDA if available
  • Model Persistence: Saves trained models every 10 epochs
  • Sample Generation: Generates sample images during training

Requirements

jittor
numpy
PIL (Pillow)
argparse

Installation

  1. Install Jittor:

    pip install jittor
  2. Install other dependencies:

    pip install pillow numpy

Usage

Basic Training

python CGAN.py

Custom Parameters

python CGAN.py --n_epochs 200 --batch_size 128 --lr 0.0001

Available Arguments

  • --n_epochs: Number of training epochs (default: 100)
  • --batch_size: Batch size for training (default: 64)
  • --lr: Learning rate for Adam optimizer (default: 0.0002)
  • --b1: Beta1 parameter for Adam (default: 0.5)
  • --b2: Beta2 parameter for Adam (default: 0.999)
  • --n_cpu: Number of CPU threads (default: 8)
  • --latent_dim: Dimensionality of latent space (default: 100)
  • --n_classes: Number of classes (default: 10)
  • --img_size: Size of generated images (default: 32)
  • --channels: Number of image channels (default: 1)
  • --sample_interval: Interval for saving sample images (default: 1000)

Model Architecture

Generator

  • Input: Random noise (100D) + Class label embedding (10D)
  • Architecture:
    • 4 fully connected layers (110→128→256→512→1024)
    • Batch normalization and LeakyReLU activations
    • Final linear layer to image dimensions (1024)
  • Output: 32x32 grayscale images with Tanh activation

Discriminator

  • Input: Flattened image (1024D) + Class label embedding (10D)
  • Architecture:
    • 4 fully connected layers with dropout (0.4) and LeakyReLU
    • Input dimension: 1034 (1024 + 10)
    • Hidden layers: 512 neurons each
  • Output: Single real/fake probability

Training Process

  1. Generator Training: Learn to generate realistic images that fool the discriminator
  2. Discriminator Training: Learn to distinguish between real and generated images
  3. Conditional Loss: Both networks consider class label information
  4. Adversarial Loss: Mean Squared Error between predictions and target labels

The training alternates between:

  • Training generator to maximize discriminator error
  • Training discriminator to correctly classify real vs fake images

Output Files

Generated During Training

  • {batch_number}.png: Sample images generated every sample_interval batches
  • Images are arranged in a 10x10 grid showing digits 0-9

Model Checkpoints

  • generator_last.pkl: Saved generator model (every 10 epochs)
  • discriminator_last.pkl: Saved discriminator model (every 10 epochs)

Final Output

  • result.png: Final generated image sequence for the specified number string

Project Structure

PA3_oyx/
├── CGAN.py                    # Main implementation file
├── README.md                  # This file
├── .gitignore                # Git ignore rules
├── generator_last.pkl        # Trained generator model
├── discriminator_last.pkl    # Trained discriminator model
├── result.png               # Final generated sequence
├── *.png                    # Training sample images
└── .git/                    # Git repository

Key Implementation Details

Label Embedding

Both generator and discriminator use embedding layers to convert class labels into dense vectors that are concatenated with image data.

Loss Function

Uses Mean Squared Error (MSE) loss for the adversarial training:

  • Generator loss: MSE between discriminator output and “real” labels
  • Discriminator loss: Average of MSE for real images (target=1) and fake images (target=0)

Training Schedule

  • Models are saved every 10 epochs
  • Sample images are generated every 1000 batches
  • Training progress is printed every 50 batches

Results

The trained model can generate conditioned digit images where you can specify which digit (0-9) to generate. The final output demonstrates this by generating images for the sequence “2023010788”.

Student Information

  • Student ID: 2023010788
  • Project: PA3 Assignment
  • Framework: Jittor
  • Model: Conditional GAN

License

This project is for educational purposes as part of PA3 assignment.

关于
14.4 MB
邀请码