Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shawwn/0e524d4a7a5d8fb152a86616559cc02a to your computer and use it in GitHub Desktop.
Save shawwn/0e524d4a7a5d8fb152a86616559cc02a to your computer and use it in GitHub Desktop.
JAX C++ stack trace walkthrough for TpuExecutor_Allocate

Twitter thread: https://twitter.com/theshawwn/status/1456925974919004165
Hacker News thread: https://news.ycombinator.com/item?id=29128998

November 6, 2021

How does JAX allocate memory on a TPU?

jnp.device_put(1) is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value 1 actually get onto a TPU?

Turns out, the answer is "C++", and a lot of it.

JAX in TPU mode calls into a library called libtpu. The source code isn't publicly available, but (to my amazement) it has a simple C API, called libtpu.h: https://twitter.com/cdleary/status/1336555074001141760?lang=en

I'll do a more detailed writeup soon, but for now I wanted to record a quick video (the audio is horrible, sorry) showing off every C++ stack frame between the time you call jnp.device_put(1) all the way through the actual memory allocation on the TPU silicon.

python trace for jnp.device_put(1)

First, here's the Python stack trace for jnp.device_put(1), up till it disappears into C++ land:

~/ml/libtpu-python/jaxtest.py:27 in <module> # jnp.device_put(1) happens here
~/ml/jax/jax/_src/api.py:2523 in device_put
~/ml/jax/jax/_src/tree_util.py:178 in tree_map
~/ml/jax/jax/_src/tree_util.py:178 in <genexpr>
~/ml/jax/jax/_src/api.py:2523 in <lambda>
~/ml/jax/jax/core.py:272 in bind
~/ml/jax/jax/core.py:624 in process_primitive
~/ml/jax/jax/interpreters/xla.py:1708 in _device_put_impl
~/ml/jax/jax/interpreters/xla.py:301 in device_put
~/ml/jax/jax/interpreters/xla.py:309 in _device_put_array # then jumps into C++ here

Notice how little info you get. Almost none of the answers are in the Python codebase. At least, answers to the engineering questions I cared about, like "What's going on under the hood?"

C++ trace for jnp.device_put(1)

Here's the rest of the trace you normally can't see. It's the actual code that drives the TPU in production.

Click any frame to see its C++ source code:

  1. PyTreeDef::Unflatten (py::iterable) const
  2. py::object PyTreeDef::UnflattenImpl<py::iterable> (py::iterable) const
  3. PyClient::BufferFromPyval (py::handle, PjRtDevice*, bool, PjRtClient::HostBufferSemantics)
  4. DevicePut (py::handle, PjRtDevice*, const DevicePutOptions &)
  5. HandleNumpyArray (py::handle, PjRtDevice*, const DevicePutOptions &)
  6. PjRtStreamExecutorClient::BufferFromHostBuffer (const void *, const Shape &, PjRtClient::HostBufferSemantics, function<void ()>, PjRtDevice*)
  7. AllocateDestinationBuffer (const Shape &, PjRtDevice*, LocalDeviceState*, stream_executor::Stream*, bool, PjRtClient*, shared_ptr<BufferSequencingEvent>)
  8. TransferManager::AllocateScopedShapedBuffer (const Shape &, stream_executor::DeviceMemoryAllocator*, int, const fn<Shape (Shape &)>)
  9. stream_executor::StreamExecutorMemoryAllocator::Allocate (int, uint64_t, bool, int64_t)
  10. stream_executor::StreamExecutor::Allocate (uint64_t, int64_t)
  11. tensorflow::tpu::TpuExecutor::Allocate (uint64_t, int64_t)
  12. tensorflow::CurrentStackTrace (bool)

It goes 11 function calls deeper into C++! That's a lot of info you normally don't see!

It would be possible to dump a C++ trace for every TPU API function. You'd end up with a "guidebook" to the TPU, since you'd be able to see very clearly what it's doing (and what all the components are). But for now, you'll have to settle for this. :)

The C++ stack trace was generated by compiling a mock libtpu.so. It's totally not ready to show; don't bother trying to build it yet.

(If you really want something to do, this tweet was the starting point, which you can play with yourself: https://twitter.com/theshawwn/status/1455425961432948736)

CLion video

The video was also to show off my custom build of iTerm2 (gnachman/iTerm2#454) which drops me into CLion when I command-click filenames with line numbers. :)

CLion requires a CMakeLists.txt file. I keep one up to date for Tensorflow here, which was generated using cmake-gen-globs. (I have no idea why I named it that, but at least it's reasonably unique. It's also the worst possible script, and I'm sure there are a bunch of nicer ways to do the same thing. But hey, it works, and it requires no dependencies. 🎉

-- Shawn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment