Created
November 11, 2024 03:00
-
-
Save pszemraj/d3c0a873866b4b4517ffb4f0072219f6 to your computer and use it in GitHub Desktop.
working CEMA for cpu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
based on the megalodon implementation | |
https://github.com/XuezheMax/megalodon/blob/53cbaa3a3b3e05ea439564bd67cb352012ba6b97/megalodon/modules/complex_exponential_moving_average.py | |
""" | |
import math | |
import torch | |
from torch import nn | |
from typing import Optional, Tuple | |
def _reset_parameters(alpha, delta, theta, gamma, omega, embed_dim, device): | |
""" | |
Initializes the parameters of the MultiHeadComplexEMA layer. | |
Args: | |
alpha (torch.Tensor): Parameter tensor for alpha. | |
delta (torch.Tensor): Parameter tensor for delta. | |
theta (torch.Tensor): Parameter tensor for theta. | |
gamma (torch.Tensor): Parameter tensor for gamma. | |
omega (torch.Tensor): Parameter tensor for omega. | |
embed_dim (int): Embedding dimension. | |
device (torch.device): Device to place tensors on. | |
""" | |
# Initialize alpha & delta with normal distribution | |
nn.init.normal_(alpha, mean=0.0, std=0.2) | |
nn.init.normal_(delta, mean=0.0, std=0.2) | |
# Initialize theta | |
freqs = math.log(embed_dim) / embed_dim | |
freqs = torch.exp( | |
torch.arange(1, embed_dim + 1, device=device, dtype=theta.dtype) * -freqs | |
) | |
freqs = torch.log(freqs / (1.0 - freqs)).view(embed_dim, 1, 1) | |
with torch.no_grad(): | |
theta.copy_(freqs) | |
# Initialize gamma and omega | |
nn.init.normal_(gamma, mean=0.0, std=1.0) | |
with torch.no_grad(): | |
gamma[:, :, 1] = 0.0 # Set imaginary part to 0 | |
nn.init.trunc_normal_(omega, mean=0.0, std=0.25, a=-1.0, b=1.0) | |
class MultiHeadComplexEMA(nn.Module): | |
"""Complex Exponential Moving Average (CEMA) Layer. | |
This implementation uses standard PyTorch operations without any custom CUDA extensions. | |
Args: | |
embed_dim (int): Dimension of the embedding. | |
ndim (int, optional): Number of dimensions for EMA. Default is 16. | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
ndim: int = 16, | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.ndim = ndim | |
self.scale = math.sqrt(1.0 / self.ndim) | |
# Initialize parameters | |
self.alpha = nn.Parameter(torch.Tensor(embed_dim, ndim, 1)) | |
self.delta = nn.Parameter(torch.Tensor(embed_dim, ndim, 1)) | |
self.theta = nn.Parameter(torch.Tensor(embed_dim, 1, 1)) | |
self.gamma = nn.Parameter( | |
torch.Tensor(embed_dim, ndim, 2) | |
) # Complex coefficients | |
self.omega = nn.Parameter(torch.Tensor(embed_dim, 1)) | |
self._coeffs = None # Cache coefficients for inference | |
self._init_parameters() | |
def _init_parameters(self): | |
device = self.alpha.device | |
_reset_parameters( | |
self.alpha, | |
self.delta, | |
self.theta, | |
self.gamma, | |
self.omega, | |
self.embed_dim, | |
device, | |
) | |
def _calc_coeffs(self): | |
""" | |
Calculate the EMA coefficients p, q, and gamma based on current parameters. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
- p: (D, N, 1) | |
- q: (D, N), complex | |
- gamma: (D, N), complex | |
""" | |
# Apply sigmoid to alpha and delta to constrain them between 0 and 1 | |
alpha = torch.sigmoid(self.alpha.float()) # (D, N, 1) | |
delta = torch.sigmoid(self.delta.float()) # (D, N, 1) | |
# Compute theta scaled by wavelets | |
theta = torch.sigmoid(self.theta.float()) * ( | |
2 * math.pi / self.ndim | |
) # (D, 1, 1) | |
wavelets = torch.arange( | |
1, self.ndim + 1, dtype=theta.dtype, device=theta.device | |
).view( | |
1, self.ndim | |
) # (1, N) | |
theta = wavelets.unsqueeze(2) * theta # (D, N, 1) | |
# Compute p and q | |
p = alpha # (D, N, 1) | |
magnitude = (1.0 - alpha * delta).squeeze(2) # (D, N) | |
angle = theta.squeeze(2) # (D, N) | |
q = torch.polar(magnitude, angle) # (D, N), complex | |
# Compute gamma | |
gamma = ( | |
torch.view_as_complex(self.gamma.float()) * self.scale | |
) # (D, N), complex | |
return p, q, gamma | |
def coeffs(self): | |
""" | |
Retrieve or calculate the EMA coefficients. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
- p: (D, N, 1) | |
- q: (D, N), complex | |
- gamma: (D, N), complex | |
""" | |
if self.training: | |
return self._calc_coeffs() | |
else: | |
if self._coeffs is None: | |
self._coeffs = self._calc_coeffs() | |
return self._coeffs | |
def fftconv(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: | |
""" | |
Perform FFT-based convolution. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, D, L), real | |
k (torch.Tensor): Kernel tensor of shape (D, N, L), complex | |
Returns: | |
torch.Tensor: Convolved output of shape (B, D, L), real | |
""" | |
B, D, L = x.shape | |
_, N, _ = k.shape | |
# Perform FFT on input (real) | |
x_f = torch.fft.fft(x, dim=-1) # (B, D, L), complex | |
# Perform FFT on kernel (complex) | |
k_f = torch.fft.fft(k, dim=-1) # (D, N, L), complex | |
# Multiply in frequency domain: (B, D, 1, L) * (1, D, N, L) = (B, D, N, L) | |
y_f = x_f.unsqueeze(2) * k_f.unsqueeze(0) # (B, D, N, L), complex | |
# Sum over the N dimension: (B, D, L), complex | |
y_f = y_f.sum(dim=2) # (B, D, L), complex | |
# Inverse FFT to get back to time domain | |
y = torch.fft.ifft(y_f, dim=-1).real # (B, D, L), real | |
return y | |
def ema_parameters( | |
self, | |
p: torch.Tensor, | |
q: torch.Tensor, | |
gamma: torch.Tensor, | |
hx: Optional[torch.Tensor], | |
length: int, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Generate convolution kernels and bias from EMA coefficients and hidden state. | |
Args: | |
p (torch.Tensor): p coefficients of shape (D, N, 1). | |
q (torch.Tensor): q coefficients of shape (D, N), complex. | |
gamma (torch.Tensor): gamma coefficients of shape (D, N), complex. | |
hx (Optional[torch.Tensor]): Hidden state tensor of shape (B, D, N, 2). | |
length (int): Sequence length. | |
Returns: | |
Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
- k: Kernel tensor of shape (D, N, L), complex. | |
- b: Bias tensor of shape (B, D, L) if hx is provided, else None. | |
""" | |
D, N = q.shape | |
B = hx.size(0) if hx is not None else None | |
# Create time steps with real dtype | |
time_steps = torch.arange(length, device=q.device, dtype=q.real.dtype).view( | |
1, 1, length | |
) # (1,1,L) | |
# Compute q^t for each time step | |
q_t = torch.pow(q.unsqueeze(-1), time_steps) # (D, N, L), complex | |
# Compute k: (D, N, L), complex | |
k = gamma.unsqueeze(-1) * q_t # (D, N, L), complex | |
# Compute bias if hx is provided | |
if hx is not None: | |
# Convert hx from real to complex | |
hx_complex = torch.view_as_complex(hx) # (B, D, N), complex | |
# Compute bias: p * hx * q^t | |
# p has shape (D, N, 1), broadcast to (B, D, N, L) | |
bias = (p * hx_complex.unsqueeze(-1)) * q_t # (B, D, N, L), complex | |
# Sum over the N dimension to get (B, D, L), complex | |
bias = bias.sum(dim=2) # (B, D, L), complex | |
# Since the output is real, take the real part of the bias | |
bias = bias.real # (B, D, L), real | |
else: | |
bias = None | |
return k, bias | |
def ema_hidden( | |
self, | |
x: torch.Tensor, | |
p: torch.Tensor, | |
q: torch.Tensor, | |
hx: Optional[torch.Tensor], | |
) -> Optional[torch.Tensor]: | |
""" | |
Update hidden state based on input and EMA coefficients. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, D, L). | |
p (torch.Tensor): p coefficients of shape (D, N, 1). | |
q (torch.Tensor): q coefficients of shape (D, N), complex. | |
hx (Optional[torch.Tensor]): Previous hidden state of shape (B, D, N, 2). | |
Returns: | |
Optional[torch.Tensor]: Updated hidden state. | |
""" | |
if hx is None: | |
# Initialize hidden state | |
B, D, _ = x.shape | |
hx = torch.zeros(B, D, self.ndim, 2, device=x.device, dtype=x.dtype) | |
# Convert hx to complex | |
hx_complex = torch.view_as_complex(hx) # (B, D, N), complex | |
# Aggregate input over the sequence length (e.g., mean) | |
x_mean = x.mean(dim=-1) # (B, D) | |
x_mean = x_mean.unsqueeze(2) # (B, D, 1) | |
# Expand p to match (B, D, N) | |
p_expanded = p.squeeze(2).unsqueeze(0).expand_as(hx_complex) # (B, D, N) | |
# Update rule: h_new = p * x_mean + q * h_old | |
h_new = p_expanded * x_mean + q.unsqueeze(0) * hx_complex # (B, D, N), complex | |
# Convert back to real tensor | |
h_new_real = torch.view_as_real(h_new) # (B, D, N, 2) | |
return h_new_real | |
def forward( | |
self, | |
x: torch.Tensor, | |
hx: Optional[torch.Tensor] = None, | |
compute_last_state: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Forward pass of the CEMA layer. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, D, L). | |
hx (Optional[torch.Tensor]): Hidden state tensor of shape (B, D, N, 2). | |
compute_last_state (bool): Whether to compute and return the last hidden state. | |
Returns: | |
Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
- Output tensor of shape (B, D, L). | |
- Updated hidden state if compute_last_state is True, else None. | |
""" | |
B, D, L = x.shape | |
residual = x * self.omega # (B, D, L), real | |
p, q, gamma = self.coeffs() # p: (D, N, 1), q: (D, N), gamma: (D, N) | |
# Generate convolution kernels and bias | |
k, b = self.ema_parameters( | |
p, q, gamma, hx, L | |
) # k: (D, N, L), complex; b: (B, D, L) or None | |
# Perform FFT-based convolution | |
output = self.fftconv(x, k) # (B, D, L), real | |
if b is not None: | |
output = output + b.to(output.dtype) # (B, D, L), real | |
# Update hidden state if required | |
h = self.ema_hidden(x, p, q, hx) if compute_last_state else None | |
# Add residual | |
output = output + residual # (B, D, L), real | |
return output, h | |
def extra_repr(self) -> str: | |
return f"embed_dim={self.embed_dim}, ndim={self.ndim}" | |
# Example Usage | |
if __name__ == "__main__": | |
# Define parameters | |
batch_size = 2 | |
embed_dim = 8 | |
ndim = 4 | |
seq_length = 16 | |
# Create random input | |
x = torch.randn(batch_size, embed_dim, seq_length) | |
# Initialize CEMA layer | |
cema = MultiHeadComplexEMA(embed_dim=embed_dim, ndim=ndim) | |
# Move to CPU explicitly (optional, since default is CPU) | |
device = torch.device("cpu") # Change to 'cuda' if using GPU | |
cema = cema.to(device) | |
x = x.to(device) | |
# Forward pass without hidden state | |
output, h = cema(x, compute_last_state=False) | |
print("Input shape:", x.shape) | |
print("Output shape:", output.shape) | |
print("Hidden state:", h) | |
# Forward pass with hidden state | |
hx = torch.zeros( | |
batch_size, embed_dim, ndim, 2, device=device, dtype=x.dtype | |
) # Initialize hidden state | |
output, h = cema(x, hx=hx, compute_last_state=True) | |
print("\nWith Hidden State:") | |
print("Output shape:", output.shape) | |
print("Updated Hidden state shape:", h.shape if h is not None else None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
expected output: