Skip to content

Instantly share code, notes, and snippets.

View apivovarov's full-sized avatar

Alexander Pivovarov apivovarov

  • Sunnyvale, CA
  • 22:06 (UTC -07:00)
  • LinkedIn in/pivovaal
View GitHub Profile
@apivovarov
apivovarov / ra2a_r8_2x4.hlo
Created April 26, 2026 23:10
ra2a_r8_2x4.hlo
HloModule module, input_output_alias={ {}: (1, {}) }, num_partitions=1, replica_count=8
// HloModule module, num_partitions=1, replica_count=8
ENTRY entry {
id = u32[] replica-id()
input = f32[16] parameter(0)
output = f32[16] parameter(1)
send_sizes = s32[8] constant({1,1,1,1,1,1,1,1})
recv_sizes = s32[8] constant({1,1,1,1,1,1,1,1})
input_offsets = s32[8] constant({0, 2, 4, 6, 8, 10, 12, 14})
@apivovarov
apivovarov / SIGILL in tests
Last active March 31, 2026 02:24
executor->GetOrNullResource<se::DeviceInterconnectResource>() causes SIGILL in tests
[ RUN ] RaggedAllToAllTest/RaggedAllToAllTest.RaggedAllToAll_2GPUs_CommandBuffer/sync_one_shot_with_multi_gpu_barrier_with_nccl
*** SIGILL (UD1@0x7f9c00625cc9), see go/stacktraces#debugging-ubsan
received by PID 102457 (TID 104591) on cpu 145; stack trace: ***
E0330 19:21:38.081043 104590 process_state.cc:1066] RAW: Signal 4 raised at PC: 0x7f9c00625cc9 while already in FailureSignalHandler!
E0330 19:21:38.081811 104590 process_state.cc:1070] RAW: tid: 104590 raised new signal (old_tid: 104591)
PC: @ 0x7f9c00625cc9 (unknown) xla::gpu::RaggedAllToAllThunk::RunCollective()
@ 0x7f92d4008f08 1904 FailureSignalHandler()
@ 0x7f9d30337c60 (unknown) (unknown)
@ 0x7f9bf5215d8c 288 xla::gpu::CollectiveThunk::ExecuteOnStream()
@ 0x7f9bfa60d618 464 xla::gpu::ThunkExecutor::ExecuteOnStream()
@apivovarov
apivovarov / gist:848f7bf1814522f8848e9a1e4012c7d9
Created March 16, 2026 21:59
--xla_gpu_execution_terminate_timeout=20s
W 2026-03-16T14:24:05.274810-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] --- Thread 7fb2bc4eb700 (name: py_xla_execute/5842) stack: ---
stack used: 166 KiB of 2104 KiB
I 2026-03-16T14:24:05.274834-07:00 0 W0316 14:24:05.274810 7452 thread.cc:1808] --- Thread 7fb2bc4eb700 (name: py_xla_execute/5842) stack: ---
I 2026-03-16T14:24:05.274836-07:00 0 stack used: 166 KiB of 2104 KiB
W 2026-03-16T14:24:05.489353-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] @ 0x7fb4e02cc178 _binary_blaze_out_haswell_opt_cuda_genfiles_third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
I 2026-03-16T14:24:05.490090-07:00 0 W0316 14:24:05.489353 7452 thread.cc:1808] @ 0x7fb4e02cc178 _binary_blaze_out_haswell_opt_cuda_genfiles_third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
W 2026-03-16T14:24:05.669908-07:00 7452 tf_gpu-executable-hang-watchdog/7452 thread/thread.cc:1808] @ 0x7fb4e0e5e960 _binary
@apivovarov
apivovarov / build-xla-with-jax-container.md
Created June 10, 2025 01:20
Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

XLA is a compiler used internally by JAX. JAX is distributed via PyPI wheels. The JAX Continuous Integration documentation explains how to build JAX wheels using the tensorflow/ml-build:latest Docker container.

We can extend these instructions to build XLA targets within the JAX container as well. This ensures that the XLA targets’ build configuration is consistent with the JAX/XLA build configuration, which can be useful if we want to reproduce workload results using XLA tools that were originally created in JAX.

Build XLA Targets in the JAX CI Container

  1. Clone the JAX repository and navigate to the 'jax' directory
int32_t add_int32_using_float32(int32_t a, int32_t b) {
const uint32_t SHIFT = 16;
const uint32_t MASK = 0xFFFF;
int32_t a_high = a >> SHIFT; // Sign-extended
uint32_t a_low = static_cast<uint32_t>(a) & MASK;
int32_t b_high = b >> SHIFT; // Sign-extended
uint32_t b_low = static_cast<uint32_t>(b) & MASK;
// Add low parts (this will fit in float32)
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
// Grok 3 early
int32_t int32_add_using_float(int32_t a, int32_t b) {
// Masks for splitting into high and low 16-bit parts
const int32_t MASK_16BIT = 0xFFFF; // 16-bit mask: 0x0000FFFF
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t add_uint32_using_float32(uint32_t a, uint32_t b) {
// Split the 32-bit numbers into two 16-bit halves (high and low)
const uint32_t mask16 = 0xFFFF; // 16-bit mask
uint32_t a_low = a & mask16; // Lower 16 bits of a
@apivovarov
apivovarov / divide_uint32.cc
Created February 20, 2025 02:04
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) {
// Handle division by zero
uint32_t is_zero_divisor = (divisor == 0);
if (is_zero_divisor) {
@apivovarov
apivovarov / jax_softmax_accuracy.py
Last active February 1, 2025 03:41
jax softmax accuracy
# ======== softmax ========================================
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
def softmax(x):
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
@apivovarov
apivovarov / jax_reduce_accuracy.py
Created January 29, 2025 01:50
JAX reduce op accuracy
# REDUCE
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064