This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2019-present, Thomas Wolf. | |
# All rights reserved. This source code is licensed under the MIT-style license. | |
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """ | |
import os | |
from collections import namedtuple | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from ignite.engine import Engine, Events |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() | |
ms_per_iter = do_bench(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def tree_stack(trees): | |
"""Takes a list of trees and stacks every corresponding leaf. | |
For example, given two trees ((a, b), c) and ((a', b'), c'), returns | |
((stack(a, a'), stack(b, b')), stack(c, c')). | |
Useful for turning a list of objects into something you can feed to a | |
vmapped function. |