Skip to content

Instantly share code, notes, and snippets.

@proger
Last active August 30, 2024 22:58
Show Gist options
  • Save proger/60b15a0812b232fc30cfdfd67b66b1a8 to your computer and use it in GitHub Desktop.
Save proger/60b15a0812b232fc30cfdfd67b66b1a8 to your computer and use it in GitHub Desktop.
tensor product representation capacity
#%%
from collections import defaultdict
import bisect
import json
import matplotlib.pyplot as plt
import torch
from matplotlib import rcParams
rcParams['font.family'] = 'serif'
def cdiv(a, b):
return (a + b - 1) // b
trace = defaultdict(list)
## comment this out to start from scratch:
for line in open('trace.jsonl'):
data = json.loads(line)
if 'chunks' in data:
trace[data['D']].append(data)
torch.manual_seed(0)
torch.set_default_device('cuda')
torch.set_default_dtype(torch.float16)
plt.figure(figsize=(10, 5))
plt.legend(loc='lower left')
dist = 'normal' # or 'uniform'
#times = list(range(1, 2**13))
#times = [2**t for t in range(14, 21)]
times = [2**t for t in range(0, 21)]
dimensions = [2**d for d in range(0, 13)] # 14 starts to OOM early with this code
#times = [2**t for t in range(0, 10)]
#dimensions = [2**d for d in range(0, 13)]
print(json.dumps(dict(
times=times,
dimensions=dimensions,
)), flush=True)
plt.title('D$\\times$D Associative Retrieval')
plt.ylabel('Retrieval Accuracy')
dist_label = '$\mathcal{U}(-1,1)/\sqrt{D}$' if dist == 'uniform' else '$\mathcal{N}(0,I)/\sqrt{D}$'
plt.xlabel('Stored key value pairs (ranks), each key and value is drawn from ' + dist_label)
plt.xticks(range(len(times)), [str(t) for t in times], rotation=90)
max_ts = []
for D in dimensions:
accuracies = [t['accuracy'] for t in trace[D] if t['T'] in times]
for T in (times if not accuracies else []):
N = 10
accuracy = 0
for seed in range(N):
if dist == 'uniform':
keys = (2*torch.rand(T, D)-1) / D**0.5
values = (2*torch.rand(T, D)-1) / D**0.5
else:
keys = torch.randn(T, D) / D**0.5
values = torch.randn(T, D) / D**0.5
if D < 2048:
max_chunk = 2048
else:
max_chunk = 128 # ok for 8192
if T >= 16834 and D >= 8192:
max_chunk = 1 # doesn't fit on 3090
chunks = cdiv(T, max_chunk)
W = torch.zeros(D, D)
for key_chunk, value_chunk in zip(keys.chunk(chunks, dim=0), values.chunk(chunks, dim=0)):
W.add_((key_chunk[:, :, None] * value_chunk[:, None, :]).sum(dim=0))
hard_accuracy = None
if T <= max_chunk:
# hard attention
A = (keys @ W) @ values.T
hard_accuracy = (A.argmax(dim=-1) == torch.arange(T)).sum().item()
# chunked hard attention
offset = 0
soft_accuracy = 0
for key_chunk in keys.chunk(chunks, dim=0):
A = (key_chunk @ W) @ values.T
soft_accuracy += (A.argmax(dim=-1) == torch.arange(offset, offset+key_chunk.shape[0])).sum().item()
offset += key_chunk.shape[0]
if hard_accuracy is not None:
assert soft_accuracy == hard_accuracy, f'{soft_accuracy} != {hard_accuracy}'
accuracy += soft_accuracy
accuracy = accuracy / (T*N)
accuracies.append(accuracy)
print(json.dumps(
dict(D=D, T=T, accuracy=accuracy, chunks=chunks)
), flush=True)
if accuracy < 0.6:
break
saturation_index = bisect.bisect_right([-a for a in accuracies], -1.0 + 1/times[-1])
max_t = times[min(saturation_index, len(times)-1)]
print(json.dumps(dict(
D=D,
max_t=max_t,
optimal_words_per_pair=D*D/max_t,
accuracies=accuracies,
times=times,
)), flush=True)
max_ts.append(max_t)
plt.plot(accuracies, label=f'D={D}', marker='o', linestyle='--', linewidth=0.5)
plt.grid(True)
plt.savefig('retrieval.pdf', bbox_inches='tight')
plt.close()
plt.figure(figsize=(8, 5))
plt.title('D$\\times$D Associative Matrix Capacity')
plt.ylabel('Optimal number of key value pairs until saturation')
plt.xlabel('D: key and value dimension')
plt.xticks(range(len(dimensions)), [str(d) for d in dimensions])
plt.yticks(times, [str(t) for t in times])
plt.scatter(range(len(dimensions)), max_ts, s=[D*D/max_t for D, max_t in zip(dimensions, max_ts)])
plt.grid(True)
plt.yscale('log', base=2)
plt.savefig('capacity.pdf', bbox_inches='tight')
{"times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576], "dimensions": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]}
{"D": 1, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 1, "T": 2, "accuracy": 0.65, "chunks": 1}
{"D": 1, "T": 4, "accuracy": 0.375, "chunks": 1}
{"D": 1, "max_t": 2, "optimal_words_per_pair": 0.5, "accuracies": [1.0, 0.65, 0.375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 2, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 2, "T": 2, "accuracy": 0.85, "chunks": 1}
{"D": 2, "T": 4, "accuracy": 0.45, "chunks": 1}
{"D": 2, "max_t": 4, "optimal_words_per_pair": 1.0, "accuracies": [1.0, 0.85, 0.45], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 4, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 4, "T": 2, "accuracy": 0.85, "chunks": 1}
{"D": 4, "T": 4, "accuracy": 0.55, "chunks": 1}
{"D": 4, "max_t": 4, "optimal_words_per_pair": 4.0, "accuracies": [1.0, 0.85, 0.55], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 8, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 8, "T": 2, "accuracy": 0.95, "chunks": 1}
{"D": 8, "T": 4, "accuracy": 0.85, "chunks": 1}
{"D": 8, "T": 8, "accuracy": 0.6, "chunks": 1}
{"D": 8, "T": 16, "accuracy": 0.45, "chunks": 1}
{"D": 8, "max_t": 4, "optimal_words_per_pair": 16.0, "accuracies": [1.0, 0.95, 0.85, 0.6, 0.45], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 16, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 16, "T": 2, "accuracy": 0.95, "chunks": 1}
{"D": 16, "T": 4, "accuracy": 0.975, "chunks": 1}
{"D": 16, "T": 8, "accuracy": 0.925, "chunks": 1}
{"D": 16, "T": 16, "accuracy": 0.73125, "chunks": 1}
{"D": 16, "T": 32, "accuracy": 0.49375, "chunks": 1}
{"D": 16, "max_t": 2, "optimal_words_per_pair": 128.0, "accuracies": [1.0, 0.95, 0.975, 0.925, 0.73125, 0.49375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 32, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 32, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 32, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 32, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 32, "T": 16, "accuracy": 0.96875, "chunks": 1}
{"D": 32, "T": 32, "accuracy": 0.871875, "chunks": 1}
{"D": 32, "T": 64, "accuracy": 0.684375, "chunks": 1}
{"D": 32, "T": 128, "accuracy": 0.390625, "chunks": 1}
{"D": 32, "max_t": 16, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 0.96875, 0.871875, 0.684375, 0.390625], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 64, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 64, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 64, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 64, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 64, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 64, "T": 32, "accuracy": 0.996875, "chunks": 1}
{"D": 64, "T": 64, "accuracy": 0.9875, "chunks": 1}
{"D": 64, "T": 128, "accuracy": 0.9296875, "chunks": 1}
{"D": 64, "T": 256, "accuracy": 0.670703125, "chunks": 1}
{"D": 64, "T": 512, "accuracy": 0.30703125, "chunks": 1}
{"D": 64, "max_t": 32, "optimal_words_per_pair": 128.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 0.996875, 0.9875, 0.9296875, 0.670703125, 0.30703125], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 128, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 128, "T": 256, "accuracy": 0.998046875, "chunks": 1}
{"D": 128, "T": 512, "accuracy": 0.9451171875, "chunks": 1}
{"D": 128, "T": 1024, "accuracy": 0.63330078125, "chunks": 1}
{"D": 128, "T": 2048, "accuracy": 0.22646484375, "chunks": 1}
{"D": 128, "max_t": 256, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.998046875, 0.9451171875, 0.63330078125, 0.22646484375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 256, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 256, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 512, "accuracy": 1.0, "chunks": 1}
{"D": 256, "T": 1024, "accuracy": 0.99951171875, "chunks": 1}
{"D": 256, "T": 2048, "accuracy": 0.944482421875, "chunks": 1}
{"D": 256, "T": 4096, "accuracy": 0.559716796875, "chunks": 2}
{"D": 256, "max_t": 1024, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99951171875, 0.944482421875, 0.559716796875], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 512, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 256, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 512, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 1024, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 2048, "accuracy": 1.0, "chunks": 1}
{"D": 512, "T": 4096, "accuracy": 0.999658203125, "chunks": 2}
{"D": 512, "T": 8192, "accuracy": 0.93275146484375, "chunks": 4}
{"D": 512, "T": 16384, "accuracy": 0.464215087890625, "chunks": 8}
{"D": 512, "max_t": 4096, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.999658203125, 0.93275146484375, 0.464215087890625], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 1024, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 256, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 512, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 1024, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 2048, "accuracy": 1.0, "chunks": 1}
{"D": 1024, "T": 4096, "accuracy": 1.0, "chunks": 2}
{"D": 1024, "T": 8192, "accuracy": 1.0, "chunks": 4}
{"D": 1024, "T": 16384, "accuracy": 0.9997314453125, "chunks": 8}
{"D": 1024, "T": 32768, "accuracy": 0.9050018310546875, "chunks": 16}
{"D": 1024, "T": 65536, "accuracy": 0.36622467041015627, "chunks": 32}
{"D": 1024, "max_t": 16384, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9997314453125, 0.9050018310546875, 0.36622467041015627], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 2048, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 2048, "T": 256, "accuracy": 1.0, "chunks": 2}
{"D": 2048, "T": 512, "accuracy": 1.0, "chunks": 4}
{"D": 2048, "T": 1024, "accuracy": 1.0, "chunks": 8}
{"D": 2048, "T": 2048, "accuracy": 1.0, "chunks": 16}
{"D": 2048, "T": 4096, "accuracy": 1.0, "chunks": 32}
{"D": 2048, "T": 8192, "accuracy": 1.0, "chunks": 64}
{"D": 2048, "T": 16384, "accuracy": 1.0, "chunks": 128}
{"D": 2048, "T": 32768, "accuracy": 1.0, "chunks": 256}
{"D": 2048, "T": 65536, "accuracy": 0.9996078491210938, "chunks": 512}
{"D": 2048, "T": 131072, "accuracy": 0.8634902954101562, "chunks": 1024}
{"D": 2048, "T": 262144, "accuracy": 0.2733623504638672, "chunks": 2048}
{"D": 2048, "max_t": 65536, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9996078491210938, 0.8634902954101562, 0.2733623504638672], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 4096, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 4096, "T": 256, "accuracy": 1.0, "chunks": 2}
{"D": 4096, "T": 512, "accuracy": 1.0, "chunks": 4}
{"D": 4096, "T": 1024, "accuracy": 1.0, "chunks": 8}
{"D": 4096, "T": 2048, "accuracy": 1.0, "chunks": 16}
{"D": 4096, "T": 4096, "accuracy": 1.0, "chunks": 32}
{"D": 4096, "T": 8192, "accuracy": 1.0, "chunks": 64}
{"D": 4096, "T": 16384, "accuracy": 1.0, "chunks": 128}
{"D": 4096, "T": 32768, "accuracy": 1.0, "chunks": 256}
{"D": 4096, "T": 65536, "accuracy": 1.0, "chunks": 512}
{"D": 4096, "T": 131072, "accuracy": 1.0, "chunks": 1024}
{"D": 4096, "T": 262144, "accuracy": 0.9992786407470703, "chunks": 2048}
{"D": 4096, "T": 524288, "accuracy": 0.8040372848510742, "chunks": 4096}
{"D": 4096, "max_t": 262144, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9992786407470703, 0.8040372848510742], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]}
{"D": 8192, "T": 1, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 2, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 4, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 8, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 16, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 32, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 64, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 128, "accuracy": 1.0, "chunks": 1}
{"D": 8192, "T": 256, "accuracy": 1.0, "chunks": 2}
{"D": 8192, "T": 512, "accuracy": 1.0, "chunks": 4}
{"D": 8192, "T": 1024, "accuracy": 1.0, "chunks": 8}
{"D": 8192, "T": 2048, "accuracy": 1.0, "chunks": 16}
{"D": 8192, "T": 4096, "accuracy": 1.0, "chunks": 32}
{"D": 8192, "T": 8192, "accuracy": 1.0, "chunks": 64}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment