This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
""" | |
Fused Attention | |
=============== | |
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) | |
Credits: OpenAI kernel team | |
Extra Credits: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Proof-of-concept for NAT traversal and low-latency communication over QUIC | |
between two Modal containers. | |
In theory this could be used to establish a low-latency p2p connection between a | |
service running outside Modal and a Modal GPU container, e.g. for real-time | |
inference on a video stream. Please let us know if you try it! | |
Usage: | |
> modal run modal_quic_hole_punch.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
with open(sys.argv[0]) as f: | |
code = f.read() # read the code of this file ASAP, for logging | |
import uuid | |
import time | |
import glob | |
import subprocess | |
import contextlib | |
from dataclasses import dataclass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""DiM (Diffusion Mixer).""" | |
import math | |
import typing | |
import einops | |
import torch | |
class DiMConfig(typing.NamedTuple): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://x.com/shxf0072/status/1873038335427658011 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from collections import OrderedDict | |
from ohara.modules.norm import RMSNorm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from einops import rearrange | |
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer | |
class AudioCodecModel(nn.Module): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train GPT-2 in five minutes -- for free | |
# | |
# ```bash | |
# pip install modal | |
# modal setup | |
# modal run wrapper.py | |
# ``` | |
# | |
# Note that the end-to-end latency the first time is more like 25 minutes: | |
# - five minutes to install Torch (rip) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Ring attention for PyTorch. | |
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py. | |
""" | |
import flash_attn.flash_attn_interface as fai | |
import torch | |
import torch.distributed as dist |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def apply_p_rope( | |
inputs: jax.Array, # [B, L] | |
positions: jax.Array, # [B, L] | |
head_dim: int, | |
max_wavelength: int = _MAX_WAVELENGTH, | |
rope_percentage: float = 1.0, | |
) -> jax.Array: | |
"""Applies p-RoPE.""" | |
rope_angles = int(rope_percentage * head_dim // 2) | |
nope_angles = head_dim // 2 - rope_angles |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable | |
import numpy as np | |
from tqdm import tqdm | |
def wsola_chunked_processing(audio: np.ndarray, sr: int, chunk_size: int, hop_size: int, mod_func: Callable[[np.ndarray], np.ndarray]): | |
# Check if chunk_size is larger than the audio length | |
if chunk_size >= len(audio): | |
# Process the entire audio in one go | |
output = mod_func(audio).squeeze() |
NewerOlder