GSOC21 Highlights

16 minute read

Published:

Important: Many blogs discuss how to write a successful proposal for GSOC, but this isn’t one of them. This blog talks about my GSOC experience, and I will link below to blogs that talk a bit about how to write a successful proposal that I agree with or used when I was crafting mine.

  1. https://blog.laymonage.com/posts/gsoc/
  2. https://blog.shubhank.codes/series/gsoc-21

The proposal and unlikely acceptance

Honestly, being part of google summer of code (GSOC), 2021 has been a fever dream. It wasn’t something I expected, and I applied it on a whim. I still remember scrolling through Twitter and seeing a tweet about it from the TensorFlow team and decided then and there that I wanted to apply. Frankly, I spent only 2 hours on my GSOC application before submitting it to Tensorflow on the GSOC website, unlike what was typically recommended, at least based on the experiences of past GSOC participants on Reddit and Quora. It, therefore, surprised me when I got the news that I was accepted on 17 May.

The first meeting and hitting the grind

The first meeting was a little intimidating. Everybody seemed so ahead, and many of them had made many PRs to the PyProbML repo before GSOC started (for context, I only made 1 PR) before GSOC started. These concerns basically evaporated the moment we hit the grind. I was soon pretty wrap up in my work and learning from everybody. There have been so many things I learnt and did that it would be tough to list down everything. This blog post is, therefore, more to highlights some of the interesting issues I worked on. However, for the impatient reader, the TLDR version is that I basically implemented several sampling methods and generative models in JAX and PyTorch that can be referenced and used to replicate some of the figures in Kevin’s textbook. In the end, I think the team finish the original set of issues within the first half of GSOC, so we had a lot of time to try other things like implementing GANs and VAEs.

The big 3: Potts, VAEs and GANs

  1. JAX potts models

    My first big challenge in GSOC was speeding up and then JAX-ifying my Numpy/Numba implementation of the Potts model. The idea of the Potts model is that it is a generalised version of the Ising model where instead of binary outputs (spin-up (+1) or spin-down (-1)), we have $n$ outputs where $n$ is some integer larger than 2. One of the reasons this is surprisingly tricky is that the convergence of Gibbs sampling on such Potts model is not the same as the simpler Ising model. The Potts model often required more steps before getting good results using Gibbs sampling due to slower convergence as $n$ increases. This resulted in the demo running for slightly over an hour, but by the end of the refactoring and iteration process, this number was about 1-2 mins for the full demonstration. This was made possible because of 3 things.

    • JIT compilation in JAX
    • Modification of the algorithm (Using Gumble trick that can directly use the energy of potts to sample from a categorical distribution quickly, convolution operators to update the state and block sampling to leverage hardware accelerators fully)
    • Hardware accelerators like GPUs

    These algorithmic changes are elaborated in this notebook I made to break down how the code works. If you are interested, you can have a read over there.

    The most challenging thing about writing the code was the debugging aspect. This is because JAX was still quite new to me then, and I haven’t had enough time to understand all the sharp bits yet at this time. Furthermore, I had developed some bad habits from working with NumPy that used frequent state mutations that is very costly in JAX example index updates. This meant developing a new mental model about the cost of operations as I was debugging and profiling the code to speed it up. I think newcomers to JAX should be aware that there will be an initial time sink in your first project as you get used to JAX. Still, I think this initial investment is worth it. Although I couldn’t really use JAX much for subsequent projects, I used it in my own personal projects to do exciting things like a JAX version of Sinkhorn iteration and Langevin Sampling. One thing these projects made me realise about JAX is that JAX is not just another auto-grad library. JAX has JIT and provides, maybe, more importantly, high-level abstraction for hardware accelerators, allowing you to accelerate numerical code easily. Hopefully, more people working on numerical algorithms can notice and use this to develop some truly scalable numerical code.

  2. VAE zoo

    After working on the JAX Potts model, I worked on several other sampling-related issues before working on deep generative models, i.e. VAEs and GANs. These 2 model families pretty much dominated the remaining 2nd half of GSOC, resulting in the VAE and GAN zoo, respectively. These 2 zoos pretty much consolidate the various architectures and tricks that I found and implemented for GANs and VAEs over the summer.

    The VAE zoo was hugely inspired by Anand Krishnamoorthy Pytorch-VAE library that I recommend anyone interested in VAE check out. Taking many design inspirations from that library, I started designing my own library and imagining how it would work. The key idea behind how the VAE zoo is organised is the assembler. The idea of the assembler is to have a script take a configuration file that specifies which components along with the hyperparameters and then make a VAE. The assembler does this by taking different components of the VAE from a model definition file and instantiating each component with its corresponding hyperparameters in a YAML file. Then places each component object into a template that is then used for training or inference. This separates the different aspects of the model, the encoder, decoder and training loop, and means that when we want to make a new model, we only need to make a new template but can reuse encoder and decoder code. In short, create a script that handles the piping for you to handle piping and architecture code separately. It is a simple idea that I honestly can’t believe I never thought of or realised before. In the world of design patterns, I think this is called a builder pattern though I may be mistaken.

  3. GAN zoo

    The GAN zoo was designed similar to the VAE zoo but focused on tricks like top-k and instance noise instead of different architectures. The GAN zoo implements a bunch of papers such as LOGAN, WGAN etc. I sadly did not have the time to implement larger and successful GAN architectures like StyleGAN or the more recent alias-free GAN, which shows remarkable performance and results in generating photo-realistic images.

Afterword

At the start of GSOC, I made some goals about what I want to do if I was ever accepted into GSOC; contribute to open source, strengthening my foundation, especially on sampling and bayesian methods, and gaining experience working in a team. I think it is safe to say I have achieved most of the goals I set for myself. Thanks, Kevin and Mahmoud, for this amazing opportunity and experience. It was fun working with everyone in PyProbML over the summer. I hope to work with you guys again 🙂

Other issues I worked on

  1. Rejection Sampling (link to PR)
  2. Adaptive rejection sampling (link to PR)
  3. Slice sampling (link to PR)
  4. Ising denoising demo using CAVI (link to PR)
  5. ebBinomal (link to code)
  6. Gibbs sampling for ising (link to PR)
  7. Hopfield (link to PR)
  8. Numba version of potts sampling (link to PR)
  9. Flax versions of VAE (link to code)
  10. JAX MD example of multi-host TPU processing (link to code)
  11. Sinkhorn algorithm (link to code)

    Link to slide deck for end of GSOC here

Leave a Comment