Created
September 27, 2025 04:01
-
-
Save CoffeeVampir3/285615057cfc7472222104cf26d47d57 to your computer and use it in GitHub Desktop.
Fast VS Slow Tversky Multihead Bench
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
def tversky_multihead_similarity_vectorized_simple(x, features, prototypes, theta, alpha, beta, n_heads): | |
batch_size, total_dim = x.shape | |
d_head = total_dim // n_heads | |
x = rearrange(x, 'b (h d) -> b h d', h=n_heads) | |
features = rearrange(features, 'f (h d) -> h f d', h=n_heads) | |
prototypes = rearrange(prototypes, 'p (h d) -> h p d', h=n_heads) | |
x_features = torch.einsum('bhd,hfd->bhf', x, features) | |
p_features = torch.einsum('hpd,hfd->hpf', prototypes, features) | |
x_present = F.relu(x_features) | |
p_present = F.relu(p_features) | |
x_weighted = x_features * x_present | |
p_weighted = p_features * p_present | |
common = torch.einsum('bhf,hpf->bhp', x_weighted, p_weighted) | |
x_weighted_sum = x_weighted.sum(dim=2, keepdim=True) | |
x_p_interaction = torch.einsum('bhf,hpf->bhp', x_weighted, p_present) | |
x_distinctive = x_weighted_sum - x_p_interaction | |
p_weighted_sum = p_weighted.sum(dim=2).unsqueeze(0) | |
x_p_weighted_interaction = torch.einsum('bhf,hpf->bhp', x_present, p_weighted) | |
p_distinctive = p_weighted_sum - x_p_weighted_interaction | |
theta = rearrange(theta, 'h 1 -> 1 h 1') | |
alpha = rearrange(alpha, 'h 1 -> 1 h 1') | |
beta = rearrange(beta, 'h 1 -> 1 h 1') | |
return theta * common - alpha * x_distinctive - beta * p_distinctive | |
def tversky_multihead_similarity_vectorized_optimized(x, features, prototypes, theta, alpha, beta, n_heads): | |
batch_size, total_dim = x.shape | |
d_head = total_dim // n_heads | |
# Multihead expansion | |
x = rearrange(x, 'b (h d) -> b h d', h=n_heads) # [batch, heads, d_head] | |
features = rearrange(features, 'f (h d) -> h f d', h=n_heads) # [heads, features, d_head] | |
prototypes = rearrange(prototypes, 'p (h d) -> h p d', h=n_heads) # [heads, prototypes, d_head] | |
# Full features (hidden * features per head, proto * features per head) | |
x_features = torch.einsum('bhd,hfd->bhf', x, features) # [batch, heads, features] | |
p_features = torch.einsum('hpd,hfd->hpf', prototypes, features) # [heads, prototypes, features] | |
# Presence masking | |
x_present = F.relu(x_features) # [batch, heads, features] | |
p_present = F.relu(p_features) # [heads, prototypes, features] | |
x_weighted = x_features * x_present # [batch, heads, features] | |
p_weighted = p_features * p_present # [heads, prototypes, features] | |
# BMM to avoid [batch, heads, prototypes, features] materialization | |
x_weighted_h = x_weighted.transpose(0, 1) # [heads, batch, features] | |
p_weighted_h = p_weighted.transpose(1, 2) # [heads, features, prototypes] | |
# Original: torch.einsum('bhf,hpf->bhp', x_weighted, p_weighted) would require broadcasting | |
# where intermediate tensors expand to [batch, heads, prototypes, features] for element-wise multiply | |
# Equivalent: sum_f(x_weighted[b,h,f] * p_weighted[h,p,f]) for each (b,h,p) | |
# = sum_f((x_weighted[b,h,f]) * (p_weighted[h,p,f])) | |
# = x_weighted[b,h,:] @ p_weighted[h,:,p] for each head h | |
# = torch.bmm([heads, batch, features], [heads, features, prototypes]) | |
common = torch.bmm(x_weighted_h, p_weighted_h).transpose(0, 1) # [batch, heads, prototypes] | |
# Same idea, avoid [batch, heads, prototypes, features] materialization | |
x_weighted_sum = x_weighted.sum(dim=2, keepdim=True) # [batch, heads, 1] | |
p_present_h = p_present.transpose(1, 2) # [heads, features, prototypes] | |
# Original: torch.einsum('bhf,hpf->bhp', x_weighted, p_present) | |
# where p_present would broadcast to [batch, heads, prototypes, features] | |
# Equivalent: sum_f(x_weighted[b,h,f] * p_present[h,p,f]) for each (b,h,p) | |
# = x_weighted[b,h,:] @ p_present[h,:,p] for each head h | |
# = torch.bmm([heads, batch, features], [heads, features, prototypes]) | |
x_p_interaction = torch.bmm(x_weighted_h, p_present_h).transpose(0, 1) # [batch, heads, prototypes] | |
x_distinctive = x_weighted_sum - x_p_interaction # [batch, heads, prototypes] | |
# And again same to avoid [batch, heads, prototypes, features] materialization | |
p_weighted_sum = p_weighted.sum(dim=2).unsqueeze(0) # [1, heads, prototypes] | |
x_present_h = x_present.transpose(0, 1) # [heads, batch, features] | |
# Original: torch.einsum('bhf,hpf->bhp', x_present, p_weighted) | |
# where x_present would broadcast to [batch, heads, prototypes, features] | |
# Equivalent: sum_f(x_present[b,h,f] * p_weighted[h,p,f]) for each (b,h,p) | |
# = x_present[b,h,:] @ p_weighted[h,:,p] for each head h | |
# = torch.bmm([heads, batch, features], [heads, features, prototypes]) | |
p_x_interaction = torch.bmm(x_present_h, p_weighted_h).transpose(0, 1) # [batch, heads, prototypes] | |
p_distinctive = p_weighted_sum - p_x_interaction # [batch, heads, prototypes] | |
theta = rearrange(theta, 'h 1 -> 1 h 1') | |
alpha = rearrange(alpha, 'h 1 -> 1 h 1') | |
beta = rearrange(beta, 'h 1 -> 1 h 1') | |
return theta * common - alpha * x_distinctive - beta * p_distinctive # [batch, heads, prototypes] | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
import time | |
import tracemalloc | |
def test_tversky_multihead_implementations(): | |
torch.manual_seed(42) | |
batch_size, n_heads, d_head = 32, 8, 64 | |
n_features, n_prototypes = 512, 128 | |
total_dim = n_heads * d_head | |
x = torch.randn(batch_size, total_dim) | |
features = torch.randn(n_features, total_dim) | |
prototypes = torch.randn(n_prototypes, total_dim) | |
theta = torch.randn(n_heads, 1) | |
alpha = torch.randn(n_heads, 1) | |
beta = torch.randn(n_heads, 1) | |
def measure_memory_and_time(func, name): | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
tracemalloc.start() | |
start_time = time.perf_counter() | |
result = func(x, features, prototypes, theta, alpha, beta, n_heads) | |
end_time = time.perf_counter() | |
current, peak = tracemalloc.get_traced_memory() | |
tracemalloc.stop() | |
return result, end_time - start_time, peak / 1024 / 1024 | |
result_simple, time_simple, memory_simple = measure_memory_and_time( | |
tversky_multihead_similarity_vectorized_simple, "simple" | |
) | |
result_optimized, time_optimized, memory_optimized = measure_memory_and_time( | |
tversky_multihead_similarity_vectorized_optimized, "optimized" | |
) | |
outputs_match = torch.allclose(result_simple, result_optimized, atol=1e-5) | |
max_diff = torch.max(torch.abs(result_simple - result_optimized)) | |
print(f"Outputs match: {outputs_match}") | |
print(f"Max difference: {max_diff:.2e}") | |
print(f"Simple - Time: {time_simple:.4f}s, Memory: {memory_simple:.1f}MB") | |
print(f"Optimized - Time: {time_optimized:.4f}s, Memory: {memory_optimized:.1f}MB") | |
print(f"Speedup: {time_simple/time_optimized:.2f}x") | |
print(f"Memory reduction: {memory_simple/memory_optimized:.2f}x") | |
return outputs_match | |
test_tversky_multihead_implementations() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment