We look at language models parametrized by neural networks, and how they’re capable of near transfer, generalizing to sequences similar to (but not the exact same) as those in their training sets.

jax.pjit

With proper sharding, LLMs can scale far beyond the memory capacity of a single GPU/TPU. We explore the math underpinning this from the ground up, and then implement a fully working implementation with JAX/Flax.

Spherical Harmonics are a core building block of Equivariant Neural Networks. This post breaks them down by analyzing them as 3D extensions of the Fourier Series.

Everyday human reasoning breaks down as the scale of time and space at play increases. Complex systems thinking gives us a new set of tools to better understand the chains of consequences involved, and make better decisions.

CUDA has a hierarchical programming model, requiring thought at the level of Grids, Blocks and Threads. We explore this directly, using our understanding to write a simple GPU accelerated addition kernel from scratch.