This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
uv venv --python=3.12 --managed-python | |
source .venv/bin/activate | |
uv pip install --group dev | |
uv pip install ninja # or sudo apt install ninja-build | |
USE_DISTRIBUTED=0 USE_MKLDNN=0 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 uv pip install --no-build-isolation -v -e . |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor | |
def varlen_attn( | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
cum_seq_q: Tensor, | |
cum_seq_k: Tensor, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
print(torch.__version__) | |
group_size = 32 | |
w = torch.randn(512, 1024) | |
w_groups = w.unflatten(1, (-1, group_size)) | |
min_val = w_groups.amin(2, keepdim=True) | |
max_val = w_groups.amax(2, keepdim=True) | |
scale = (max_val - min_val) / 15 # scale (max-min) to 15 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor, nn | |
from tqdm import tqdm | |
class PerLayerOffloadWithBackwardGradient: | |
"This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer." | |
def __init__( | |
self, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from diffusers import FluxPipeline | |
from torch import nn | |
class ModelOffloaderV2: | |
def __init__(self, model: nn.Module, record_stream: bool = False): | |
# move model to pinned memory. keep a model copy in CPU pinned memory. | |
for p in model.parameters(): | |
p.data = p.data.cpu().pin_memory() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
from torch import Tensor | |
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) | |
configs = [ | |
(128, 256, 64, 3, 8), | |
(64, 256, 32, 4, 4), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import torch | |
def load_audio(path: str, sample_rate: int) -> torch.Tensor: | |
cmd = f"{FFMPEG_PATH} -i {path} -ar {sample_rate} -ac 1 -f s32le -" | |
proc = subprocess.run(shlex.split(cmd), capture_output=True) | |
if proc.returncode: | |
raise RuntimeError(proc.stderr.decode()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Modified from https://github.com/ppwwyyxx/RAM-multiprocess-dataloader | |
class PyTorchStrList: | |
def __init__(self, items: list[str]): | |
data = [torch.frombuffer(x.encode(), dtype=torch.uint8) for x in items] | |
lengths = [0] + [x.shape[0] for x in data] | |
self.data = torch.cat(data, 0) | |
self.index = torch.tensor(lengths).cumsum_(0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from playwright.sync_api import sync_playwright | |
import requests | |
import re | |
import json | |
import os | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from tqdm import tqdm | |
import argparse |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import IntEnum | |
class FieldType(IntEnum): | |
BYTE = 1 | |
ASCII = 2 | |
SHORT = 3 | |
LONG = 4 | |
RATIONAL = 5 |
NewerOlder