Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active January 24, 2024 15:09
Show Gist options
  • Save maedoc/77f9452af00b4fe4cf7d41eb5a321656 to your computer and use it in GitHub Desktop.
Save maedoc/77f9452af00b4fe4cf7d41eb5a321656 to your computer and use it in GitHub Desktop.
Another short take on RWKV, towards use with time series
import numpy as np
import torch
class MyModule(torch.nn.Module):
def add_param(self, key, shape):
val = torch.randn(*shape)/np.prod(shape)
setattr(self, key, torch.nn.Parameter(val))
def add_params(self, keys, shape):
for key in keys.split(' '):
self.add_param(key, shape)
def parameter_count(self):
return sum([p.numel() for p in self.parameters()])
class TimeMix(MyModule):
def __init__(self, C):
super().__init__()
self.C = C
self.add_params('mix_k mix_v mix_r first decay', (C,))
self.layer_norm = torch.nn.LayerNorm(C)
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L174
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
self.key = torch.nn.Linear(C, C, bias=False)
self.value = torch.nn.Linear(C, C, bias=False)
self.receptance = torch.nn.Linear(C, C, bias=False)
self.output = torch.nn.Linear(C, C, bias=False)
def forward(self, x, state=None):
# in train code, x.size is (batch, time, channel)
# so last_x is computed with time shift
if x.ndim == 1:
# rnn style gen
B, T, C = 0, 0, x.shape[0]
rnn = True
#_, last_x, aa, bb, pp = state
last_x = state[1]
aa = state[2]
bb = state[3]
pp = state[4]
else:
B, T, C = x.shape
rnn = False
last_x = self.time_shift(x)
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/cuda/wkv_cuda.cu#L96
aa = torch.zeros_like(x[:,0]) # (B, C)
bb = torch.zeros_like(x[:,0])
pp = torch.ones_like(x[:,0])*(-1e38)
x = self.layer_norm(x)
xk = x * self.mix_k + last_x * (1 - self.mix_k)
xv = x * self.mix_v + last_x * (1 - self.mix_v)
xr = x * self.mix_r + last_x * (1 - self.mix_r)
r = torch.sigmoid(self.receptance(xr))
k = self.key(xk)
v = self.value(xv)
# rest is in the RUN_CUDA thingy
# RUN_CUDA(B, T, dim_att, time_decay, time_first, k, v)
# wkv_cuda.forward(B, T, C, w, u, k, v, y), y is output
# where dim_att==C is the attention dimension, n in this notebook
# so where do aa bb pp come from?
# they are computed from start of sequence in cuda kernel
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/cuda/wkv_cuda.cu#L22
# not yet sure
wkv = []
# TODO this could be a jit'd module probably
for t in range(1 if rnn else T):
# note confusing changes in variable names btw different codes
# time_decay -> w, shape (C,)
# time_first -> u, shape (C,)
# k, v are same, shape (C,) or (B,T,C)
# y -> wkv, shape (C,) or (B,T,C)
# k, v, y offset in cuda, k[block,:,c]
# thread grid computes wkv parallel for all B and C, iters T
kk = k if rnn else k[:,t]
vv = v if rnn else v[:,t]
ww = self.first + kk # u + kk; (C,)+(B,C)->(B,C)
qq = torch.maximum(pp, ww) # p = max(pp, ww); (B,C)->(B,C)
e1 = torch.exp(pp - qq) # exp(pp - p) %
e2 = torch.exp(ww - qq) # exp(ww - p) %
# y[ii], wkv[:,t,:]
wkv.append(
(e1 * aa + e2 * vv) / (e1 * bb + e2) # (B,C)
)
ww = pp + self.decay # (B,C)+(C,)->(B,C)
qq = torch.maximum(ww, kk) # (B,C)
e1 = torch.exp(ww - qq)
e2 = torch.exp(kk - qq)
# retain moving averages for next iter
aa = e1 * aa + e2 * vv
bb = e1 * bb + e2
pp = qq
wkv = wkv[0] if rnn else torch.stack(wkv, dim=1)
out = self.output(r * wkv) # whence rwkv
if rnn: # rnn style return
return out, torch.stack((x, aa, bb, pp))
else: # parallel
return out
# TODO: torch.jit.script doesn't like variable returns
class SpaceMix(MyModule):
def __init__(self, C):
super().__init__()
self.C = C
self.add_params('mix_k mix_r', (C,))
# self.add_params('kw vw rw', (C, C))
self.layer_norm = torch.nn.LayerNorm(C)
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L247
# uses a time shift for last_x and nn.Linear(..., bias=False)
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
# instead of weights
# also need to copy initialization tricks
self.kw = torch.nn.Linear(C, C, bias=False)
self.vw = torch.nn.Linear(C, C, bias=False)
self.rw = torch.nn.Linear(C, C, bias=False)
def forward(self, x, state=None):
if x.ndim == 1: # rnn mode, x.shape is (C)
last_x, *_ = state
else: # xfmr, x.shape is (B,T,C)
last_x = self.time_shift(x)
x = self.layer_norm(x)
xk = x * self.mix_k + last_x * (1 - self.mix_k)
xr = x * self.mix_r + last_x * (1 - self.mix_r)
r = torch.sigmoid(self.rw(xr))
k = torch.square(torch.relu(self.kw(xk))) # square relu, primer paper
rvwk = r * self.vw(k)
if x.ndim == 1: # rnn
return rvwk, x
else: # xmfr
return rvwk
class Block(MyModule):
def __init__(self, C):
super().__init__()
self.C = C
self.time_mix = TimeMix(C)
self.space_mix = SpaceMix(C)
self.time_ln = torch.nn.LayerNorm(C)
self.space_ln = torch.nn.LayerNorm(C)
def forward(self, x, state=None):
if x.ndim == 1: # rnn mode
# in tvb terms, as if we're starting with afferent states
x_ = self.time_ln(x)
time_dx, time_state = self.time_mix(x_, state)
x = x + time_dx
x_ = self.space_ln(x)
space_dx, space_state = self.space_mix(x_, state)
x = x + space_dx
next_state = torch.concatenate((space_state.reshape((1, -1)), time_state))
return x, next_state
else:
x = x + self.time_mix(self.time_ln(x))
x = x + self.space_mix(self.space_ln(x))
return x
class RWKV(MyModule):
def __init__(self, C, nlayers=1, indim=None, outdim=None):
super().__init__()
self.C = C
self.indim = indim or C
self.outdim = outdim or C
self.layers = torch.nn.ModuleList([
Block(C) for i in range(nlayers)])
self.pre_ln = torch.nn.LayerNorm(C)
self.post_ln = torch.nn.LayerNorm(C)
self.encode = torch.nn.Linear(self.indim, C, bias=False)
self.decode = torch.nn.Linear(C, self.outdim, bias=False)
def forward(self, x, state=None):
x = self.encode(x)
x = self.pre_ln(x)
for layer in self.layers:
if x.ndim == 1:
x, state = layer(x, state)
else:
x = layer(x)
x = self.post_ln(x)
x = self.decode(x)
# here I'm skipping the softmax since we don't want to force probabilities
if x.ndim == 1:
return x, state
else:
return x
def test_time_mix():
B, T, C = 2, 3, 6
print('test time mix with B T C', B, T, C)
tm = TimeMix(C).to('mps')
x = torch.randn(C).to('mps')
state = torch.randn(5, C).to('mps')
out, time_state = tm.forward(x, state)
print(out.shape, time_state.shape)
tm = TimeMix(C).to('mps')
x = torch.randn(B, T, C).to('mps')
state = torch.randn(5, B, T, C).to('mps')
print(tm.forward(x).shape)
def test_space_mix():
B, T, C = 2, 3, 6
print('test space mix with B T C', B, T, C)
sm = SpaceMix(C).to('mps')
x = torch.randn(B, T, C).to('mps')
print(sm.forward(x).shape)
state = torch.randn(5, C).to('mps')
rvw, x = sm.forward(x[0,0], state)
print(rvw.shape, x.shape)
def test_rwkv():
B, T, C = 2, 3, 6
# rnn mode
# basic usage:
x = torch.randn(C)
state = torch.randn(5, C)
rwkv = RWKV(C, 3)
nx, nstate = rwkv(x, state)
assert nx.shape == x.shape and nstate.shape == state.shape
# extra input
x = torch.randn(C + 3)
state = torch.randn(5, C)
rwkv = RWKV(C, 3, indim=C + 3)
nx, nstate = rwkv(x, state)
assert nx.shape[0] == C
# extra carried through the network but not in output
x = torch.randn(C + 3)
state = torch.randn(5, C + 3)
rwkv = RWKV(C + 3, 3, outdim=C)
nx, nstate = rwkv(x, state)
assert nx.shape[0] == C
# extra inputs, large latent state, smaller output
indim = C + 3
latdim = 2*C
outdim = C
x = torch.randn(indim)
state = torch.randn(5, latdim)
rwkv = RWKV(latdim, 4, indim=indim, outdim=outdim)
nx, nstate = rwkv(x, state)
assert nx.shape==(outdim,) and nstate.shape==(5, latdim)
rwkv.parameter_count()
# xfmr mode
# same idea but in xfmr mode: check different usages
x = torch.randn(B, T, C)
rwkv = RWKV(C, 3)
nx = rwkv(x)
assert x.shape == nx.shape
x = torch.randn(B, T, C + 3)
rwkv = RWKV(C, 3, indim=C + 3)
nx = rwkv(x)
assert x.shape[2]-3 == nx.shape[2]
x = torch.randn(B, T, C + 3)
rwkv = RWKV(C + 3, 3, outdim=C)
nx = rwkv(x)
assert x.shape[2]-3 == nx.shape[2]
indim = C + 3
latdim = 2*C
outdim = C
x = torch.randn(B, T, indim)
rwkv = RWKV(latdim, 4, indim=indim, outdim=outdim)
nx = rwkv(x)
assert nx.shape==(B, T, outdim)
@alfredyuan
Copy link

Just curious, Rwkv module definition is not complete?

@maedoc
Copy link
Author

maedoc commented Aug 3, 2023

Yes the rwkv module def is not complete.. I have it finished it locally and forgot to update it. I also realized that I don't have a great way to test the code for now.

@maedoc
Copy link
Author

maedoc commented Aug 21, 2023

@alfredyuan I updated, fixed for training and inference & more reusable definition. I didn't include the softmax in the RWKV output but that should be easier to add in any calling code.

@alfredyuan
Copy link

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment