Skip to content

Instantly share code, notes, and snippets.

View YouJiacheng's full-sized avatar

You Jiacheng YouJiacheng

  • IIIS, Tsinghua University
  • Beijing, China
View GitHub Profile
from functools import partial
import jax
import jax.numpy as jnp
import optax
def poly(x: jnp.ndarray, w: jnp.ndarray):
assert w.shape == (3,)
w = w.astype(jnp.float32)
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import uuid
from dataclasses import dataclass
from functools import lru_cache, partial
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import uuid
from dataclasses import dataclass
from functools import lru_cache, partial
import os
import sys
from typing import override
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import contextlib
import time
import uuid
from dataclasses import dataclass
@dataclass
class Args:
vocab_size: int = 129280
dim: int = 7168
inter_dim: int = 18432
moe_inter_dim: int = 2048
n_layers: int = 61
def abs_cdf(t: Tensor, thresholds: list[float]):
t = t.abs()
level = torch.bucketize(t, t.new_tensor(thresholds), out_int32=True) # sum(x > v for v in thresholds)
return level.flatten().bincount(minlength=len(thresholds) + 1).cumsum(0) / t.numel()
# reference: https://github.com/pytorch/pytorch/issues/69519#issuecomment-2500366519
def histogram(input: Tensor, bins: Tensor, *, weight: Optional[Tensor] = None, density: bool = False):
bucket_indices = torch.bucketize(input, bins)
counts = torch.bincount(bucket_indices, weights=weight, minlength=bins.size(0)+1)
counts = counts[1:-1]
@YouJiacheng
YouJiacheng / rope_shift.py
Created November 24, 2024 16:57
rope shift
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(
self,
dim,
max_seq_len: int = 4096,
import os
import sys
import torch._dynamo.compiled_autograd
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
import torch
import torch.utils.benchmark as benchmark
def benchmark_in_us(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
mesh_1d = init_device_mesh("cuda", (4,), mesh_dim_names=("shard",))
rank = mesh_1d.get_rank()
dtensors: list[DTensor] = []