This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import torch.utils.benchmark as benchmark | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def _fwd_kernel( | |
Q1, Q2, K1, K2, V, sm_scale, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from the triton implementation of flash-attention v2 | |
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py | |
import time | |
import torch | |
import torch.utils.benchmark as benchmark | |
import triton | |
import triton.language as tl | |
@triton.jit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import jax | |
from tqdm import tqdm | |
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit | |
from runners import InferenceRunner, ModelRunner, sample_from_model | |
CKPT_PATH = "./checkpoints" |