Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created November 11, 2024 03:00
Show Gist options
  • Save pszemraj/d3c0a873866b4b4517ffb4f0072219f6 to your computer and use it in GitHub Desktop.
Save pszemraj/d3c0a873866b4b4517ffb4f0072219f6 to your computer and use it in GitHub Desktop.
working CEMA for cpu
"""
based on the megalodon implementation
https://github.com/XuezheMax/megalodon/blob/53cbaa3a3b3e05ea439564bd67cb352012ba6b97/megalodon/modules/complex_exponential_moving_average.py
"""
import math
import torch
from torch import nn
from typing import Optional, Tuple
def _reset_parameters(alpha, delta, theta, gamma, omega, embed_dim, device):
"""
Initializes the parameters of the MultiHeadComplexEMA layer.
Args:
alpha (torch.Tensor): Parameter tensor for alpha.
delta (torch.Tensor): Parameter tensor for delta.
theta (torch.Tensor): Parameter tensor for theta.
gamma (torch.Tensor): Parameter tensor for gamma.
omega (torch.Tensor): Parameter tensor for omega.
embed_dim (int): Embedding dimension.
device (torch.device): Device to place tensors on.
"""
# Initialize alpha & delta with normal distribution
nn.init.normal_(alpha, mean=0.0, std=0.2)
nn.init.normal_(delta, mean=0.0, std=0.2)
# Initialize theta
freqs = math.log(embed_dim) / embed_dim
freqs = torch.exp(
torch.arange(1, embed_dim + 1, device=device, dtype=theta.dtype) * -freqs
)
freqs = torch.log(freqs / (1.0 - freqs)).view(embed_dim, 1, 1)
with torch.no_grad():
theta.copy_(freqs)
# Initialize gamma and omega
nn.init.normal_(gamma, mean=0.0, std=1.0)
with torch.no_grad():
gamma[:, :, 1] = 0.0 # Set imaginary part to 0
nn.init.trunc_normal_(omega, mean=0.0, std=0.25, a=-1.0, b=1.0)
class MultiHeadComplexEMA(nn.Module):
"""Complex Exponential Moving Average (CEMA) Layer.
This implementation uses standard PyTorch operations without any custom CUDA extensions.
Args:
embed_dim (int): Dimension of the embedding.
ndim (int, optional): Number of dimensions for EMA. Default is 16.
"""
def __init__(
self,
embed_dim: int,
ndim: int = 16,
):
super().__init__()
self.embed_dim = embed_dim
self.ndim = ndim
self.scale = math.sqrt(1.0 / self.ndim)
# Initialize parameters
self.alpha = nn.Parameter(torch.Tensor(embed_dim, ndim, 1))
self.delta = nn.Parameter(torch.Tensor(embed_dim, ndim, 1))
self.theta = nn.Parameter(torch.Tensor(embed_dim, 1, 1))
self.gamma = nn.Parameter(
torch.Tensor(embed_dim, ndim, 2)
) # Complex coefficients
self.omega = nn.Parameter(torch.Tensor(embed_dim, 1))
self._coeffs = None # Cache coefficients for inference
self._init_parameters()
def _init_parameters(self):
device = self.alpha.device
_reset_parameters(
self.alpha,
self.delta,
self.theta,
self.gamma,
self.omega,
self.embed_dim,
device,
)
def _calc_coeffs(self):
"""
Calculate the EMA coefficients p, q, and gamma based on current parameters.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- p: (D, N, 1)
- q: (D, N), complex
- gamma: (D, N), complex
"""
# Apply sigmoid to alpha and delta to constrain them between 0 and 1
alpha = torch.sigmoid(self.alpha.float()) # (D, N, 1)
delta = torch.sigmoid(self.delta.float()) # (D, N, 1)
# Compute theta scaled by wavelets
theta = torch.sigmoid(self.theta.float()) * (
2 * math.pi / self.ndim
) # (D, 1, 1)
wavelets = torch.arange(
1, self.ndim + 1, dtype=theta.dtype, device=theta.device
).view(
1, self.ndim
) # (1, N)
theta = wavelets.unsqueeze(2) * theta # (D, N, 1)
# Compute p and q
p = alpha # (D, N, 1)
magnitude = (1.0 - alpha * delta).squeeze(2) # (D, N)
angle = theta.squeeze(2) # (D, N)
q = torch.polar(magnitude, angle) # (D, N), complex
# Compute gamma
gamma = (
torch.view_as_complex(self.gamma.float()) * self.scale
) # (D, N), complex
return p, q, gamma
def coeffs(self):
"""
Retrieve or calculate the EMA coefficients.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- p: (D, N, 1)
- q: (D, N), complex
- gamma: (D, N), complex
"""
if self.training:
return self._calc_coeffs()
else:
if self._coeffs is None:
self._coeffs = self._calc_coeffs()
return self._coeffs
def fftconv(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
"""
Perform FFT-based convolution.
Args:
x (torch.Tensor): Input tensor of shape (B, D, L), real
k (torch.Tensor): Kernel tensor of shape (D, N, L), complex
Returns:
torch.Tensor: Convolved output of shape (B, D, L), real
"""
B, D, L = x.shape
_, N, _ = k.shape
# Perform FFT on input (real)
x_f = torch.fft.fft(x, dim=-1) # (B, D, L), complex
# Perform FFT on kernel (complex)
k_f = torch.fft.fft(k, dim=-1) # (D, N, L), complex
# Multiply in frequency domain: (B, D, 1, L) * (1, D, N, L) = (B, D, N, L)
y_f = x_f.unsqueeze(2) * k_f.unsqueeze(0) # (B, D, N, L), complex
# Sum over the N dimension: (B, D, L), complex
y_f = y_f.sum(dim=2) # (B, D, L), complex
# Inverse FFT to get back to time domain
y = torch.fft.ifft(y_f, dim=-1).real # (B, D, L), real
return y
def ema_parameters(
self,
p: torch.Tensor,
q: torch.Tensor,
gamma: torch.Tensor,
hx: Optional[torch.Tensor],
length: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Generate convolution kernels and bias from EMA coefficients and hidden state.
Args:
p (torch.Tensor): p coefficients of shape (D, N, 1).
q (torch.Tensor): q coefficients of shape (D, N), complex.
gamma (torch.Tensor): gamma coefficients of shape (D, N), complex.
hx (Optional[torch.Tensor]): Hidden state tensor of shape (B, D, N, 2).
length (int): Sequence length.
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- k: Kernel tensor of shape (D, N, L), complex.
- b: Bias tensor of shape (B, D, L) if hx is provided, else None.
"""
D, N = q.shape
B = hx.size(0) if hx is not None else None
# Create time steps with real dtype
time_steps = torch.arange(length, device=q.device, dtype=q.real.dtype).view(
1, 1, length
) # (1,1,L)
# Compute q^t for each time step
q_t = torch.pow(q.unsqueeze(-1), time_steps) # (D, N, L), complex
# Compute k: (D, N, L), complex
k = gamma.unsqueeze(-1) * q_t # (D, N, L), complex
# Compute bias if hx is provided
if hx is not None:
# Convert hx from real to complex
hx_complex = torch.view_as_complex(hx) # (B, D, N), complex
# Compute bias: p * hx * q^t
# p has shape (D, N, 1), broadcast to (B, D, N, L)
bias = (p * hx_complex.unsqueeze(-1)) * q_t # (B, D, N, L), complex
# Sum over the N dimension to get (B, D, L), complex
bias = bias.sum(dim=2) # (B, D, L), complex
# Since the output is real, take the real part of the bias
bias = bias.real # (B, D, L), real
else:
bias = None
return k, bias
def ema_hidden(
self,
x: torch.Tensor,
p: torch.Tensor,
q: torch.Tensor,
hx: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""
Update hidden state based on input and EMA coefficients.
Args:
x (torch.Tensor): Input tensor of shape (B, D, L).
p (torch.Tensor): p coefficients of shape (D, N, 1).
q (torch.Tensor): q coefficients of shape (D, N), complex.
hx (Optional[torch.Tensor]): Previous hidden state of shape (B, D, N, 2).
Returns:
Optional[torch.Tensor]: Updated hidden state.
"""
if hx is None:
# Initialize hidden state
B, D, _ = x.shape
hx = torch.zeros(B, D, self.ndim, 2, device=x.device, dtype=x.dtype)
# Convert hx to complex
hx_complex = torch.view_as_complex(hx) # (B, D, N), complex
# Aggregate input over the sequence length (e.g., mean)
x_mean = x.mean(dim=-1) # (B, D)
x_mean = x_mean.unsqueeze(2) # (B, D, 1)
# Expand p to match (B, D, N)
p_expanded = p.squeeze(2).unsqueeze(0).expand_as(hx_complex) # (B, D, N)
# Update rule: h_new = p * x_mean + q * h_old
h_new = p_expanded * x_mean + q.unsqueeze(0) * hx_complex # (B, D, N), complex
# Convert back to real tensor
h_new_real = torch.view_as_real(h_new) # (B, D, N, 2)
return h_new_real
def forward(
self,
x: torch.Tensor,
hx: Optional[torch.Tensor] = None,
compute_last_state: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass of the CEMA layer.
Args:
x (torch.Tensor): Input tensor of shape (B, D, L).
hx (Optional[torch.Tensor]): Hidden state tensor of shape (B, D, N, 2).
compute_last_state (bool): Whether to compute and return the last hidden state.
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Output tensor of shape (B, D, L).
- Updated hidden state if compute_last_state is True, else None.
"""
B, D, L = x.shape
residual = x * self.omega # (B, D, L), real
p, q, gamma = self.coeffs() # p: (D, N, 1), q: (D, N), gamma: (D, N)
# Generate convolution kernels and bias
k, b = self.ema_parameters(
p, q, gamma, hx, L
) # k: (D, N, L), complex; b: (B, D, L) or None
# Perform FFT-based convolution
output = self.fftconv(x, k) # (B, D, L), real
if b is not None:
output = output + b.to(output.dtype) # (B, D, L), real
# Update hidden state if required
h = self.ema_hidden(x, p, q, hx) if compute_last_state else None
# Add residual
output = output + residual # (B, D, L), real
return output, h
def extra_repr(self) -> str:
return f"embed_dim={self.embed_dim}, ndim={self.ndim}"
# Example Usage
if __name__ == "__main__":
# Define parameters
batch_size = 2
embed_dim = 8
ndim = 4
seq_length = 16
# Create random input
x = torch.randn(batch_size, embed_dim, seq_length)
# Initialize CEMA layer
cema = MultiHeadComplexEMA(embed_dim=embed_dim, ndim=ndim)
# Move to CPU explicitly (optional, since default is CPU)
device = torch.device("cpu") # Change to 'cuda' if using GPU
cema = cema.to(device)
x = x.to(device)
# Forward pass without hidden state
output, h = cema(x, compute_last_state=False)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Hidden state:", h)
# Forward pass with hidden state
hx = torch.zeros(
batch_size, embed_dim, ndim, 2, device=device, dtype=x.dtype
) # Initialize hidden state
output, h = cema(x, hx=hx, compute_last_state=True)
print("\nWith Hidden State:")
print("Output shape:", output.shape)
print("Updated Hidden state shape:", h.shape if h is not None else None)
@pszemraj
Copy link
Author

expected output:

Input shape: torch.Size([2, 8, 16])
Output shape: torch.Size([2, 8, 16])
Hidden state: None

With Hidden State:
Output shape: torch.Size([2, 8, 16])
Updated Hidden state shape: torch.Size([2, 8, 4, 2])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment