Skip to content

Instantly share code, notes, and snippets.

@bearpelican
Created August 1, 2019 10:29
Show Gist options
  • Save bearpelican/e51681bd46c5a7b13dda241d4a9879f9 to your computer and use it in GitHub Desktop.
Save bearpelican/e51681bd46c5a7b13dda241d4a9879f9 to your computer and use it in GitHub Desktop.
class MultiTransformer(nn.Module):
"Multitask Transformer for training mask, next word, and sequence 2 sequence"
def __init__(self, encoder, decoder, head, mem_len):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.head = head
self.default_mem_len = mem_len
self.current_mem_len = None
def forward(self, inp):
# data order: mask, next word, melody, chord
outputs = {}
msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']]
if msk is not None:
outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos']))
if lm is not None:
outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos']))
if c2m is not None:
self.reset()
c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos'])
c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc)
outputs['c2m'] = self.head(c2m_dec)
if m2c is not None:
self.reset()
m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos'])
m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc)
outputs['m2c'] = self.head(m2c_dec)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment