Skip to content

Instantly share code, notes, and snippets.

View proger's full-sized avatar
🎯
Focusing

Volodymyr Kyrylov proger

🎯
Focusing
View GitHub Profile
"""
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
For a simple single chunk version see `forward_simple`.
It computes decayed values by a little bit of recurrence (`decay_values`)
and then applies linear attention (`causal_attend`).
`forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and
then performs chunk-level stitching.
digit 10076 27
digit 10154 2011
digit 1017 4
digit 10191 33
digit 1025 5
digit 10353 31
digit 10389 2008
digit 10411 120
digit 10607 01
digit 10858 195
"""
Randomized Binary Search Trees
https://www.cs.upc.edu/~conrado/research/papers/jacm-mr98.pdf
"""
import math
import random
from collections import Counter
class root:
"""
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
`forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching.
`forward_loop` is the reference implementation of the original recurrence.
References:
[1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985)
Method 1, section 3 guides `decay_values`.
// uses https://github.com/HazyResearch/ThunderKittens
#include "tk/src/kittens.cuh"
#include "tk/src/common/pyutils/torch_helpers.cuh"
#define NUM_WORKERS 2 // This kernel uses this many workers in parallel per block, to help issue instructions more quickly.
#define DIMENSION 64 // This kernel operates over 64-dimensional vectors
#define DEBUG 0
using namespace kittens; // this kernel only handles headdim=q_reg.cols for simplicity. Also n should be a multiple of 256 here.
"a linear RNN that receives ones as input and gives increasingly better approximations to pi as output"
import numpy as np
import math
def binary(digits: int):
"Make a basis of powers of two of dimension `digits`, lowest bits first"
return 1 << np.arange(digits)
def leibniz(n):
#%%
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
@proger
proger / tpr.py
Last active August 30, 2024 22:58
tensor product representation capacity
#%%
from collections import defaultdict
import bisect
import json
import matplotlib.pyplot as plt
import torch
from matplotlib import rcParams
rcParams['font.family'] = 'serif'
@proger
proger / abv.py
Last active April 26, 2024 07:53
# prompt: https://twitter.com/francoisfleuret/status/1783479122418716805
import os
os.environ['TORCH_LOGS'] = 'output_code' # shows all the bmms
import torch
torch.set_float32_matmul_precision('high')
N, T, D, U, C = 3, 128, 5, 32, 32 # batch, time, heads, head_dim, dim
S = T
A = torch.randn(N, T, D, U) / U**0.5
@proger
proger / xor.py
Last active April 22, 2024 13:36
tensor network that can learn xor
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float()
y = torch.logical_xor(X[:, 0], X[:, 1]).float()