Deep Learning with JAX

This document was uploaded by one of our users. The uploader already confirmed that they had the permission to publish it. If you are author/publisher or own the copyright of this documents, please report to us by using this DMCA report form.

Simply click on the Download Book button.

Yes, Book downloads on Ebookily are 100% Free.

Sometimes the book is free on Amazon As well, so go ahead and hit "Search on Amazon"

Accelerate deep learning and other number-intensive tasks with JAX, Google’s awesome high-performance numerical computing library. In Deep Learning with JAX you will learn how to: • Use JAX for numerical calculations • Build differentiable models with JAX primitives • Run distributed and parallelized computations with JAX • Use high-level neural network libraries such as Flax and Haiku • Leverage libraries and modules from the JAX ecosystem The JAX numerical computing library tackles the core performance challenges at the heart of deep learning and other scientific computing tasks. By combining Google’s Accelerated Linear Algebra platform (XLA) with a hyper-optimized version of NumPy and a variety of other high-performance features, JAX delivers a huge performance boost in low-level computations and transformations. Deep Learning with JAX is a hands-on guide to using JAX for deep learning and other mathematically-intensive applications. Google Developer Expert Grigory Sapunov steadily builds your understanding of JAX’s concepts. The engaging examples introduce the fundamental concepts on which JAX relies and then show you how to apply them to real-world tasks. You’ll learn how to use JAX’s ecosystem of high-level libraries and modules, and also how to combine TensorFlow and PyTorch with JAX for data loading and deployment. About the technology The JAX Python mathematics library is used by many successful deep learning organizations, including Google’s groundbreaking DeepMind team. This exciting newcomer already boasts an amazing ecosystem of tools including high-level deep learning libraries Flax by Google, Haiku by DeepMind, gradient processing and optimization libraries, libraries for evolutionary computations, federated learning, and much more! JAX brings a functional programming mindset to Python deep learning, letting you improve your composability and parallelization in a cluster. About the book Deep Learning with JAX teaches you how to use JAX and its ecosystem to build neural networks. You’ll learn by exploring interesting examples including an image classification tool, an image filter application, and a massive scale neural network with distributed training across a cluster of TPUs. Discover how to work with JAX for hardware and other low-level aspects and how to solve common machine learning problems with JAX. By the time you’re finished with this awesome book, you’ll be ready to start applying JAX to your own research and prototyping! About the reader For intermediate Python programmers who are familiar with deep learning. About the author Grigory Sapunov is a co-founder and CTO of Intento. He is a software engineer with more than twenty years of experience. Grigory holds a Ph.D. in artificial intelligence and is a Google Developer Expert in Machine Learning.

Author(s): Grigory Sapunov
Edition: 1
Publisher: Manning
Year: 2024

Language: English
Commentary: Publisher's PDF
Pages: 408
City: Shelter Island, NY
Tags: Deep Learning; Python; Parallel Programming; JAX; Tensor Sharding; pytrees

Deep Learning with JAX
brief contents
contents
preface
acknowledgments
about this book
Who should read this book?
How this book is organized: A roadmap
About the code
liveBook discussion forum
Other online resources
about the author
about the cover illustration
Part 1
1 When and why to use JAX
1.1 Reasons to use JAX
1.1.1 Computational performance
1.1.2 Functional approach
1.1.3 JAX ecosystem
1.2 How is JAX different from NumPy?
1.2.1 JAX as NumPy
1.2.2 Composable transformations
1.3 How is JAX different from TensorFlow and PyTorch?
2 Your first program in JAX
2.1 A toy ML problem: Classifying handwritten digits
2.2 An overview of a JAX deep learning project
2.3 Loading and preparing the dataset
2.4 A simple neural network in JAX
2.4.1 Neural network initialization
2.4.2 Neural network forward pass
2.5 vmap: Auto-vectorizing calculations to work with batches
2.6 Autodiff: How to calculate gradients without knowing about derivatives
2.6.1 Loss function
2.6.2 Obtaining gradients
2.6.3 Gradient update step
2.6.4 Training loop
2.7 JIT: Compiling your code to make it faster
2.8 Saving and deploying the model
2.9 Pure functions and composable transformations: Why are they important?
Part 2
3 Working with arrays
3.1 Image processing with NumPy arrays
3.1.1 Loading an image into a NumPy array
3.1.2 Performing basic preprocessing operations with an image
3.1.3 Adding noise to the image
3.1.4 Implementing image filtering
3.1.5 Saving a tensor as an image file
3.2 Arrays in JAX
3.2.1 Switching to JAX NumPy-like API
3.2.2 What is Array?
3.2.3 Device-related operations
3.2.4 Asynchronous dispatch
3.2.5 Running computations on TPU
3.3 Differences from NumPy
3.3.1 Immutability
3.3.2 Types
3.4 High-level and low-level interfaces: jax.numpy and jax.lax
3.4.1 Control flow primitives
3.4.2 Type promotion
4 Calculating gradients
4.1 Different ways of getting derivatives
4.1.1 Manual differentiation
4.1.2 Symbolic differentiation
4.1.3 Numerical differentiation
4.1.4 Automatic differentiation
4.2 Calculating gradients with autodiff
4.2.1 Working with gradients in TensorFlow
4.2.2 Working with gradients in PyTorch
4.2.3 Working with gradients in JAX
4.2.4 Higher-order derivatives
4.2.5 Multivariable case
4.3 Forward- and reverse-mode autodiff
4.3.1 Evaluation trace
4.3.2 Forward mode and jvp()
4.3.3 Reverse mode and vjp()
4.3.4 Going deeper
5 Compiling your code
5.1 Using compilation
5.1.1 Using JIT compilation
5.1.2 Pure functions and compilation process
5.2 JIT internals
5.2.1 Jaxpr, an intermediate representation for JAX programs
5.2.2 XLA
5.2.3 Using AOT compilation
5.3 JIT limitations
5.3.1 Pure and impure functions
5.3.2 Exact numerics
5.3.3 Conditioning on input parameter values
5.3.4 Slow compilation
5.3.5 Class methods
5.3.6 Simple functions
6 Vectorizing your code
6.1 Different ways to vectorize a function
6.1.1 Naive approaches
6.1.2 Manual vectorization
6.1.3 Automatic vectorization
6.1.4 Speed comparisons
6.2 Controlling vmap() behavior
6.2.1 Controlling array axes to map over
6.2.2 Controlling output array axes
6.2.3 Using named arguments
6.2.4 Using decorator style
6.2.5 Using collective operations
6.3 Real-life use cases for vmap()
6.3.1 Batch data processing
6.3.2 Batching neural network models
6.3.3 Per-sample gradients
6.3.4 Vectorizing loops
7 Parallelizing your computations
7.1 Parallelizing computations with pmap()
7.1.1 Setting up a problem
7.1.2 Using pmap (almost) like vmap
7.2 Controlling pmap() behavior
7.2.1 Controlling input and output mapping axes
7.2.2 Using named axes and collectives
7.3 Data-parallel neural network training example
7.3.1 Preparing data and neural network structure
7.3.2 Implementing data-parallel training
7.4 Using multihost configurations
8 Using tensor sharding
8.1 Basics of tensor sharding
8.1.1 Device mesh
8.1.2 Positional sharding
8.1.3 An example with 2D mesh
8.1.4 Using replication
8.1.5 Sharding constraints
8.1.6 Named sharding
8.1.7 Device placement policy and errors
8.2 MLP with tensor sharding
8.2.1 Eight-way data parallelism
8.2.2 Four-way data parallelism, two-way tensor parallelism
9 Random numbers in JAX
9.1 Generating random data
9.1.1 Loading the dataset
9.1.2 Generating random noise
9.1.3 Performing a random augmentation
9.2 Differences with NumPy
9.2.1 How NumPy works
9.2.2 Seed and state in NumPy
9.2.3 JAX PRNG
9.2.4 Advanced JAX PRNG configuration
9.3 Generating random numbers in real-life applications
9.3.1 Building a complete data augmentation pipeline
9.3.2 Generating random initializations for a neural network
10 Working with pytrees
10.1 Representing complex data structures as pytrees
10.2 Functions for working with pytrees
10.2.1 Using tree_map()
10.2.2 Flatten/unflatten a pytree
10.2.3 Using tree_reduce()
10.2.4 Transposing a pytree
10.3 Creating custom pytree nodes
Part 3
11 Higher-level neural network libraries
11.1 MNIST image classification using an MLP
11.1.1 MLP in Flax
11.1.2 Optax gradient transformations library
11.1.3 Training a neural network the Flax way
11.2 Image classification using a ResNet
11.2.1 Managing state in Flax
11.2.2 Saving and loading a model using Orbax
11.3 Using the Hugging Face ecosystem
11.3.1 Using a pretrained model from the Hugging Face Model Hub
11.3.2 Going further with fine-tuning and pretraining
11.3.3 Using the diffusers library
12 Other members of the JAX ecosystem
12.1 Deep learning ecosystem
12.1.1 High-level neural network libraries
12.1.2 LLMs in JAX
12.1.3 Utility libraries
12.2 Machine learning modules
12.2.1 Reinforcement learning
12.2.2 Other machine learning libraries
12.3 JAX modules for other fields
A Installing JAX
B Using Google Colab
C Using Google Cloud TPUs
D Experimental parallelization
index