Skip to content

Instantly share code, notes, and snippets.

View apivovarov's full-sized avatar

Alexander Pivovarov apivovarov

  • Amazon Web Services
  • Santa Clara, CA
  • 10:12 (UTC -08:00)
  • LinkedIn in/pivovaal
View GitHub Profile
int32_t add_int32_using_float32(int32_t a, int32_t b) {
const uint32_t SHIFT = 16;
const uint32_t MASK = 0xFFFF;
int32_t a_high = a >> SHIFT; // Sign-extended
uint32_t a_low = static_cast<uint32_t>(a) & MASK;
int32_t b_high = b >> SHIFT; // Sign-extended
uint32_t b_low = static_cast<uint32_t>(b) & MASK;
// Add low parts (this will fit in float32)
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
// Grok 3 early
int32_t int32_add_using_float(int32_t a, int32_t b) {
// Masks for splitting into high and low 16-bit parts
const int32_t MASK_16BIT = 0xFFFF; // 16-bit mask: 0x0000FFFF
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t add_uint32_using_float32(uint32_t a, uint32_t b) {
// Split the 32-bit numbers into two 16-bit halves (high and low)
const uint32_t mask16 = 0xFFFF; // 16-bit mask
uint32_t a_low = a & mask16; // Lower 16 bits of a
@apivovarov
apivovarov / divide_uint32.cc
Created February 20, 2025 02:04
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) {
// Handle division by zero
uint32_t is_zero_divisor = (divisor == 0);
if (is_zero_divisor) {
@apivovarov
apivovarov / jax_softmax_accuracy.py
Last active February 1, 2025 03:41
jax softmax accuracy
# ======== softmax ========================================
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
def softmax(x):
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
@apivovarov
apivovarov / jax_reduce_accuracy.py
Created January 29, 2025 01:50
JAX reduce op accuracy
# REDUCE
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064
@apivovarov
apivovarov / jax_exp_accuracy.py
Last active January 29, 2025 02:44
jax exp accuracy
# EXP
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[ 4603795724020441080, 4600877961149805680, 13824041971732116924, 13819622319934898176, 4591988256500146912, 13827780906010235892, 4592430392547641040, 13825287425120520724, 13828493035404110830, 4606660095605074408, 4602255938349038484, 4606604461173478372, 13827339776963088236, 4604220211627411202, 13824977737624476168, 4600871068755542572, 4606725295095174736, 4590846528761481504, 13822202805976138512, 13815414477567454288, 4606044597329528470, 4600076899447978620, 4591604889013648656, 4602697894782769988, 13827420058876989058, 4604756101012177124, 4605317067340288710, 13816787606007860464, 4602913948182442482, 4606332577827281932, 4604773657456759696, 13827381634999381006], [13826069199251957758, 4605825876037599016, 4599813349846477404, 4606034169064504034, 13810066424407939872, 4599928033056128564, 4604540450308053286, 4604372500785701914, 45965253622748
@apivovarov
apivovarov / jax_rsqrt_accuracy.py
Last active January 29, 2025 01:22
jax rsqrt accuracy
# RSQRT
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32)
c_np = 1.0 / np.sqrt(a_np)
@apivovarov
apivovarov / jax_matmul_accuracy.py
Created January 24, 2025 22:42
JAX matmul accuracy
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
shape = (16,16)
a_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
b_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
c_np = a_np @ b_np
#include <exception>
#include <string>
#include <sstream>
#include <execinfo.h>
#include <cxxabi.h>
#include <vector>
class ExceptionWithStackTrace : public std::exception {
public:
explicit ExceptionWithStackTrace(const std::string& message)