This repository contains an implementation of the StyleGAN architecture based on the paper "A Style-Based Generator Architecture for Generative Adversarial Networks" by Tero Karras, Samuli Laine, and Timo Aila.
StyleGAN is a state-of-the-art architecture for generative adversarial networks that introduces a novel way to control the image synthesis process. It separates high-level attributes (pose, identity) from stochastic variation (freckles, hair) in generated images, enabling more intuitive control over the generation process.
This implementation focuses on the core components of StyleGAN:
- Mapping Network: Transforms the input latent code into an intermediate latent space
- Synthesis Network: Generates images using AdaIN (Adaptive Instance Normalization)
- Style Mixing: Allows combining styles from different latent codes
- Progressive Growing: Gradually increases the resolution of generated images
The implementation includes the following key components:
- Mapping Network: Transforms the input latent vector z into an intermediate latent space w
- AdaIN (Adaptive Instance Normalization): Applies style-based modulation
- Noise Injection: Adds stochastic variation to the generated images
- Synthesis Network: Generates images progressively through multiple resolution blocks
- Discriminator: WGAN-GP based discriminator for adversarial training
- Python 3.6+
- PyTorch 1.7+
- torchvision
- numpy
- PIL
- tqdm
- matplotlib
The model is trained using a WGAN-GP loss function with the following parameters:
# Hyperparameters
BATCH_SIZE = 16
Z_DIM = 128
lr = 0.0001
beta1 = 0.0
beta2 = 0.99
gp_lambda = 10.0Training can be started with:
# Initialize models
Generator_model = Generator(z_dim=128, w_dim=128).to(device)
Discriminator_model = Discriminator().to(device)
# Training loop (see training code in notebook)To generate images using a trained model:
# Generate a random latent vector
z = torch.randn(1, 128, device=device)
# Generate an image
generated_img = Generator_model(z)
# Display the image
import matplotlib.pyplot as plt
import numpy as np
img_np = (generated_img[0].detach().cpu().numpy() + 1) / 2
img_np = np.transpose(img_np, (1, 2, 0))
plt.imshow(img_np)
plt.axis('off')
plt.show()The generator consists of:
- Mapping Network: 8-layer MLP that maps the input latent code to the intermediate latent space
- Synthesis Network: Generates images through multiple resolution blocks with AdaIN and noise injection
- Constant Input: Starts with a learned constant input (4×4×128)
- Progressive Synthesis: Gradually increases resolution through upsampling
The discriminator is a standard convolutional network with:
- From RGB: Initial conversion from RGB to feature space
- Downsampling Blocks: Multiple blocks that progressively reduce resolution
- Final Classification: Outputs a single value for WGAN-GP loss
This implementation is based on the StyleGAN paper by NVIDIA Research:
Karras, T., Laine, S., & Aila, T. (2019). A Style-Based Generator Architecture for Generative Adversarial Networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4401-4410).