Skip to content

Instantly share code, notes, and snippets.

View KeAWang's full-sized avatar

Alex Wang KeAWang

View GitHub Profile
@KeAWang
KeAWang / tree_stack.py
Last active June 5, 2023 16:05 — forked from willwhitney/tree_stack.py
utils for stacking and unstacking jax pytrees to deal with vmap
import jax
import jax.numpy as jnp
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
@KeAWang
KeAWang / mfu_compute.py
Created April 11, 2024 17:17 — forked from Chillee/mfu_compute.py
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@KeAWang
KeAWang / gpt-2-wikitext-103.py
Created September 23, 2024 20:23 — forked from thomwolf/gpt-2-wikitext-103.py
A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103
# Copyright (c) 2019-present, Thomas Wolf.
# All rights reserved. This source code is licensed under the MIT-style license.
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """
import os
from collections import namedtuple
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events