Skip to content

Instantly share code, notes, and snippets.

@sanjoy
Created December 30, 2024 07:00
Show Gist options
  • Save sanjoy/a491556da8fd77fa13c585581c3fae15 to your computer and use it in GitHub Desktop.
Save sanjoy/a491556da8fd77fa13c585581c3fae15 to your computer and use it in GitHub Desktop.
import numpy as np
def softmax_reduction(x, y):
e_x = np.exp(x - np.max(x))
softmax = e_x / e_x.sum(axis=0)
return (softmax * y).sum()
def process_one_element(
x, y, online_max, online_output, online_denominator):
old_online_max = online_max
online_max = max(online_max, x)
online_output = (
online_output * np.exp(old_online_max - online_max) +
np.exp(x - online_max) * y)
online_denominator = (
online_denominator * np.exp(old_online_max - online_max) +
np.exp(x - online_max))
return online_max, online_output, online_denominator
def softmax_reduction_online(x, y):
online_max = float('-inf')
online_output = 0.0
online_denominator = 0.0
for xi, yi in zip(x, y):
online_max, online_output, online_denominator = process_one_element(
xi, yi, online_max, online_output, online_denominator)
return online_output / online_denominator
def main():
x = np.random.random(128)
y = np.random.random(128)
result = softmax_reduction(x, y)
result_online = softmax_reduction_online(x, y)
print(np.allclose(result, result_online))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment