Skip to content

Instantly share code, notes, and snippets.

@proger
Created October 31, 2023 21:24
Show Gist options
  • Select an option

  • Save proger/66643e0a3ff6b775d189d98dfd8d081f to your computer and use it in GitHub Desktop.

Select an option

Save proger/66643e0a3ff6b775d189d98dfd8d081f to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import math
def scan0(f, x):
y = [torch.zeros_like(x[..., 0])]
for i in range(0, x.size(-1)):
y.append(f*y[..., i-1] + x[..., i])
return torch.stack(y, dim=-1)
class TimeMixingBlock(nn.Module):
def __init__(self, layer, num_layers, S):
super().__init__()
S = D//4
indices = torch.arange(S)/S
layer_frac = 1 - layer/num_layers
self.mu_key = nn.Parameter(indices.pow(layer_frac))
self.mu_value = nn.Parameter(indices.pow(layer_frac) + 0.3*layer/(num_layers-1))
self.mu_gate = nn.Parameter(0.5 * indices.pow(layer_frac))
indices = torch.arange(D)/(D-1)
self.decay = nn.Parameter(-5 + 8 * indices.pow(0.7 + 1.3 * layer/(num_layers-1)), requires_grad=False)
self.bonus = nn.Parameter(0.5* (((indices+1) % 3)-1) + math.log(0.3), requires_grad=False)
self.w_gate = nn.Linear(S, D, bias=False)
nn.init.zeros_(self.w_gate.weight)
self.w_key = nn.Linear(S, D, bias=False)
nn.init.zeros_(self.w_key.weight)
self.w_value = nn.Linear(S, D, bias=False)
nn.init.zeros_(self.w_value.weight)
self.w_o = nn.Linear(D, S, bias=False)
nn.init.normal_(self.w_o.weight, std=math.sqrt(D/S))
def forward(self, x):
gate = self.w_gate(self.mu_gate * x[:, :, 1:] + (1 - self.mu_gate) * x[:, :, :-1])
key = self.w_key(self.mu_key * x[:, :, 1:] + (1 - self.mu_key) * x[:, :, :-1])
value = self.w_value(self.mu_value * x[:, :, 1:] + (1 - self.mu_value) * x[:, :, :-1])
decay = (-self.decay).exp()
key_exp = key.exp()
key_bonus_exp = (self.bonus + key).exp()
# TODO: prevent overflows
a = scan0(decay, key_exp * value) # (T+1)
b = scan0(decay, key_exp)
wkv = (a[..., :-1] + key_bonus_exp * value) / (b[..., :-1] + key_bonus_exp)
y = self.w_o(x_gate.sigmoid() * wkv)
return y
class ChannelMixingBlock(nn.Module):
def __init__(self, layer, num_layers, D):
super().__init__()
S = D//4
indices = torch.arange(S)/S
layer_frac = 1 - layer/num_layers
self.mu_key = nn.Parameter(indices.pow(layer_frac))
self.mu_value = nn.Parameter(indices.pow(layer_frac))
self.w_gate = nn.Linear(S, S, bias=False)
nn.init.zeros_(self.w_gate.weight)
self.w_key = nn.Linear(S, D, bias=False)
nn.init.zeros_(self.w_key.weight)
self.w_value = nn.Linear(D, S, bias=False)
nn.init.normal_(self.w_value.weight, std=math.sqrt(D/S))
def forward(self, x):
gate = self.w_gate(self.mu_gate * x[:, :, 1:] + (1 - self.mu_gate) * x[:, :, :-1])
key = self.w_key(self.mu_key * x[:, :, 1:] + (1 - self.mu_key) * x[:, :, :-1])
value = self.w_value(key.relu().pow(2))
return gate.sigmoid() * value
class RWKVBlock(nn.Module):
def __init__(self, layer, num_layers, D):
super().__init__()
S = D//4
self.ln_time = nn.LayerNorm(S)
self.time_mix = TimeMixingBlock(layer, num_layers, D)
self.ln_channel = nn.LayerNorm(S)
self.channel_mix = ChannelMixingBlock(layer, num_layers, D)
def forward(self, x):
x = x + self.time_mix(self.ln_time(x))
x = x + self.channel_mix(self.ln_channel(x))
return x
class RWKV(nn.Module):
def __init__(self, vocab_size=512, num_layers=2, D=16):
super().__init__()
S = D//4
self.tokens = nn.Embedding(vocab_size, S)
nn.init.uniform_(self.tokens.weight, -1e-4, 1e-4)
self.ln_pre = nn.LayerNorm(S)
self.blocks = nn.ModuleList([RWKVBlock(i, num_layers, D) for i in range(num_layers)])
self.ln_post = nn.LayerNorm(S)
self.decoder = nn.Linear(S, vocab_size, bias=False)
def forward(self, x):
x = self.tokens(x)
x = self.ln_pre(x)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
x = self.decoder(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment