Skip to content

Instantly share code, notes, and snippets.

@mberr
Created May 4, 2022 21:42
Show Gist options
  • Save mberr/894109e59df27f0cc151b979fc8ec60f to your computer and use it in GitHub Desktop.
Save mberr/894109e59df27f0cc151b979fc8ec60f to your computer and use it in GitHub Desktop.
A simple PyTorch implementation of rational activation functions.
import torch
from torch import nn
class RationalActivation(nn.Module):
"""
A rational activation function with trainable parameters.
Inspired by https://arxiv.org/abs/2205.01549.
.. seealso::
https://en.wikipedia.org/wiki/Rational_function
https://github.com/ml-research/rational_activations
https://arxiv.org/abs/1907.06732
"""
def __init__(self, n: int = 4, m: int = 5) -> None:
"""
Initialize the activation.
:param n:
the nominator polynom degree.
:param m:
the denominator polynom degree.
"""
super().__init__()
# add zero degree parameter
n, m = n + 1, m + 1
self.m = m
self.n = n
self.k = max(m, n)
self.w_p = nn.Parameter(torch.empty(n))
self.w_q = nn.Parameter(torch.empty(m))
self.reset_parameters()
def reset_parameters(self):
# these could be fitted to match existing activation functions
nn.init.uniform_(self.w_p)
nn.init.uniform_(self.w_q)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# calculate exponentation only once
x = x.unsqueeze(dim=-1) ** torch.arange(self.k, device=x.device, dtype=x.dtype).view(*ones_like(x.shape), -1)
# use einsum to avoid manual reshaping
nom = torch.einsum("...i, i -> ...", x[..., : self.n], self.w_p)
denom = torch.einsum("...i, i -> ...", x[..., : self.m], self.w_q)
# make sure rational function does not have poles
return nom / (1 + denom.abs())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment