Skip to content

Instantly share code, notes, and snippets.

View apivovarov's full-sized avatar

Alexander Pivovarov apivovarov

  • Sunnyvale, CA
  • 18:34 (UTC -07:00)
  • LinkedIn in/pivovaal
View GitHub Profile
@apivovarov
apivovarov / build-xla-with-jax-container.md
Created June 10, 2025 01:20
Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

Build XLA with CUDA/cuDNN Support Using the JAX CI/Release Container

XLA is a compiler used internally by JAX. JAX is distributed via PyPI wheels. The JAX Continuous Integration documentation explains how to build JAX wheels using the tensorflow/ml-build:latest Docker container.

We can extend these instructions to build XLA targets within the JAX container as well. This ensures that the XLA targets’ build configuration is consistent with the JAX/XLA build configuration, which can be useful if we want to reproduce workload results using XLA tools that were originally created in JAX.

Build XLA Targets in the JAX CI Container

  1. Clone the JAX repository and navigate to the 'jax' directory
int32_t add_int32_using_float32(int32_t a, int32_t b) {
const uint32_t SHIFT = 16;
const uint32_t MASK = 0xFFFF;
int32_t a_high = a >> SHIFT; // Sign-extended
uint32_t a_low = static_cast<uint32_t>(a) & MASK;
int32_t b_high = b >> SHIFT; // Sign-extended
uint32_t b_low = static_cast<uint32_t>(b) & MASK;
// Add low parts (this will fit in float32)
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
// Grok 3 early
int32_t int32_add_using_float(int32_t a, int32_t b) {
// Masks for splitting into high and low 16-bit parts
const int32_t MASK_16BIT = 0xFFFF; // 16-bit mask: 0x0000FFFF
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t add_uint32_using_float32(uint32_t a, uint32_t b) {
// Split the 32-bit numbers into two 16-bit halves (high and low)
const uint32_t mask16 = 0xFFFF; // 16-bit mask
uint32_t a_low = a & mask16; // Lower 16 bits of a
@apivovarov
apivovarov / divide_uint32.cc
Created February 20, 2025 02:04
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) {
// Handle division by zero
uint32_t is_zero_divisor = (divisor == 0);
if (is_zero_divisor) {
@apivovarov
apivovarov / jax_softmax_accuracy.py
Last active February 1, 2025 03:41
jax softmax accuracy
# ======== softmax ========================================
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
def softmax(x):
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
@apivovarov
apivovarov / jax_reduce_accuracy.py
Created January 29, 2025 01:50
JAX reduce op accuracy
# REDUCE
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064
@apivovarov
apivovarov / jax_exp_accuracy.py
Last active January 29, 2025 02:44
jax exp accuracy
# EXP
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[ 4603795724020441080, 4600877961149805680, 13824041971732116924, 13819622319934898176, 4591988256500146912, 13827780906010235892, 4592430392547641040, 13825287425120520724, 13828493035404110830, 4606660095605074408, 4602255938349038484, 4606604461173478372, 13827339776963088236, 4604220211627411202, 13824977737624476168, 4600871068755542572, 4606725295095174736, 4590846528761481504, 13822202805976138512, 13815414477567454288, 4606044597329528470, 4600076899447978620, 4591604889013648656, 4602697894782769988, 13827420058876989058, 4604756101012177124, 4605317067340288710, 13816787606007860464, 4602913948182442482, 4606332577827281932, 4604773657456759696, 13827381634999381006], [13826069199251957758, 4605825876037599016, 4599813349846477404, 4606034169064504034, 13810066424407939872, 4599928033056128564, 4604540450308053286, 4604372500785701914, 45965253622748
@apivovarov
apivovarov / jax_rsqrt_accuracy.py
Last active January 29, 2025 01:22
jax rsqrt accuracy
# RSQRT
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32)
c_np = 1.0 / np.sqrt(a_np)
@apivovarov
apivovarov / jax_matmul_accuracy.py
Created January 24, 2025 22:42
JAX matmul accuracy
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
shape = (16,16)
a_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
b_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
c_np = a_np @ b_np