Skip to content

Instantly share code, notes, and snippets.

View Ryu1845's full-sized avatar
🎯
Focusing

Sofian Mejjoute Ryu1845

🎯
Focusing
View GitHub Profile
@Birch-san
Birch-san / _06_fused_attention_blockptr_jvp.py
Last active June 29, 2025 17:08
Triton fused attention tutorial, updated with JVP support. Albeit with atol=1e-3 accuracy on JVP.
from __future__ import annotations
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
@aksh-at
aksh-at / modal_quic_hole_punch.py
Last active May 4, 2025 12:51
Modal QUIC NAT hole-punching
"""
Proof-of-concept for NAT traversal and low-latency communication over QUIC
between two Modal containers.
In theory this could be used to establish a low-latency p2p connection between a
service running outside Modal and a Modal GPU container, e.g. for real-time
inference on a video stream. Please let us know if you try it!
Usage:
> modal run modal_quic_hole_punch.py
@tysam-code
tysam-code / diloco_nesterov_.7lr_.0_to_.9_momentum_1000_momentum_warmup_1-momentum_dampening_dampening_initial_step_bugfix_25_steps_all_run3.log
Created April 30, 2025 00:56
diloco_nesterov_.7lr_.0_to_.9_momentum_1000_momentum_warmup_1-momentum_dampening_dampening_initial_step_bugfix_25_steps_all_run3.log
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import time
import glob
import subprocess
import contextlib
from dataclasses import dataclass
"""DiM (Diffusion Mixer)."""
import math
import typing
import einops
import torch
class DiMConfig(typing.NamedTuple):
@joey00072
joey00072 / mla.py
Created December 28, 2024 16:25
multi head latent attention (MLA)
# https://x.com/shxf0072/status/1873038335427658011
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections import OrderedDict
from ohara.modules.norm import RMSNorm
@zjlww
zjlww / model.py
Created December 7, 2024 01:39
Stripped AudioCodecModel from NeMo @ bde672e
from typing import Tuple
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer
class AudioCodecModel(nn.Module):
@charlesfrye
charlesfrye / wrapper.py
Last active February 24, 2025 16:16
Train GPT-2 in five minutes -- for free!
# Train GPT-2 in five minutes -- for free
#
# ```bash
# pip install modal
# modal setup
# modal run wrapper.py
# ```
#
# Note that the end-to-end latency the first time is more like 25 minutes:
# - five minutes to install Torch (rip)
@crowsonkb
crowsonkb / ring_attn.py
Created October 10, 2024 16:19
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
def apply_p_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
rope_percentage: float = 1.0,
) -> jax.Array:
"""Applies p-RoPE."""
rope_angles = int(rope_percentage * head_dim // 2)
nope_angles = head_dim // 2 - rope_angles
from typing import Callable
import numpy as np
from tqdm import tqdm
def wsola_chunked_processing(audio: np.ndarray, sr: int, chunk_size: int, hop_size: int, mod_func: Callable[[np.ndarray], np.ndarray]):
# Check if chunk_size is larger than the audio length
if chunk_size >= len(audio):
# Process the entire audio in one go
output = mod_func(audio).squeeze()