Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created April 3, 2023 22:45
Show Gist options
  • Save Birch-san/b6f6c20c0bb26f1a624e4e46d3f8e5a4 to your computer and use it in GitHub Desktop.
Save Birch-san/b6f6c20c0bb26f1a624e4e46d3f8e5a4 to your computer and use it in GitHub Desktop.
Reducing the softmax denominator to sum only as many attention scores as the in-distibution checkpoint would've, so that its outputs have in-distribution magnitudes
from torch import FloatTensor
vae_scale_factor = 8
typical_self_attn_key_length = (512/vae_scale_factor) * (512/vae_scale_factor)
desired_self_attn_key_length = (768/vae_scale_factor) * (768/vae_scale_factor)
key_length_factor=desired_self_attn_key_length/typical_self_attn_key_length if is_self_attn else 1.
def softmax(x: FloatTensor, dim=-1) -> FloatTensor:
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
key_tokens = x.size(-1)
preferred_token_count = int(key_tokens/key_length_factor)
x_exp_sum = x_exp.topk(k=preferred_token_count, dim=dim).values.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment