This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
int32_t add_int32_using_float32(int32_t a, int32_t b) { | |
const uint32_t SHIFT = 16; | |
const uint32_t MASK = 0xFFFF; | |
int32_t a_high = a >> SHIFT; // Sign-extended | |
uint32_t a_low = static_cast<uint32_t>(a) & MASK; | |
int32_t b_high = b >> SHIFT; // Sign-extended | |
uint32_t b_low = static_cast<uint32_t>(b) & MASK; | |
// Add low parts (this will fit in float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdint> | |
#include <limits> | |
#include <vector> | |
#include <utility> | |
// Grok 3 early | |
int32_t int32_add_using_float(int32_t a, int32_t b) { | |
// Masks for splitting into high and low 16-bit parts | |
const int32_t MASK_16BIT = 0xFFFF; // 16-bit mask: 0x0000FFFF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdint> | |
#include <limits> | |
#include <vector> | |
#include <utility> | |
uint32_t add_uint32_using_float32(uint32_t a, uint32_t b) { | |
// Split the 32-bit numbers into two 16-bit halves (high and low) | |
const uint32_t mask16 = 0xFFFF; // 16-bit mask | |
uint32_t a_low = a & mask16; // Lower 16 bits of a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdint> | |
#include <limits> | |
#include <vector> | |
#include <utility> | |
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) { | |
// Handle division by zero | |
uint32_t is_zero_divisor = (divisor == 0); | |
if (is_zero_divisor) { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ======== softmax ======================================== | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
def softmax(x): | |
x_max = np.max(x, axis=-1, keepdims=True) | |
exp_x = np.exp(x - x_max) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# REDUCE | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# EXP | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
a_np=np.array([[ 4603795724020441080, 4600877961149805680, 13824041971732116924, 13819622319934898176, 4591988256500146912, 13827780906010235892, 4592430392547641040, 13825287425120520724, 13828493035404110830, 4606660095605074408, 4602255938349038484, 4606604461173478372, 13827339776963088236, 4604220211627411202, 13824977737624476168, 4600871068755542572, 4606725295095174736, 4590846528761481504, 13822202805976138512, 13815414477567454288, 4606044597329528470, 4600076899447978620, 4591604889013648656, 4602697894782769988, 13827420058876989058, 4604756101012177124, 4605317067340288710, 13816787606007860464, 4602913948182442482, 4606332577827281932, 4604773657456759696, 13827381634999381006], [13826069199251957758, 4605825876037599016, 4599813349846477404, 4606034169064504034, 13810066424407939872, 4599928033056128564, 4604540450308053286, 4604372500785701914, 45965253622748 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# RSQRT | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32) | |
c_np = 1.0 / np.sqrt(a_np) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
shape = (16,16) | |
a_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32) | |
b_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32) | |
c_np = a_np @ b_np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <exception> | |
#include <string> | |
#include <sstream> | |
#include <execinfo.h> | |
#include <cxxabi.h> | |
#include <vector> | |
class ExceptionWithStackTrace : public std::exception { | |
public: | |
explicit ExceptionWithStackTrace(const std::string& message) |
NewerOlder