Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created April 3, 2023 00:12
Show Gist options
  • Save Birch-san/c1f7329a2a1c510b16a4fac16c2d42aa to your computer and use it in GitHub Desktop.
Save Birch-san/c1f7329a2a1c510b16a4fac16c2d42aa to your computer and use it in GitHub Desktop.
Typical softmax
from torch import FloatTensor
def softmax(x: FloatTensor, dim=-1) -> FloatTensor:
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
x_exp_sum = x_exp.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment