Skip to content

Instantly share code, notes, and snippets.

View evanatyourservice's full-sized avatar

Evan Walters evanatyourservice

  • Denver, CO
View GitHub Profile
@evanatyourservice
evanatyourservice / hellaswag_jax.py
Last active September 15, 2024 00:54
How to prepare and evaluate on hellaswag in JAX
import json
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
import optax.tree_utils as otu
import tensorflow as tf
@evanatyourservice
evanatyourservice / Beta-TCVAE in JAX Flax
Created January 1, 2024 20:10
Beta-TCVAE in JAX Flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
"""
There's a typo in most B-TCVAE implementations on github, so I thought I'd make a
quick gist of a working B-TCVAE.