This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""DiM (Diffusion Mixer).""" | |
import math | |
import typing | |
import einops | |
import torch | |
class DiMConfig(typing.NamedTuple): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://x.com/shxf0072/status/1873038335427658011 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from collections import OrderedDict | |
from ohara.modules.norm import RMSNorm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from einops import rearrange | |
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer | |
class AudioCodecModel(nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train GPT-2 in five minutes -- for free | |
# | |
# ```bash | |
# pip install modal | |
# modal setup | |
# modal run wrapper.py | |
# ``` | |
# | |
# Note that the end-to-end latency the first time is more like 25 minutes: | |
# - five minutes to install Torch (rip) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Ring attention for PyTorch. | |
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py. | |
""" | |
import flash_attn.flash_attn_interface as fai | |
import torch | |
import torch.distributed as dist | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def apply_p_rope( | |
inputs: jax.Array, # [B, L] | |
positions: jax.Array, # [B, L] | |
head_dim: int, | |
max_wavelength: int = _MAX_WAVELENGTH, | |
rope_percentage: float = 1.0, | |
) -> jax.Array: | |
"""Applies p-RoPE.""" | |
rope_angles = int(rope_percentage * head_dim // 2) | |
nope_angles = head_dim // 2 - rope_angles |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable | |
import numpy as np | |
from tqdm import tqdm | |
def wsola_chunked_processing(audio: np.ndarray, sr: int, chunk_size: int, hop_size: int, mod_func: Callable[[np.ndarray], np.ndarray]): | |
# Check if chunk_size is larger than the audio length | |
if chunk_size >= len(audio): | |
# Process the entire audio in one go | |
output = mod_func(audio).squeeze() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import einsum | |
import torch.distributed as dist | |
def tree_attn_decode(q, k, v): | |
""" | |
Algorithm 3 proposed in Tree Attention | |
https://arxiv.org/abs/2408.04093 | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Computing zeroth matrix powers via Lakic 1998. | |
paper: "On the Computation of the Matrix k-th Root" | |
Suppose we have a matrix G = USV^T and we want to compute | |
G^0 defined via G^0 = UV^T. We might want to do this to run | |
"stochastic spectral descent" of Carlson et al 2015. The | |
naive way to do this is via the SVD. But we can also just do | |
(GG^T)^(-1/2) G or alternatively G (G^TG)^(-1/2) and apply | |
the iterative method from Lakic 1998. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Generates a document causal attention mask based on a document ID tensor""" | |
from typing import List, Union | |
import torch | |
from torch import Tensor | |
from torch.nn.attention.flex_attention import _mask_mod_signature, or_masks | |
from attn_gym.masks import causal_mask | |
NewerOlder