Skip to content

Instantly share code, notes, and snippets.

chu-tianxiang /
Created March 20, 2024 04:40
Convert grok-1 weight to torch
import numpy as np
import torch
import jax
from tqdm import tqdm
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints"
chu-tianxiang /
Last active November 22, 2024 21:18
triton implementation of ReRope
# Adapted from the triton implementation of flash-attention v2
import time
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
chu-tianxiang /
Created August 31, 2023 10:44
triton implementation of ReRope forward pass
import time
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
def _fwd_kernel(
Q1, Q2, K1, K2, V, sm_scale,