Created
October 31, 2023 21:24
-
-
Save proger/66643e0a3ff6b775d189d98dfd8d081f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import math | |
| def scan0(f, x): | |
| y = [torch.zeros_like(x[..., 0])] | |
| for i in range(0, x.size(-1)): | |
| y.append(f*y[..., i-1] + x[..., i]) | |
| return torch.stack(y, dim=-1) | |
| class TimeMixingBlock(nn.Module): | |
| def __init__(self, layer, num_layers, S): | |
| super().__init__() | |
| S = D//4 | |
| indices = torch.arange(S)/S | |
| layer_frac = 1 - layer/num_layers | |
| self.mu_key = nn.Parameter(indices.pow(layer_frac)) | |
| self.mu_value = nn.Parameter(indices.pow(layer_frac) + 0.3*layer/(num_layers-1)) | |
| self.mu_gate = nn.Parameter(0.5 * indices.pow(layer_frac)) | |
| indices = torch.arange(D)/(D-1) | |
| self.decay = nn.Parameter(-5 + 8 * indices.pow(0.7 + 1.3 * layer/(num_layers-1)), requires_grad=False) | |
| self.bonus = nn.Parameter(0.5* (((indices+1) % 3)-1) + math.log(0.3), requires_grad=False) | |
| self.w_gate = nn.Linear(S, D, bias=False) | |
| nn.init.zeros_(self.w_gate.weight) | |
| self.w_key = nn.Linear(S, D, bias=False) | |
| nn.init.zeros_(self.w_key.weight) | |
| self.w_value = nn.Linear(S, D, bias=False) | |
| nn.init.zeros_(self.w_value.weight) | |
| self.w_o = nn.Linear(D, S, bias=False) | |
| nn.init.normal_(self.w_o.weight, std=math.sqrt(D/S)) | |
| def forward(self, x): | |
| gate = self.w_gate(self.mu_gate * x[:, :, 1:] + (1 - self.mu_gate) * x[:, :, :-1]) | |
| key = self.w_key(self.mu_key * x[:, :, 1:] + (1 - self.mu_key) * x[:, :, :-1]) | |
| value = self.w_value(self.mu_value * x[:, :, 1:] + (1 - self.mu_value) * x[:, :, :-1]) | |
| decay = (-self.decay).exp() | |
| key_exp = key.exp() | |
| key_bonus_exp = (self.bonus + key).exp() | |
| # TODO: prevent overflows | |
| a = scan0(decay, key_exp * value) # (T+1) | |
| b = scan0(decay, key_exp) | |
| wkv = (a[..., :-1] + key_bonus_exp * value) / (b[..., :-1] + key_bonus_exp) | |
| y = self.w_o(x_gate.sigmoid() * wkv) | |
| return y | |
| class ChannelMixingBlock(nn.Module): | |
| def __init__(self, layer, num_layers, D): | |
| super().__init__() | |
| S = D//4 | |
| indices = torch.arange(S)/S | |
| layer_frac = 1 - layer/num_layers | |
| self.mu_key = nn.Parameter(indices.pow(layer_frac)) | |
| self.mu_value = nn.Parameter(indices.pow(layer_frac)) | |
| self.w_gate = nn.Linear(S, S, bias=False) | |
| nn.init.zeros_(self.w_gate.weight) | |
| self.w_key = nn.Linear(S, D, bias=False) | |
| nn.init.zeros_(self.w_key.weight) | |
| self.w_value = nn.Linear(D, S, bias=False) | |
| nn.init.normal_(self.w_value.weight, std=math.sqrt(D/S)) | |
| def forward(self, x): | |
| gate = self.w_gate(self.mu_gate * x[:, :, 1:] + (1 - self.mu_gate) * x[:, :, :-1]) | |
| key = self.w_key(self.mu_key * x[:, :, 1:] + (1 - self.mu_key) * x[:, :, :-1]) | |
| value = self.w_value(key.relu().pow(2)) | |
| return gate.sigmoid() * value | |
| class RWKVBlock(nn.Module): | |
| def __init__(self, layer, num_layers, D): | |
| super().__init__() | |
| S = D//4 | |
| self.ln_time = nn.LayerNorm(S) | |
| self.time_mix = TimeMixingBlock(layer, num_layers, D) | |
| self.ln_channel = nn.LayerNorm(S) | |
| self.channel_mix = ChannelMixingBlock(layer, num_layers, D) | |
| def forward(self, x): | |
| x = x + self.time_mix(self.ln_time(x)) | |
| x = x + self.channel_mix(self.ln_channel(x)) | |
| return x | |
| class RWKV(nn.Module): | |
| def __init__(self, vocab_size=512, num_layers=2, D=16): | |
| super().__init__() | |
| S = D//4 | |
| self.tokens = nn.Embedding(vocab_size, S) | |
| nn.init.uniform_(self.tokens.weight, -1e-4, 1e-4) | |
| self.ln_pre = nn.LayerNorm(S) | |
| self.blocks = nn.ModuleList([RWKVBlock(i, num_layers, D) for i in range(num_layers)]) | |
| self.ln_post = nn.LayerNorm(S) | |
| self.decoder = nn.Linear(S, vocab_size, bias=False) | |
| def forward(self, x): | |
| x = self.tokens(x) | |
| x = self.ln_pre(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.ln_post(x) | |
| x = self.decoder(x) | |
| return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment