Skip to content

Instantly share code, notes, and snippets.

import argparse
import ale_py
import cv2
import gymnasium as gym
import minari
import minigrid
import numpy as np
import sb3_contrib
import shortuuid
@carlosgmartin
carlosgmartin / hungarian_algorithm.py
Last active November 20, 2024 04:46
Comparison of JAX implementations of the Hungarian algorithm
from timeit import default_timer
import jax
import optax
import tqdm
from jax import lax, numpy as jnp, random
@jax.jit
def scenic_hungarian_algorithm(cost):
"""
Improving the lowering and compilation of unrolled lax.scan loops
https://github.com/jax-ml/jax/discussions/25336
"""
import argparse
import functools
import jax
from jax import lax, numpy as jnp, random
@carlosgmartin
carlosgmartin / jax_saturating_arithmetic.py
Last active April 27, 2025 05:20
Implementation of saturating arithmetic for JAX
"""
https://github.com/jax-ml/jax/issues/26566
"""
import itertools
import operator
import jax
from jax import numpy as jnp
from tqdm import tqdm