Skip to content

Instantly share code, notes, and snippets.

@isaacmg
Created December 4, 2019 05:39
Show Gist options
  • Save isaacmg/dde5274a83c7510796990d47a3eb5443 to your computer and use it in GitHub Desktop.
Save isaacmg/dde5274a83c7510796990d47a3eb5443 to your computer and use it in GitHub Desktop.
class SimpleTransformer(torch.nn.Module):
def __init__(self, n_time_series, seq_len, d_model=128):
super().__init__()
self.dense_shape = torch.nn.Linear(n_time_series, d_model)
self.pe = SimplePositionalEncoding(d_model)
self.transformer = Transformer(d_model, nhead=8)
self.final_layer = torch.nn.Linear(d_model, 1)
self.sequence_size = seq_len
def forward(self, x, t, tgt_mask, src_mask=None):
if src_mask:
x = self.encode_sequence(x, src_mask)
else:
x = self.encode_sequence(x, src_mask)
return self.decode_seq(x, t, tgt_mask)
def basic_feature(self, x):
x = self.dense_shape(x)
x = self.pe(x)
x = x.permute(1,0,2)
return x
def encode_sequence(self, x, src_mask=None):
x = self.basic_feature(x)
x = self.transformer.encoder(x, src_mask)
return x
def decode_seq(self, mem, t, tgt_mask, seq_size=None):
if seq_size == None:
seq_size = self.sequence_size
t = self.basic_feature(t)
x = self.transformer.decoder(t, mem, tgt_mask=tgt_mask)
x = self.final_layer(x)
return x.view(-1, seq_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment