Accelerate deep learning and other number-intensive tasks with JAX, Google’s awesome high-performance numerical computing library.
In Deep Learning with JAX you will learn how to:
Use JAX for numerical calculations
Build differentiable models with JAX primitives
Run distributed and parallelized computations with JAX
Use high-level neural network libraries such as Flax and Haiku
Leverage libraries and modules from the JAX ecosystem
JAX is a Python mathematics library with a NumPy interface developed by Google. It is heavily used for machine learning research, and it seems that JAX has already become the #3 Deep Learning framework (after TensorFlow and PyTorch). It also became the main Deep Learning framework in companies such as DeepMind, and more and more of Google’s own research use JAX. JAX promotes a functional programming paradigm in Deep Learning. It has powerful function
transformations such as taking gradients of a function, JIT-compilation with XLA, auto-vectorization, and parallelization. JAX supports GPU and TPU and provides great performance.
Deep Learning with JAX is a hands-on guide to using JAX for deep learning and other mathematically-intensive applications. Google Developer Expert Grigory Sapunov steadily builds your understanding of JAX’s concepts. The engaging examples introduce the fundamental concepts on which JAX relies and then show you how to apply them to real-world tasks. You’ll learn how to use JAX’s ecosystem of high-level libraries and modules, and also how to combine TensorFlow and PyTorch with JAX for data loading and deployment.
about the technology
The JAX Python mathematics library is used by many successful deep learning organizations, including Google’s groundbreaking DeepMind team. This exciting newcomer already boasts an amazing ecosystem of tools including high-level deep learning libraries Flax by Google, Haiku by DeepMind, gradient processing and optimization libraries, libraries for evolutionary computations, federated learning, and much more! JAX brings a functional programming mindset to Python deep learning, letting you improve your composability and parallelization in a cluster.
about the book
Deep Learning with JAX teaches you how to use JAX and its ecosystem to build neural networks. You’ll learn by exploring interesting examples including an image classification tool, an image filter application, and a massive scale neural network with distributed training across a cluster of TPUs. Discover how to work with JAX for hardware and other low-level aspects and how to solve common machine learning problems with JAX. By the time you’re finished with this awesome book, you’ll be ready to start applying JAX to your own research and prototyping!
Author(s): Grigory Sapunov
Publisher: Manning Publications
Year: 2023
Language: English
Pages: 211
Deep Learning with JAX MEAP V06
Copyright
Welcome
Brief contents
Chapter 1: Intro to JAX
1.1 What is JAX?
1.1.1 JAX as NumPy
1.1.2 Composable transformations
1.2 Why use JAX?
1.2.1 Computational performance
1.2.2 Functional approach
1.2.3 JAX ecosystem
1.3 How is JAX different from TensorFlow/PyTorch?
1.4 Summary
Chapter 2: Your first program in JAX
2.1 A toy ML problem: classifying handwritten digits
2.2 Loading and preparing the dataset
2.3 A simple neural network in JAX
2.3.1 Neural network initialization
2.3.2 Neural network forward pass
2.4 vmap: auto-vectorizing calculations to work with batches
2.5 Autodiff: how to calculate gradients without knowing about derivatives
2.6 JIT: compiling your code to make it faster
2.7 Pure functions and composable transformations: why is it important?
2.8 An overview of a JAX deep learning project
2.9 Exercises
2.10 Summary
Chapter 3: Working with tensors
3.1 Image processing with NumPy arrays
3.1.1 Loading and storing images in NumPy arrays
3.1.2 Performing basic image processing with NumPy API
3.2 Tensors in JAX
3.2.1 Switching to JAX NumPy-like API
3.2.2 What is the DeviceArray?
3.2.3 Device-related operations
3.2.4 Asynchronous dispatch
3.2.5 Moving image processing to TPU
3.3 Differences with NumPy
3.3.1 Immutability
3.3.2 Types
3.4 High-level and low-level interfaces: jax.numpy and jax.lax
3.5 Exercises
3.6 Summary
Chapter 4: Autodiff
4.1 Different ways of getting derivatives
4.1.1 Manual differentiation
4.1.2 Symbolic differentiation
4.1.3 Numerical differentiation
4.1.4 Automatic differentiation
4.2 Calculating gradients with autodiff
4.2.1 Working with gradients in TensorFlow
4.2.2 Working with gradients in PyTorch
4.2.3 Working with gradients in JAX
4.2.4 Higher-order derivatives
4.2.5 Multivariable case
4.3 Forward and Reverse mode autodiff
4.3.1 Evaluation trace
4.3.2 Forward mode and jvp()
4.3.3 Reverse mode and vjp()
4.3.4 Going deeper
4.4 Summary
Chapter 5: Compiling your code
5.1 Using compilation
5.1.1 Using Just-in-Time (JIT) compilation
5.1.2 Pure functions
5.2 JIT internals
5.2.1 Jaxpr, an intermediate representation for JAX programs
5.2.2 XLA
5.2.3 Using Ahead-of-Time (AOT) compilation
5.3 JIT limitations
5.4 Summary
Chapter 6: Vectorizing your code
6.1 Different ways to vectorize a function
6.1.1 Naive approaches
6.1.2 Manual vectorization
6.1.3 Automatic vectorization
6.1.4 Speed comparisons
6.2 Controlling vmap() behavior
6.2.1 Controlling array axes to map over
6.2.2 Controlling output array axes
6.2.3 Using named arguments
6.2.4 Using decorator style
6.2.5 Using collective operations
6.3 Real-life use cases for vmap()
6.3.1 Batch data processing
6.3.2 Batching neural network models
6.3.3 Per-sample gradients
6.3.4 Vectorizing loops
6.4 Summary
Chapter 7: Parallelizing your computations
7.1 Parallelizing computations with pmap()
7.1.1 Setting up a problem
7.1.2 Using pmap (almost) like vmap
7.2 Controlling pmap() behavior
7.2.1 Controlling input and output mapping axes
7.2.2 Using names axes and collectives
7.3 Data parallel neural network training example
7.3.1 Preparing data and neural network structure
7.3.2 Implementing data parallel training
7.4 Summary