Last active
August 30, 2024 22:58
-
-
Save proger/60b15a0812b232fc30cfdfd67b66b1a8 to your computer and use it in GitHub Desktop.
tensor product representation capacity
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
from collections import defaultdict | |
import bisect | |
import json | |
import matplotlib.pyplot as plt | |
import torch | |
from matplotlib import rcParams | |
rcParams['font.family'] = 'serif' | |
def cdiv(a, b): | |
return (a + b - 1) // b | |
trace = defaultdict(list) | |
## comment this out to start from scratch: | |
for line in open('trace.jsonl'): | |
data = json.loads(line) | |
if 'chunks' in data: | |
trace[data['D']].append(data) | |
torch.manual_seed(0) | |
torch.set_default_device('cuda') | |
torch.set_default_dtype(torch.float16) | |
plt.figure(figsize=(10, 5)) | |
plt.legend(loc='lower left') | |
dist = 'normal' # or 'uniform' | |
#times = list(range(1, 2**13)) | |
#times = [2**t for t in range(14, 21)] | |
times = [2**t for t in range(0, 21)] | |
dimensions = [2**d for d in range(0, 13)] # 14 starts to OOM early with this code | |
#times = [2**t for t in range(0, 10)] | |
#dimensions = [2**d for d in range(0, 13)] | |
print(json.dumps(dict( | |
times=times, | |
dimensions=dimensions, | |
)), flush=True) | |
plt.title('D$\\times$D Associative Retrieval') | |
plt.ylabel('Retrieval Accuracy') | |
dist_label = '$\mathcal{U}(-1,1)/\sqrt{D}$' if dist == 'uniform' else '$\mathcal{N}(0,I)/\sqrt{D}$' | |
plt.xlabel('Stored key value pairs (ranks), each key and value is drawn from ' + dist_label) | |
plt.xticks(range(len(times)), [str(t) for t in times], rotation=90) | |
max_ts = [] | |
for D in dimensions: | |
accuracies = [t['accuracy'] for t in trace[D] if t['T'] in times] | |
for T in (times if not accuracies else []): | |
N = 10 | |
accuracy = 0 | |
for seed in range(N): | |
if dist == 'uniform': | |
keys = (2*torch.rand(T, D)-1) / D**0.5 | |
values = (2*torch.rand(T, D)-1) / D**0.5 | |
else: | |
keys = torch.randn(T, D) / D**0.5 | |
values = torch.randn(T, D) / D**0.5 | |
if D < 2048: | |
max_chunk = 2048 | |
else: | |
max_chunk = 128 # ok for 8192 | |
if T >= 16834 and D >= 8192: | |
max_chunk = 1 # doesn't fit on 3090 | |
chunks = cdiv(T, max_chunk) | |
W = torch.zeros(D, D) | |
for key_chunk, value_chunk in zip(keys.chunk(chunks, dim=0), values.chunk(chunks, dim=0)): | |
W.add_((key_chunk[:, :, None] * value_chunk[:, None, :]).sum(dim=0)) | |
hard_accuracy = None | |
if T <= max_chunk: | |
# hard attention | |
A = (keys @ W) @ values.T | |
hard_accuracy = (A.argmax(dim=-1) == torch.arange(T)).sum().item() | |
# chunked hard attention | |
offset = 0 | |
soft_accuracy = 0 | |
for key_chunk in keys.chunk(chunks, dim=0): | |
A = (key_chunk @ W) @ values.T | |
soft_accuracy += (A.argmax(dim=-1) == torch.arange(offset, offset+key_chunk.shape[0])).sum().item() | |
offset += key_chunk.shape[0] | |
if hard_accuracy is not None: | |
assert soft_accuracy == hard_accuracy, f'{soft_accuracy} != {hard_accuracy}' | |
accuracy += soft_accuracy | |
accuracy = accuracy / (T*N) | |
accuracies.append(accuracy) | |
print(json.dumps( | |
dict(D=D, T=T, accuracy=accuracy, chunks=chunks) | |
), flush=True) | |
if accuracy < 0.6: | |
break | |
saturation_index = bisect.bisect_right([-a for a in accuracies], -1.0 + 1/times[-1]) | |
max_t = times[min(saturation_index, len(times)-1)] | |
print(json.dumps(dict( | |
D=D, | |
max_t=max_t, | |
optimal_words_per_pair=D*D/max_t, | |
accuracies=accuracies, | |
times=times, | |
)), flush=True) | |
max_ts.append(max_t) | |
plt.plot(accuracies, label=f'D={D}', marker='o', linestyle='--', linewidth=0.5) | |
plt.grid(True) | |
plt.savefig('retrieval.pdf', bbox_inches='tight') | |
plt.close() | |
plt.figure(figsize=(8, 5)) | |
plt.title('D$\\times$D Associative Matrix Capacity') | |
plt.ylabel('Optimal number of key value pairs until saturation') | |
plt.xlabel('D: key and value dimension') | |
plt.xticks(range(len(dimensions)), [str(d) for d in dimensions]) | |
plt.yticks(times, [str(t) for t in times]) | |
plt.scatter(range(len(dimensions)), max_ts, s=[D*D/max_t for D, max_t in zip(dimensions, max_ts)]) | |
plt.grid(True) | |
plt.yscale('log', base=2) | |
plt.savefig('capacity.pdf', bbox_inches='tight') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576], "dimensions": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]} | |
{"D": 1, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 1, "T": 2, "accuracy": 0.65, "chunks": 1} | |
{"D": 1, "T": 4, "accuracy": 0.375, "chunks": 1} | |
{"D": 1, "max_t": 2, "optimal_words_per_pair": 0.5, "accuracies": [1.0, 0.65, 0.375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 2, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 2, "T": 2, "accuracy": 0.85, "chunks": 1} | |
{"D": 2, "T": 4, "accuracy": 0.45, "chunks": 1} | |
{"D": 2, "max_t": 4, "optimal_words_per_pair": 1.0, "accuracies": [1.0, 0.85, 0.45], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 4, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 4, "T": 2, "accuracy": 0.85, "chunks": 1} | |
{"D": 4, "T": 4, "accuracy": 0.55, "chunks": 1} | |
{"D": 4, "max_t": 4, "optimal_words_per_pair": 4.0, "accuracies": [1.0, 0.85, 0.55], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 8, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 8, "T": 2, "accuracy": 0.95, "chunks": 1} | |
{"D": 8, "T": 4, "accuracy": 0.85, "chunks": 1} | |
{"D": 8, "T": 8, "accuracy": 0.6, "chunks": 1} | |
{"D": 8, "T": 16, "accuracy": 0.45, "chunks": 1} | |
{"D": 8, "max_t": 4, "optimal_words_per_pair": 16.0, "accuracies": [1.0, 0.95, 0.85, 0.6, 0.45], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 16, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 16, "T": 2, "accuracy": 0.95, "chunks": 1} | |
{"D": 16, "T": 4, "accuracy": 0.975, "chunks": 1} | |
{"D": 16, "T": 8, "accuracy": 0.925, "chunks": 1} | |
{"D": 16, "T": 16, "accuracy": 0.73125, "chunks": 1} | |
{"D": 16, "T": 32, "accuracy": 0.49375, "chunks": 1} | |
{"D": 16, "max_t": 2, "optimal_words_per_pair": 128.0, "accuracies": [1.0, 0.95, 0.975, 0.925, 0.73125, 0.49375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 32, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 32, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 32, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 32, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 32, "T": 16, "accuracy": 0.96875, "chunks": 1} | |
{"D": 32, "T": 32, "accuracy": 0.871875, "chunks": 1} | |
{"D": 32, "T": 64, "accuracy": 0.684375, "chunks": 1} | |
{"D": 32, "T": 128, "accuracy": 0.390625, "chunks": 1} | |
{"D": 32, "max_t": 16, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 0.96875, 0.871875, 0.684375, 0.390625], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 64, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 64, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 64, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 64, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 64, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 64, "T": 32, "accuracy": 0.996875, "chunks": 1} | |
{"D": 64, "T": 64, "accuracy": 0.9875, "chunks": 1} | |
{"D": 64, "T": 128, "accuracy": 0.9296875, "chunks": 1} | |
{"D": 64, "T": 256, "accuracy": 0.670703125, "chunks": 1} | |
{"D": 64, "T": 512, "accuracy": 0.30703125, "chunks": 1} | |
{"D": 64, "max_t": 32, "optimal_words_per_pair": 128.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 0.996875, 0.9875, 0.9296875, 0.670703125, 0.30703125], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 128, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 128, "T": 256, "accuracy": 0.998046875, "chunks": 1} | |
{"D": 128, "T": 512, "accuracy": 0.9451171875, "chunks": 1} | |
{"D": 128, "T": 1024, "accuracy": 0.63330078125, "chunks": 1} | |
{"D": 128, "T": 2048, "accuracy": 0.22646484375, "chunks": 1} | |
{"D": 128, "max_t": 256, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.998046875, 0.9451171875, 0.63330078125, 0.22646484375], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 256, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 256, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 512, "accuracy": 1.0, "chunks": 1} | |
{"D": 256, "T": 1024, "accuracy": 0.99951171875, "chunks": 1} | |
{"D": 256, "T": 2048, "accuracy": 0.944482421875, "chunks": 1} | |
{"D": 256, "T": 4096, "accuracy": 0.559716796875, "chunks": 2} | |
{"D": 256, "max_t": 1024, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99951171875, 0.944482421875, 0.559716796875], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 512, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 256, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 512, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 1024, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 2048, "accuracy": 1.0, "chunks": 1} | |
{"D": 512, "T": 4096, "accuracy": 0.999658203125, "chunks": 2} | |
{"D": 512, "T": 8192, "accuracy": 0.93275146484375, "chunks": 4} | |
{"D": 512, "T": 16384, "accuracy": 0.464215087890625, "chunks": 8} | |
{"D": 512, "max_t": 4096, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.999658203125, 0.93275146484375, 0.464215087890625], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 1024, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 256, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 512, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 1024, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 2048, "accuracy": 1.0, "chunks": 1} | |
{"D": 1024, "T": 4096, "accuracy": 1.0, "chunks": 2} | |
{"D": 1024, "T": 8192, "accuracy": 1.0, "chunks": 4} | |
{"D": 1024, "T": 16384, "accuracy": 0.9997314453125, "chunks": 8} | |
{"D": 1024, "T": 32768, "accuracy": 0.9050018310546875, "chunks": 16} | |
{"D": 1024, "T": 65536, "accuracy": 0.36622467041015627, "chunks": 32} | |
{"D": 1024, "max_t": 16384, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9997314453125, 0.9050018310546875, 0.36622467041015627], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 2048, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 2048, "T": 256, "accuracy": 1.0, "chunks": 2} | |
{"D": 2048, "T": 512, "accuracy": 1.0, "chunks": 4} | |
{"D": 2048, "T": 1024, "accuracy": 1.0, "chunks": 8} | |
{"D": 2048, "T": 2048, "accuracy": 1.0, "chunks": 16} | |
{"D": 2048, "T": 4096, "accuracy": 1.0, "chunks": 32} | |
{"D": 2048, "T": 8192, "accuracy": 1.0, "chunks": 64} | |
{"D": 2048, "T": 16384, "accuracy": 1.0, "chunks": 128} | |
{"D": 2048, "T": 32768, "accuracy": 1.0, "chunks": 256} | |
{"D": 2048, "T": 65536, "accuracy": 0.9996078491210938, "chunks": 512} | |
{"D": 2048, "T": 131072, "accuracy": 0.8634902954101562, "chunks": 1024} | |
{"D": 2048, "T": 262144, "accuracy": 0.2733623504638672, "chunks": 2048} | |
{"D": 2048, "max_t": 65536, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9996078491210938, 0.8634902954101562, 0.2733623504638672], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 4096, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 4096, "T": 256, "accuracy": 1.0, "chunks": 2} | |
{"D": 4096, "T": 512, "accuracy": 1.0, "chunks": 4} | |
{"D": 4096, "T": 1024, "accuracy": 1.0, "chunks": 8} | |
{"D": 4096, "T": 2048, "accuracy": 1.0, "chunks": 16} | |
{"D": 4096, "T": 4096, "accuracy": 1.0, "chunks": 32} | |
{"D": 4096, "T": 8192, "accuracy": 1.0, "chunks": 64} | |
{"D": 4096, "T": 16384, "accuracy": 1.0, "chunks": 128} | |
{"D": 4096, "T": 32768, "accuracy": 1.0, "chunks": 256} | |
{"D": 4096, "T": 65536, "accuracy": 1.0, "chunks": 512} | |
{"D": 4096, "T": 131072, "accuracy": 1.0, "chunks": 1024} | |
{"D": 4096, "T": 262144, "accuracy": 0.9992786407470703, "chunks": 2048} | |
{"D": 4096, "T": 524288, "accuracy": 0.8040372848510742, "chunks": 4096} | |
{"D": 4096, "max_t": 262144, "optimal_words_per_pair": 64.0, "accuracies": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9992786407470703, 0.8040372848510742], "times": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]} | |
{"D": 8192, "T": 1, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 2, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 4, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 8, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 16, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 32, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 64, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 128, "accuracy": 1.0, "chunks": 1} | |
{"D": 8192, "T": 256, "accuracy": 1.0, "chunks": 2} | |
{"D": 8192, "T": 512, "accuracy": 1.0, "chunks": 4} | |
{"D": 8192, "T": 1024, "accuracy": 1.0, "chunks": 8} | |
{"D": 8192, "T": 2048, "accuracy": 1.0, "chunks": 16} | |
{"D": 8192, "T": 4096, "accuracy": 1.0, "chunks": 32} | |
{"D": 8192, "T": 8192, "accuracy": 1.0, "chunks": 64} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment