This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2019-present, Thomas Wolf. | |
# All rights reserved. This source code is licensed under the MIT-style license. | |
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """ | |
import os | |
from collections import namedtuple | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from ignite.engine import Engine, Events |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() | |
ms_per_iter = do_bench(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import torch | |
import numpy as np | |
def make_adding_dataset(num_seqs, seq_len, num_terms=2, seed=43141): | |
assert 0 <= num_terms <= seq_len | |
rng = np.random.default_rng(seed=seed) | |
numbers = rng.uniform(0, 1, (num_seqs, seq_len)) # B x T | |
mask = np.zeros_like(numbers) # B x T | |
non_zero = np.stack([rng.choice(seq_len, num_terms, replace=True) for _ in range(num_seqs)]) # B x 2 | |
mask[np.arange(num_seqs)[:, None], non_zero] = 1 # mask[i, non_zero[i, j]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class NanWrapper(torch.nn.Module): | |
"""Wrapper module around a torch Module that handles incoming nans""" | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, x): | |
""" Masks the entire last dimension (usually the feature/channel dimension) if any element is NaN. """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def count_params(model: torch.nn.Module): | |
"""count number trainable parameters in a pytorch model""" | |
total_params = sum(torch.numel(x) for x in model.parameters()) | |
return total_params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def constrain(x, min, max, temperature:float=1.): | |
return (max - min) * torch.sigmoid(x / temperature) + min | |
def unconstrain(y, min, max, temperature:float=1, EPS:float=1e-8): | |
assert torch.all(y >= min) and torch.all(y <= max) | |
# ensure both numerator and denominator are positive | |
numerator = y - min |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
tir_palette = { | |
"very_low": "#A61D2A", | |
"low": "#EE1D23", | |
"in_range": "#26B257", | |
"high": "#FAAB1A", | |
"very_high": "#F47D21" | |
} | |
def color_bg(bgs): | |
bins = [0, 54, 70, 180, 250, 1000] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from typing import Optional, List | |
def array_to_dataframe(array, axis_names: Optional[List[str]]=None): | |
"""Based on https://stackoverflow.com/questions/35525028/how-to-transform-a-3d-arrays-into-a-dataframe-in-python""" | |
if axis_names is None: | |
axis_names = list(range(array.ndim)) | |
index = pd.MultiIndex.from_product([range(s) for s in array.shape], names=names) | |
df = pd.DataFrame({"array": array.flatten()}, index=index)["array"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import jax.numpy as jnp | |
import jax | |
import equinox as eqx | |
from typing import Union, Any | |
from abc import ABC, abstractmethod | |
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
class Data(namedtuple("Data", ("x", "y"))): | |
def to(self, device, non_blocking=False): | |
x = self.x.to(device, non_blocking=non_blocking) | |
y = self.y.to(device, non_blocking=non_blocking) | |
return Data(x, y) | |
def contiguous(self): |
NewerOlder