Deep Learning with JAX
English
By (author): Grigory Sapunov
Accelerate deep learning and other number-intensive tasks with JAX, Googles awesome high-performance numerical computing library.
In Deep Learning with JAX you will learn how to:
- Use JAX for numerical calculations
- Build differentiable models with JAX primitives
- Run distributed and parallelized computations with JAX
- Use high-level neural network libraries such as Flax and Haiku
- Leverage libraries and modules from the JAX ecosystem
The JAX numerical computing library tackles the core performance challenges at the heart of deep learning and other scientific computing tasks. By combining Googles Accelerated Linear Algebra platform (XLA) with a hyper-optimized version of NumPy and a variety of other high-performance features, JAX delivers a huge performance boost in low-level computations and transformations.
Deep Learning with JAX is a hands-on guide to using JAX for deep learning and other mathematically-intensive applications. Google Developer Expert Grigory Sapunov steadily builds your understanding of JAXs concepts. The engaging examples introduce the fundamental concepts on which JAX relies and then show you how to apply them to real-world tasks. Youll learn how to use JAXs ecosystem of high-level libraries and modules, and also how to combine TensorFlow and PyTorch with JAX for data loading and deployment.