Created
June 27, 2019 08:38
-
-
Save Lexie88rus/6f8e6ab48f0729f548636472a07beaa4 to your computer and use it in GitHub Desktop.
SiLU implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# simply define a silu function | |
def silu(input): | |
''' | |
Applies the Sigmoid Linear Unit (SiLU) function element-wise: | |
SiLU(x) = x * sigmoid(x) | |
''' | |
return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions | |
# create a class wrapper from PyTorch nn.Module, so | |
# the function now can be easily used in models | |
class SiLU(nn.Module): | |
''' | |
Applies the Sigmoid Linear Unit (SiLU) function element-wise: | |
SiLU(x) = x * sigmoid(x) | |
Shape: | |
- Input: (N, *) where * means, any number of additional | |
dimensions | |
- Output: (N, *), same shape as the input | |
References: | |
- Related paper: | |
https://arxiv.org/pdf/1606.08415.pdf | |
Examples: | |
>>> m = silu() | |
>>> input = torch.randn(2) | |
>>> output = m(input) | |
''' | |
def __init__(self): | |
''' | |
Init method. | |
''' | |
super().__init__() # init the base class | |
def forward(self, input): | |
''' | |
Forward pass of the function. | |
''' | |
return silu(input) # simply apply already implemented SiLU |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment