Twitter thread: https://twitter.com/theshawwn/status/1456925974919004165
Hacker News thread: https://news.ycombinator.com/item?id=29128998
November 6, 2021
jnp.device_put(1)
is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value 1
actually get onto a TPU?
Turns out, the answer is "C++", and a lot of it.
JAX in TPU mode calls into a library called libtpu. The source code isn't publicly available, but (to my amazement) it has a simple C API, called libtpu.h: https://twitter.com/cdleary/status/1336555074001141760?lang=en
I'll do a more detailed writeup soon, but for now I wanted to record a quick video (the audio is horrible, sorry) showing off every C++ stack frame between the time you call jnp.device_put(1)
all the way through the actual memory allocation on the TPU silicon.
First, here's the Python stack trace for jnp.device_put(1)
, up till it disappears into C++ land:
~/ml/libtpu-python/jaxtest.py:27 in <module> # jnp.device_put(1) happens here
~/ml/jax/jax/_src/api.py:2523 in device_put
~/ml/jax/jax/_src/tree_util.py:178 in tree_map
~/ml/jax/jax/_src/tree_util.py:178 in <genexpr>
~/ml/jax/jax/_src/api.py:2523 in <lambda>
~/ml/jax/jax/core.py:272 in bind
~/ml/jax/jax/core.py:624 in process_primitive
~/ml/jax/jax/interpreters/xla.py:1708 in _device_put_impl
~/ml/jax/jax/interpreters/xla.py:301 in device_put
~/ml/jax/jax/interpreters/xla.py:309 in _device_put_array # then jumps into C++ here
Notice how little info you get. Almost none of the answers are in the Python codebase. At least, answers to the engineering questions I cared about, like "What's going on under the hood?"
Here's the rest of the trace you normally can't see. It's the actual code that drives the TPU in production.
Click any frame to see its C++ source code:
PyTreeDef::Unflatten (py::iterable) const
py::object PyTreeDef::UnflattenImpl<py::iterable> (py::iterable) const
PyClient::BufferFromPyval (py::handle, PjRtDevice*, bool, PjRtClient::HostBufferSemantics)
DevicePut (py::handle, PjRtDevice*, const DevicePutOptions &)
HandleNumpyArray (py::handle, PjRtDevice*, const DevicePutOptions &)
PjRtStreamExecutorClient::BufferFromHostBuffer (const void *, const Shape &, PjRtClient::HostBufferSemantics, function<void ()>, PjRtDevice*)
AllocateDestinationBuffer (const Shape &, PjRtDevice*, LocalDeviceState*, stream_executor::Stream*, bool, PjRtClient*, shared_ptr<BufferSequencingEvent>)
TransferManager::AllocateScopedShapedBuffer (const Shape &, stream_executor::DeviceMemoryAllocator*, int, const fn<Shape (Shape &)>)
stream_executor::StreamExecutorMemoryAllocator::Allocate (int, uint64_t, bool, int64_t)
stream_executor::StreamExecutor::Allocate (uint64_t, int64_t)
tensorflow::tpu::TpuExecutor::Allocate (uint64_t, int64_t)
tensorflow::CurrentStackTrace (bool)
It goes 11 function calls deeper into C++! That's a lot of info you normally don't see!
It would be possible to dump a C++ trace for every TPU API function. You'd end up with a "guidebook" to the TPU, since you'd be able to see very clearly what it's doing (and what all the components are). But for now, you'll have to settle for this. :)
The C++ stack trace was generated by compiling a mock libtpu.so. It's totally not ready to show; don't bother trying to build it yet.
(If you really want something to do, this tweet was the starting point, which you can play with yourself: https://twitter.com/theshawwn/status/1455425961432948736)
The video was also to show off my custom build of iTerm2 (gnachman/iTerm2#454) which drops me into CLion when I command-click filenames with line numbers. :)
CLion requires a CMakeLists.txt file. I keep one up to date for Tensorflow here, which was generated using cmake-gen-globs
. (I have no idea why I named it that, but at least it's reasonably unique. It's also the worst possible script, and I'm sure there are a bunch of nicer ways to do the same thing. But hey, it works, and it requires no dependencies. 🎉
-- Shawn
- @theshawwn on twitter: https://twitter.com/theshawwn
- shawwn on github: https://github.com/shawwn
- sillysaurusx on hn: https://news.ycombinator.com/threads?id=sillysaurusx
- support me on Patreon: https://www.patreon.com/shawwn