Skip to content

Instantly share code, notes, and snippets.

View Ryu1845's full-sized avatar
🎯
Focusing

Sofian Mejjoute Ryu1845

🎯
Focusing
View GitHub Profile
"""DiM (Diffusion Mixer)."""
import math
import typing
import einops
import torch
class DiMConfig(typing.NamedTuple):
@joey00072
joey00072 / mla.py
Created December 28, 2024 16:25
multi head latent attention (MLA)
# https://x.com/shxf0072/status/1873038335427658011
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections import OrderedDict
from ohara.modules.norm import RMSNorm
@zjlww
zjlww / model.py
Created December 7, 2024 01:39
Stripped AudioCodecModel from NeMo @ bde672e
from typing import Tuple
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer
class AudioCodecModel(nn.Module):
@charlesfrye
charlesfrye / wrapper.py
Last active February 24, 2025 16:16
Train GPT-2 in five minutes -- for free!
# Train GPT-2 in five minutes -- for free
#
# ```bash
# pip install modal
# modal setup
# modal run wrapper.py
# ```
#
# Note that the end-to-end latency the first time is more like 25 minutes:
# - five minutes to install Torch (rip)
@crowsonkb
crowsonkb / ring_attn.py
Created October 10, 2024 16:19
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
def apply_p_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
rope_percentage: float = 1.0,
) -> jax.Array:
"""Applies p-RoPE."""
rope_angles = int(rope_percentage * head_dim // 2)
nope_angles = head_dim // 2 - rope_angles
from typing import Callable
import numpy as np
from tqdm import tqdm
def wsola_chunked_processing(audio: np.ndarray, sr: int, chunk_size: int, hop_size: int, mod_func: Callable[[np.ndarray], np.ndarray]):
# Check if chunk_size is larger than the audio length
if chunk_size >= len(audio):
# Process the entire audio in one go
output = mod_func(audio).squeeze()
@lucidrains
lucidrains / tree_attn_decode.py
Created August 12, 2024 17:48
Tree Attention Decoding
import torch
from torch import einsum
import torch.distributed as dist
def tree_attn_decode(q, k, v):
"""
Algorithm 3 proposed in Tree Attention
https://arxiv.org/abs/2408.04093
"""
""" Computing zeroth matrix powers via Lakic 1998.
paper: "On the Computation of the Matrix k-th Root"
Suppose we have a matrix G = USV^T and we want to compute
G^0 defined via G^0 = UV^T. We might want to do this to run
"stochastic spectral descent" of Carlson et al 2015. The
naive way to do this is via the SVD. But we can also just do
(GG^T)^(-1/2) G or alternatively G (G^TG)^(-1/2) and apply
the iterative method from Lakic 1998.
"""Generates a document causal attention mask based on a document ID tensor"""
from typing import List, Union
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import _mask_mod_signature, or_masks
from attn_gym.masks import causal_mask