Skip to content

Instantly share code, notes, and snippets.

View Ryu1845's full-sized avatar
🎯
Focusing

Sofian Mejjoute Ryu1845

🎯
Focusing
View GitHub Profile
class Attention(nn.Module):
def __init__(self, dim):
self.pre_norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, 3*dim)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
x = self.pre_norm(x)
qkv = self.to_qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
def apply_p_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
max_wavelength: int = _MAX_WAVELENGTH,
rope_percentage: float = 1.0,
) -> jax.Array:
"""Applies p-RoPE."""
rope_angles = int(rope_percentage * head_dim // 2)
nope_angles = head_dim // 2 - rope_angles
from typing import Callable
import numpy as np
from tqdm import tqdm
def wsola_chunked_processing(audio: np.ndarray, sr: int, chunk_size: int, hop_size: int, mod_func: Callable[[np.ndarray], np.ndarray]):
# Check if chunk_size is larger than the audio length
if chunk_size >= len(audio):
# Process the entire audio in one go
output = mod_func(audio).squeeze()
@Ryu1845
Ryu1845 / convnext.py
Last active August 5, 2024 11:57 — forked from amirshamaei/convnext.py
A ConvNet for the 2020s (1D version)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
@Ryu1845
Ryu1845 / fsq.py
Created September 28, 2023 08:01
FSQ Implementation from the paper
def round_ste(z):
"""Round with straight through gradients."""
zhat = jnp.round(z)
return z + jax.lax.stop_gradient(zhat - z)
class FSQ:
@Ryu1845
Ryu1845 / zero_init.py
Last active August 1, 2024 19:38
ZerO Initialization copied from the original repo (https://github.com/jiaweizzhao/ZerO-initialization/)
import math
import torch
def hadamard(n: int, dtype=torch.int8):
"""This function is a port of the one in scipy.linalg"""
if n < 1:
lg2 = 0
else:
lg2 = int(math.log(n, 2))
@Ryu1845
Ryu1845 / came.py
Last active July 17, 2023 23:22
CAME: Confidence-guided Adaptive Memory Efficient Optimization from the official repo (https://github.com/huawei-noah/Pretrained-Language-Model/blob/master/CAME/came.py)
import math
import torch
import torch.optim
class CAME(torch.optim.Optimizer):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
@Ryu1845
Ryu1845 / bst_layer.py
Created June 19, 2023 20:29
JAX implementation of Block-State Transfomer (copied from https://arxiv.org/abs/2306.09539)
"""Block-State Transformer Layer."""
# Block Transformers are non-recurrent and parallelizable.
block_transformer = jax.vmap(BRecT.nonrecurrent_cell)
def BST(u):
"""Block-State Transformer Layer."""
global MF # True if Multi-Filter, False otherwise (SH/MH)
# split inputs into windows (l/w, w, d)
@Ryu1845
Ryu1845 / lru.py
Last active December 13, 2023 13:17
Linear Recurrent Unit (LRU) from the paper ["Resurrecting Recurrent Neural Networks for Long Sequences"](https://arxiv.org/abs/2303.06349)
"""
Simplified Implementation of the Linear Recurrent Unit
------------------------------------------------------
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
according to the following formula (and efficiently parallelized using an associative scan):
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
to account for halving the state variance by taking the real part of the output projection.
@Ryu1845
Ryu1845 / split.py
Last active March 18, 2023 19:45 — forked from ptrblck/numpy split vs PyTorch split
numpy split vs PyTorch split
import torch
import numpy as np
# numpy
a = np.random.rand(10, 20)
tmp0 = np.split(a, indices_or_sections=5, axis=0) # split into 5 sections
for t in tmp0:
print(t.shape)
# (2, 20)
# (2, 20)