JAX
Differential Programming with JAX course is nice. Meta Optimal Transport is nice JAX repo to run/study.
title: JAX#
Differential Programming with JAX course ↗ is nice. Meta Optimal Transport ↗ is nice JAX repo to run/study.
Robert Lange ↗ has nice JAX repos.
Notes#
Links#
- audax ↗ - Home for audio ML in JAX. Has common features, learnable frontends, pretrained supervised and self-supervised models.
- tinygp ↗ - Extremely lightweight library for building Gaussian Process models in Python, built on top of jax.
- GPJax ↗ - Didactic Gaussian process package for researchers in Jax.
- Mctx ↗ - Monte Carlo tree search in JAX.
- Pipelined Swarm Training ↗ - Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes.
- JAX MuZero ↗ - JAX implementation of the MuZero agent.
- Jax Influence ↗ - Scalable implementation of Influence Functions in JaX.
- BlackJAX ↗ - Library of samplers for JAX that works on CPU as well as GPU. (Twitter ↗)
- GPax ↗ - Jax/Flax codebase for Gaussian processes including meta and multi-task Gaussian processes.
- jax-fenics-adjoint ↗ - Differentiable interface to FEniCS/Firedrake for JAX using dolfin-adjoint/pyadjoint.
- jax-ekf ↗ - Generic EKF, with support for non-Euclidean manifolds.
- PaLM - Jax ↗ - Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax.
- Pre-trained image classification models for Jax/Haiku ↗
- Flaxformer: transformer architectures in JAX/Flax ↗
- KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX ↗
- flowjax ↗ - Normalizing flow implementations in jax.
- Jax3D ↗ - Library for neural rendering in Jax and aims to be a nimble NeRF ecosystem.
- DALL·E 2 in JAX ↗
- JAXNS ↗ - Nested sampling in JAX.
- AUX ↗ - Audio processing library in JAX, for JAX.
- Nice DeepMind Jax libraries ↗
- Machine Learning with JAX - From Zero to Hero (2021) ↗
- Flax ↗ - Neural network library for JAX designed for flexibility. (Docs ↗)
- JAX talks by HuggingFace ↗
- Homomorphic Encryption in JAX ↗
- JAX implementation of Learning to learn by gradient descent by gradient descent ↗
- Normalizing Flows in JAX ↗
- Big Vision ↗ - Designed for training large-scale vision models on Cloud TPU VMs. Based on Jax/Flax libraries.
- Jax vs. Julia (Vs PyTorch) (2022) ↗ (HN ↗)
- minGPT in JAX ↗
- flaxvision ↗ - Selection of neural network models ported from torchvision for JAX & Flax.
- JAX version of clip guided diffusion scripts ↗
- Functorch ↗ - Jax-like composable function transforms for PyTorch. (HN ↗)
- Ninjax ↗ - Module system for JAX that offers full state access and allows to easily combine modules from other libraries.
- Functional Transformer ↗ - Pure-functional implementation of a machine learning transformer model in Python/JAX.
- JAX + Units ↗ - Provides and interface between JAX and Pint to allow JAX to support operations with units.
- Infinite Recommendation Networks (∞-AE) in JAX ↗
- Differential Programming with JAX course ↗ (Code ↗)
- Algorithms for Privacy-Preserving Machine Learning in JAX ↗
- Connex ↗ - Small JAX library built on Equinox whose aim is to incorporate artificial analogues of biological neural network attributes into deep learning research and architecture design.
- Rax ↗ - Composable Learning to Rank using JAX.
- JaX is faster than PyTorch but harder to debug ↗
- JAX Meta Learning ↗ - Collection of meta-learning algorithms in JAX.
- Gymnax ↗ - RL Environments in JAX.
- Pax ↗ - Framework to configure and run machine learning experiments on top of Jax.