Skip to content

Instantly share code, notes, and snippets.

View YouJiacheng's full-sized avatar

You Jiacheng YouJiacheng

  • IIIS, Tsinghua University
  • Beijing, China
View GitHub Profile
from dataclasses import dataclass
@dataclass
class Args:
vocab_size: int = 129280
dim: int = 7168
inter_dim: int = 18432
moe_inter_dim: int = 2048
n_layers: int = 61
def abs_cdf(t: Tensor, thresholds: list[float]):
t = t.abs()
level = torch.bucketize(t, t.new_tensor(thresholds), out_int32=True) # sum(x > v for v in thresholds)
return level.flatten().bincount(minlength=len(thresholds) + 1).cumsum(0) / t.numel()
# reference: https://github.com/pytorch/pytorch/issues/69519#issuecomment-2500366519
def histogram(input: Tensor, bins: Tensor, *, weight: Optional[Tensor] = None, density: bool = False):
bucket_indices = torch.bucketize(input, bins)
counts = torch.bincount(bucket_indices, weights=weight, minlength=bins.size(0)+1)
counts = counts[1:-1]
@YouJiacheng
YouJiacheng / rope_shift.py
Created November 24, 2024 16:57
rope shift
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(
self,
dim,
max_seq_len: int = 4096,
import os
import sys
import torch._dynamo.compiled_autograd
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
import torch
import torch.utils.benchmark as benchmark
def benchmark_in_us(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
mesh_1d = init_device_mesh("cuda", (4,), mesh_dim_names=("shard",))
rank = mesh_1d.get_rank()
dtensors: list[DTensor] = []
import os
from typing import cast
import torch
import torch._inductor.config as config
import torch.distributed as dist
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7) -> torch.Tensor: ...
import os
import base64
# dynamically generated at test time
A_identifier = base64.urlsafe_b64encode(os.urandom(6)).decode()
B_identifier = base64.urlsafe_b64encode(os.urandom(6)).decode()
meta_prompt = f"""
@YouJiacheng
YouJiacheng / f-str.py
Last active September 26, 2024 14:34
from dataclasses import dataclass
class Default(dict[str, str]):
def __missing__(self, key: str):
return f"{{{key}}}"
@dataclass
class Pair:
INFO global: Vagrant version: 2.3.3
INFO global: Ruby version: 2.7.6
INFO global: RubyGems version: 3.1.6
INFO global: VAGRANT_EXECUTABLE="C:\\HashiCorp\\Vagrant\\embedded\\gems\\2.3.3\\gems\\vagrant-2.3.3\\bin\\vagrant"
INFO global: VAGRANT_INSTALLER_EMBEDDED_DIR="C:\\HashiCorp\\Vagrant\\embedded"
INFO global: VAGRANT_INSTALLER_ENV="1"
INFO global: VAGRANT_INSTALLER_VERSION="2"
INFO global: VAGRANT_LOG="debug"
WARN global: resolv replacement has not been enabled!
DEBUG global: Loading core plugin: C:/HashiCorp/Vagrant/embedded/gems/2.3.3/gems/vagrant-2.3.3/plugins/commands/autocomplete/plugi