Twitter thread: https://twitter.com/theshawwn/status/1456925974919004165
Hacker News thread: https://news.ycombinator.com/item?id=29128998
November 6, 2021
jnp.device_put(1)
is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value 1
actually get onto a TPU?
Turns out, the answer is "C++", and a lot of it.