Last active
January 24, 2024 15:09
-
-
Save maedoc/77f9452af00b4fe4cf7d41eb5a321656 to your computer and use it in GitHub Desktop.
Another short take on RWKV, towards use with time series
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
class MyModule(torch.nn.Module): | |
def add_param(self, key, shape): | |
val = torch.randn(*shape)/np.prod(shape) | |
setattr(self, key, torch.nn.Parameter(val)) | |
def add_params(self, keys, shape): | |
for key in keys.split(' '): | |
self.add_param(key, shape) | |
def parameter_count(self): | |
return sum([p.numel() for p in self.parameters()]) | |
class TimeMix(MyModule): | |
def __init__(self, C): | |
super().__init__() | |
self.C = C | |
self.add_params('mix_k mix_v mix_r first decay', (C,)) | |
self.layer_norm = torch.nn.LayerNorm(C) | |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L174 | |
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) | |
self.key = torch.nn.Linear(C, C, bias=False) | |
self.value = torch.nn.Linear(C, C, bias=False) | |
self.receptance = torch.nn.Linear(C, C, bias=False) | |
self.output = torch.nn.Linear(C, C, bias=False) | |
def forward(self, x, state=None): | |
# in train code, x.size is (batch, time, channel) | |
# so last_x is computed with time shift | |
if x.ndim == 1: | |
# rnn style gen | |
B, T, C = 0, 0, x.shape[0] | |
rnn = True | |
#_, last_x, aa, bb, pp = state | |
last_x = state[1] | |
aa = state[2] | |
bb = state[3] | |
pp = state[4] | |
else: | |
B, T, C = x.shape | |
rnn = False | |
last_x = self.time_shift(x) | |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/cuda/wkv_cuda.cu#L96 | |
aa = torch.zeros_like(x[:,0]) # (B, C) | |
bb = torch.zeros_like(x[:,0]) | |
pp = torch.ones_like(x[:,0])*(-1e38) | |
x = self.layer_norm(x) | |
xk = x * self.mix_k + last_x * (1 - self.mix_k) | |
xv = x * self.mix_v + last_x * (1 - self.mix_v) | |
xr = x * self.mix_r + last_x * (1 - self.mix_r) | |
r = torch.sigmoid(self.receptance(xr)) | |
k = self.key(xk) | |
v = self.value(xv) | |
# rest is in the RUN_CUDA thingy | |
# RUN_CUDA(B, T, dim_att, time_decay, time_first, k, v) | |
# wkv_cuda.forward(B, T, C, w, u, k, v, y), y is output | |
# where dim_att==C is the attention dimension, n in this notebook | |
# so where do aa bb pp come from? | |
# they are computed from start of sequence in cuda kernel | |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/cuda/wkv_cuda.cu#L22 | |
# not yet sure | |
wkv = [] | |
# TODO this could be a jit'd module probably | |
for t in range(1 if rnn else T): | |
# note confusing changes in variable names btw different codes | |
# time_decay -> w, shape (C,) | |
# time_first -> u, shape (C,) | |
# k, v are same, shape (C,) or (B,T,C) | |
# y -> wkv, shape (C,) or (B,T,C) | |
# k, v, y offset in cuda, k[block,:,c] | |
# thread grid computes wkv parallel for all B and C, iters T | |
kk = k if rnn else k[:,t] | |
vv = v if rnn else v[:,t] | |
ww = self.first + kk # u + kk; (C,)+(B,C)->(B,C) | |
qq = torch.maximum(pp, ww) # p = max(pp, ww); (B,C)->(B,C) | |
e1 = torch.exp(pp - qq) # exp(pp - p) % | |
e2 = torch.exp(ww - qq) # exp(ww - p) % | |
# y[ii], wkv[:,t,:] | |
wkv.append( | |
(e1 * aa + e2 * vv) / (e1 * bb + e2) # (B,C) | |
) | |
ww = pp + self.decay # (B,C)+(C,)->(B,C) | |
qq = torch.maximum(ww, kk) # (B,C) | |
e1 = torch.exp(ww - qq) | |
e2 = torch.exp(kk - qq) | |
# retain moving averages for next iter | |
aa = e1 * aa + e2 * vv | |
bb = e1 * bb + e2 | |
pp = qq | |
wkv = wkv[0] if rnn else torch.stack(wkv, dim=1) | |
out = self.output(r * wkv) # whence rwkv | |
if rnn: # rnn style return | |
return out, torch.stack((x, aa, bb, pp)) | |
else: # parallel | |
return out | |
# TODO: torch.jit.script doesn't like variable returns | |
class SpaceMix(MyModule): | |
def __init__(self, C): | |
super().__init__() | |
self.C = C | |
self.add_params('mix_k mix_r', (C,)) | |
# self.add_params('kw vw rw', (C, C)) | |
self.layer_norm = torch.nn.LayerNorm(C) | |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L247 | |
# uses a time shift for last_x and nn.Linear(..., bias=False) | |
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) | |
# instead of weights | |
# also need to copy initialization tricks | |
self.kw = torch.nn.Linear(C, C, bias=False) | |
self.vw = torch.nn.Linear(C, C, bias=False) | |
self.rw = torch.nn.Linear(C, C, bias=False) | |
def forward(self, x, state=None): | |
if x.ndim == 1: # rnn mode, x.shape is (C) | |
last_x, *_ = state | |
else: # xfmr, x.shape is (B,T,C) | |
last_x = self.time_shift(x) | |
x = self.layer_norm(x) | |
xk = x * self.mix_k + last_x * (1 - self.mix_k) | |
xr = x * self.mix_r + last_x * (1 - self.mix_r) | |
r = torch.sigmoid(self.rw(xr)) | |
k = torch.square(torch.relu(self.kw(xk))) # square relu, primer paper | |
rvwk = r * self.vw(k) | |
if x.ndim == 1: # rnn | |
return rvwk, x | |
else: # xmfr | |
return rvwk | |
class Block(MyModule): | |
def __init__(self, C): | |
super().__init__() | |
self.C = C | |
self.time_mix = TimeMix(C) | |
self.space_mix = SpaceMix(C) | |
self.time_ln = torch.nn.LayerNorm(C) | |
self.space_ln = torch.nn.LayerNorm(C) | |
def forward(self, x, state=None): | |
if x.ndim == 1: # rnn mode | |
# in tvb terms, as if we're starting with afferent states | |
x_ = self.time_ln(x) | |
time_dx, time_state = self.time_mix(x_, state) | |
x = x + time_dx | |
x_ = self.space_ln(x) | |
space_dx, space_state = self.space_mix(x_, state) | |
x = x + space_dx | |
next_state = torch.concatenate((space_state.reshape((1, -1)), time_state)) | |
return x, next_state | |
else: | |
x = x + self.time_mix(self.time_ln(x)) | |
x = x + self.space_mix(self.space_ln(x)) | |
return x | |
class RWKV(MyModule): | |
def __init__(self, C, nlayers=1, indim=None, outdim=None): | |
super().__init__() | |
self.C = C | |
self.indim = indim or C | |
self.outdim = outdim or C | |
self.layers = torch.nn.ModuleList([ | |
Block(C) for i in range(nlayers)]) | |
self.pre_ln = torch.nn.LayerNorm(C) | |
self.post_ln = torch.nn.LayerNorm(C) | |
self.encode = torch.nn.Linear(self.indim, C, bias=False) | |
self.decode = torch.nn.Linear(C, self.outdim, bias=False) | |
def forward(self, x, state=None): | |
x = self.encode(x) | |
x = self.pre_ln(x) | |
for layer in self.layers: | |
if x.ndim == 1: | |
x, state = layer(x, state) | |
else: | |
x = layer(x) | |
x = self.post_ln(x) | |
x = self.decode(x) | |
# here I'm skipping the softmax since we don't want to force probabilities | |
if x.ndim == 1: | |
return x, state | |
else: | |
return x | |
def test_time_mix(): | |
B, T, C = 2, 3, 6 | |
print('test time mix with B T C', B, T, C) | |
tm = TimeMix(C).to('mps') | |
x = torch.randn(C).to('mps') | |
state = torch.randn(5, C).to('mps') | |
out, time_state = tm.forward(x, state) | |
print(out.shape, time_state.shape) | |
tm = TimeMix(C).to('mps') | |
x = torch.randn(B, T, C).to('mps') | |
state = torch.randn(5, B, T, C).to('mps') | |
print(tm.forward(x).shape) | |
def test_space_mix(): | |
B, T, C = 2, 3, 6 | |
print('test space mix with B T C', B, T, C) | |
sm = SpaceMix(C).to('mps') | |
x = torch.randn(B, T, C).to('mps') | |
print(sm.forward(x).shape) | |
state = torch.randn(5, C).to('mps') | |
rvw, x = sm.forward(x[0,0], state) | |
print(rvw.shape, x.shape) | |
def test_rwkv(): | |
B, T, C = 2, 3, 6 | |
# rnn mode | |
# basic usage: | |
x = torch.randn(C) | |
state = torch.randn(5, C) | |
rwkv = RWKV(C, 3) | |
nx, nstate = rwkv(x, state) | |
assert nx.shape == x.shape and nstate.shape == state.shape | |
# extra input | |
x = torch.randn(C + 3) | |
state = torch.randn(5, C) | |
rwkv = RWKV(C, 3, indim=C + 3) | |
nx, nstate = rwkv(x, state) | |
assert nx.shape[0] == C | |
# extra carried through the network but not in output | |
x = torch.randn(C + 3) | |
state = torch.randn(5, C + 3) | |
rwkv = RWKV(C + 3, 3, outdim=C) | |
nx, nstate = rwkv(x, state) | |
assert nx.shape[0] == C | |
# extra inputs, large latent state, smaller output | |
indim = C + 3 | |
latdim = 2*C | |
outdim = C | |
x = torch.randn(indim) | |
state = torch.randn(5, latdim) | |
rwkv = RWKV(latdim, 4, indim=indim, outdim=outdim) | |
nx, nstate = rwkv(x, state) | |
assert nx.shape==(outdim,) and nstate.shape==(5, latdim) | |
rwkv.parameter_count() | |
# xfmr mode | |
# same idea but in xfmr mode: check different usages | |
x = torch.randn(B, T, C) | |
rwkv = RWKV(C, 3) | |
nx = rwkv(x) | |
assert x.shape == nx.shape | |
x = torch.randn(B, T, C + 3) | |
rwkv = RWKV(C, 3, indim=C + 3) | |
nx = rwkv(x) | |
assert x.shape[2]-3 == nx.shape[2] | |
x = torch.randn(B, T, C + 3) | |
rwkv = RWKV(C + 3, 3, outdim=C) | |
nx = rwkv(x) | |
assert x.shape[2]-3 == nx.shape[2] | |
indim = C + 3 | |
latdim = 2*C | |
outdim = C | |
x = torch.randn(B, T, indim) | |
rwkv = RWKV(latdim, 4, indim=indim, outdim=outdim) | |
nx = rwkv(x) | |
assert nx.shape==(B, T, outdim) | |
Yes the rwkv module def is not complete.. I have it finished it locally and forgot to update it. I also realized that I don't have a great way to test the code for now.
@alfredyuan I updated, fixed for training and inference & more reusable definition. I didn't include the softmax in the RWKV output but that should be easier to add in any calling code.
Thanks a lot!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just curious, Rwkv module definition is not complete?