Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created September 27, 2025 04:01
Show Gist options
  • Save CoffeeVampir3/285615057cfc7472222104cf26d47d57 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/285615057cfc7472222104cf26d47d57 to your computer and use it in GitHub Desktop.
Fast VS Slow Tversky Multihead Bench
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def tversky_multihead_similarity_vectorized_simple(x, features, prototypes, theta, alpha, beta, n_heads):
batch_size, total_dim = x.shape
d_head = total_dim // n_heads
x = rearrange(x, 'b (h d) -> b h d', h=n_heads)
features = rearrange(features, 'f (h d) -> h f d', h=n_heads)
prototypes = rearrange(prototypes, 'p (h d) -> h p d', h=n_heads)
x_features = torch.einsum('bhd,hfd->bhf', x, features)
p_features = torch.einsum('hpd,hfd->hpf', prototypes, features)
x_present = F.relu(x_features)
p_present = F.relu(p_features)
x_weighted = x_features * x_present
p_weighted = p_features * p_present
common = torch.einsum('bhf,hpf->bhp', x_weighted, p_weighted)
x_weighted_sum = x_weighted.sum(dim=2, keepdim=True)
x_p_interaction = torch.einsum('bhf,hpf->bhp', x_weighted, p_present)
x_distinctive = x_weighted_sum - x_p_interaction
p_weighted_sum = p_weighted.sum(dim=2).unsqueeze(0)
x_p_weighted_interaction = torch.einsum('bhf,hpf->bhp', x_present, p_weighted)
p_distinctive = p_weighted_sum - x_p_weighted_interaction
theta = rearrange(theta, 'h 1 -> 1 h 1')
alpha = rearrange(alpha, 'h 1 -> 1 h 1')
beta = rearrange(beta, 'h 1 -> 1 h 1')
return theta * common - alpha * x_distinctive - beta * p_distinctive
def tversky_multihead_similarity_vectorized_optimized(x, features, prototypes, theta, alpha, beta, n_heads):
batch_size, total_dim = x.shape
d_head = total_dim // n_heads
# Multihead expansion
x = rearrange(x, 'b (h d) -> b h d', h=n_heads) # [batch, heads, d_head]
features = rearrange(features, 'f (h d) -> h f d', h=n_heads) # [heads, features, d_head]
prototypes = rearrange(prototypes, 'p (h d) -> h p d', h=n_heads) # [heads, prototypes, d_head]
# Full features (hidden * features per head, proto * features per head)
x_features = torch.einsum('bhd,hfd->bhf', x, features) # [batch, heads, features]
p_features = torch.einsum('hpd,hfd->hpf', prototypes, features) # [heads, prototypes, features]
# Presence masking
x_present = F.relu(x_features) # [batch, heads, features]
p_present = F.relu(p_features) # [heads, prototypes, features]
x_weighted = x_features * x_present # [batch, heads, features]
p_weighted = p_features * p_present # [heads, prototypes, features]
# BMM to avoid [batch, heads, prototypes, features] materialization
x_weighted_h = x_weighted.transpose(0, 1) # [heads, batch, features]
p_weighted_h = p_weighted.transpose(1, 2) # [heads, features, prototypes]
# Original: torch.einsum('bhf,hpf->bhp', x_weighted, p_weighted) would require broadcasting
# where intermediate tensors expand to [batch, heads, prototypes, features] for element-wise multiply
# Equivalent: sum_f(x_weighted[b,h,f] * p_weighted[h,p,f]) for each (b,h,p)
# = sum_f((x_weighted[b,h,f]) * (p_weighted[h,p,f]))
# = x_weighted[b,h,:] @ p_weighted[h,:,p] for each head h
# = torch.bmm([heads, batch, features], [heads, features, prototypes])
common = torch.bmm(x_weighted_h, p_weighted_h).transpose(0, 1) # [batch, heads, prototypes]
# Same idea, avoid [batch, heads, prototypes, features] materialization
x_weighted_sum = x_weighted.sum(dim=2, keepdim=True) # [batch, heads, 1]
p_present_h = p_present.transpose(1, 2) # [heads, features, prototypes]
# Original: torch.einsum('bhf,hpf->bhp', x_weighted, p_present)
# where p_present would broadcast to [batch, heads, prototypes, features]
# Equivalent: sum_f(x_weighted[b,h,f] * p_present[h,p,f]) for each (b,h,p)
# = x_weighted[b,h,:] @ p_present[h,:,p] for each head h
# = torch.bmm([heads, batch, features], [heads, features, prototypes])
x_p_interaction = torch.bmm(x_weighted_h, p_present_h).transpose(0, 1) # [batch, heads, prototypes]
x_distinctive = x_weighted_sum - x_p_interaction # [batch, heads, prototypes]
# And again same to avoid [batch, heads, prototypes, features] materialization
p_weighted_sum = p_weighted.sum(dim=2).unsqueeze(0) # [1, heads, prototypes]
x_present_h = x_present.transpose(0, 1) # [heads, batch, features]
# Original: torch.einsum('bhf,hpf->bhp', x_present, p_weighted)
# where x_present would broadcast to [batch, heads, prototypes, features]
# Equivalent: sum_f(x_present[b,h,f] * p_weighted[h,p,f]) for each (b,h,p)
# = x_present[b,h,:] @ p_weighted[h,:,p] for each head h
# = torch.bmm([heads, batch, features], [heads, features, prototypes])
p_x_interaction = torch.bmm(x_present_h, p_weighted_h).transpose(0, 1) # [batch, heads, prototypes]
p_distinctive = p_weighted_sum - p_x_interaction # [batch, heads, prototypes]
theta = rearrange(theta, 'h 1 -> 1 h 1')
alpha = rearrange(alpha, 'h 1 -> 1 h 1')
beta = rearrange(beta, 'h 1 -> 1 h 1')
return theta * common - alpha * x_distinctive - beta * p_distinctive # [batch, heads, prototypes]
import torch
import torch.nn.functional as F
from einops import rearrange
import time
import tracemalloc
def test_tversky_multihead_implementations():
torch.manual_seed(42)
batch_size, n_heads, d_head = 32, 8, 64
n_features, n_prototypes = 512, 128
total_dim = n_heads * d_head
x = torch.randn(batch_size, total_dim)
features = torch.randn(n_features, total_dim)
prototypes = torch.randn(n_prototypes, total_dim)
theta = torch.randn(n_heads, 1)
alpha = torch.randn(n_heads, 1)
beta = torch.randn(n_heads, 1)
def measure_memory_and_time(func, name):
torch.cuda.empty_cache() if torch.cuda.is_available() else None
tracemalloc.start()
start_time = time.perf_counter()
result = func(x, features, prototypes, theta, alpha, beta, n_heads)
end_time = time.perf_counter()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
return result, end_time - start_time, peak / 1024 / 1024
result_simple, time_simple, memory_simple = measure_memory_and_time(
tversky_multihead_similarity_vectorized_simple, "simple"
)
result_optimized, time_optimized, memory_optimized = measure_memory_and_time(
tversky_multihead_similarity_vectorized_optimized, "optimized"
)
outputs_match = torch.allclose(result_simple, result_optimized, atol=1e-5)
max_diff = torch.max(torch.abs(result_simple - result_optimized))
print(f"Outputs match: {outputs_match}")
print(f"Max difference: {max_diff:.2e}")
print(f"Simple - Time: {time_simple:.4f}s, Memory: {memory_simple:.1f}MB")
print(f"Optimized - Time: {time_optimized:.4f}s, Memory: {memory_optimized:.1f}MB")
print(f"Speedup: {time_simple/time_optimized:.2f}x")
print(f"Memory reduction: {memory_simple/memory_optimized:.2f}x")
return outputs_match
test_tversky_multihead_implementations()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment