Skip to content

Instantly share code, notes, and snippets.

View apivovarov's full-sized avatar

Alexander Pivovarov apivovarov

  • Sunnyvale, CA
  • 01:31 (UTC -07:00)
  • LinkedIn in/pivovaal
View GitHub Profile
@apivovarov
apivovarov / gist:848f7bf1814522f8848e9a1e4012c7d9
Created March 16, 2026 21:59
--xla_gpu_execution_terminate_timeout=20s
W 2026-03-16T14:24:05.274810-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] --- Thread 7fb2bc4eb700 (name: py_xla_execute/5842) stack: ---
stack used: 166 KiB of 2104 KiB
I 2026-03-16T14:24:05.274834-07:00 0 W0316 14:24:05.274810 7452 thread.cc:1808] --- Thread 7fb2bc4eb700 (name: py_xla_execute/5842) stack: ---
I 2026-03-16T14:24:05.274836-07:00 0 stack used: 166 KiB of 2104 KiB
W 2026-03-16T14:24:05.489353-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] @ 0x7fb4e02cc178 _binary_blaze_out_haswell_opt_cuda_genfiles_third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
I 2026-03-16T14:24:05.490090-07:00 0 W0316 14:24:05.489353 7452 thread.cc:1808] @ 0x7fb4e02cc178 _binary_blaze_out_haswell_opt_cuda_genfiles_third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
W 2026-03-16T14:24:05.669908-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] @ 0x7fb4e0e5e960 _binary
@apivovarov
apivovarov / build-xla-with-jax-container.md
Created June 10, 2025 01:20
Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

XLA is a compiler used internally by JAX. JAX is distributed via PyPI wheels. The JAX Continuous Integration documentation explains how to build JAX wheels using the tensorflow/ml-build:latest Docker container.

We can extend these instructions to build XLA targets within the JAX container as well. This ensures that the XLA targets’ build configuration is consistent with the JAX/XLA build configuration, which can be useful if we want to reproduce workload results using XLA tools that were originally created in JAX.

Build XLA Targets in the JAX CI Container

  1. Clone the JAX repository and navigate to the 'jax' directory
int32_t add_int32_using_float32(int32_t a, int32_t b) {
const uint32_t SHIFT = 16;
const uint32_t MASK = 0xFFFF;
int32_t a_high = a >> SHIFT; // Sign-extended
uint32_t a_low = static_cast<uint32_t>(a) & MASK;
int32_t b_high = b >> SHIFT; // Sign-extended
uint32_t b_low = static_cast<uint32_t>(b) & MASK;
// Add low parts (this will fit in float32)
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
// Grok 3 early
int32_t int32_add_using_float(int32_t a, int32_t b) {
// Masks for splitting into high and low 16-bit parts
const int32_t MASK_16BIT = 0xFFFF; // 16-bit mask: 0x0000FFFF
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t add_uint32_using_float32(uint32_t a, uint32_t b) {
// Split the 32-bit numbers into two 16-bit halves (high and low)
const uint32_t mask16 = 0xFFFF; // 16-bit mask
uint32_t a_low = a & mask16; // Lower 16 bits of a
@apivovarov
apivovarov / divide_uint32.cc
Created February 20, 2025 02:04
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) {
// Handle division by zero
uint32_t is_zero_divisor = (divisor == 0);
if (is_zero_divisor) {
@apivovarov
apivovarov / jax_softmax_accuracy.py
Last active February 1, 2025 03:41
jax softmax accuracy
# ======== softmax ========================================
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
def softmax(x):
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
@apivovarov
apivovarov / jax_reduce_accuracy.py
Created January 29, 2025 01:50
JAX reduce op accuracy
# REDUCE
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064
@apivovarov
apivovarov / jax_exp_accuracy.py
Last active January 29, 2025 02:44
jax exp accuracy
# EXP
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[ 4603795724020441080, 4600877961149805680, 13824041971732116924, 13819622319934898176, 4591988256500146912, 13827780906010235892, 4592430392547641040, 13825287425120520724, 13828493035404110830, 4606660095605074408, 4602255938349038484, 4606604461173478372, 13827339776963088236, 4604220211627411202, 13824977737624476168, 4600871068755542572, 4606725295095174736, 4590846528761481504, 13822202805976138512, 13815414477567454288, 4606044597329528470, 4600076899447978620, 4591604889013648656, 4602697894782769988, 13827420058876989058, 4604756101012177124, 4605317067340288710, 13816787606007860464, 4602913948182442482, 4606332577827281932, 4604773657456759696, 13827381634999381006], [13826069199251957758, 4605825876037599016, 4599813349846477404, 4606034169064504034, 13810066424407939872, 4599928033056128564, 4604540450308053286, 4604372500785701914, 45965253622748
@apivovarov
apivovarov / jax_rsqrt_accuracy.py
Last active January 29, 2025 01:22
jax rsqrt accuracy
# RSQRT
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32)
c_np = 1.0 / np.sqrt(a_np)