Skip to content

Flax (JAX) implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation

License

Notifications You must be signed in to change notification settings

n2cholas/progan-flax

Repository files navigation

Progressive Growing of GANs in Flax

Flax (JAX) implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. This code is meant to a starting point you can fork for your own needs rather than a full re-implementation.

Some curated samples below from the generator trained on the CelebA (not HQ) dataset. All hyperparameters are in src/conf/config.yaml. They're not as good as the original paper due to the significantly lower training time.

Usage

  1. Download and extract CelebA.
  2. Install JAX (instructions vary by system).
  3. Install the other dependencies (ideally in a pip environment) via pip install -r requirements.txt. Requires Python >=3.6.
  4. Run the code via python src/main.py data_dir=<celeba directory>.

The code was run on a TPUv3-8. You will need to adjust the hyperparameters in src/conf/config.yaml for your local system (e.g. set distributed: False, decrease batch size, etc).

Differences from Original Paper

  • Different learning rates and batch sizes.
  • Transition (interpolation between previous and current stage) only lasts 80% of each stage instead of entire stage.
  • Slightly smaller model (with same architecture) since this implementation is for CelebA up to 128x128.
  • Trained with bfloat16 without loss scaling (as opposed to float16 with loss scaling).
  • tanh activation for the Generator outputs.
  • Gain used for equalized learning rate adjusted for each activation instead of using sqrt(2) throughout (gains computed based on PyTorch).

Training Results

Below are training curves with the configuration in src/conf/config.yaml. Each vertical grey line indicates going from one stage to the next. The spikes in time per step at the beginning of each stage correspond to compilation. The total training time was 11 hours 6 minutes on a TPUv3-8. Training for longer would likely give better results. Full training logs and checkpoints can be found here.

About

Flax (JAX) implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation

Topics

Resources

License

Stars

Watchers

Forks

Languages