Created
August 1, 2019 10:29
-
-
Save bearpelican/e51681bd46c5a7b13dda241d4a9879f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiTransformer(nn.Module): | |
"Multitask Transformer for training mask, next word, and sequence 2 sequence" | |
def __init__(self, encoder, decoder, head, mem_len): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.head = head | |
self.default_mem_len = mem_len | |
self.current_mem_len = None | |
def forward(self, inp): | |
# data order: mask, next word, melody, chord | |
outputs = {} | |
msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']] | |
if msk is not None: | |
outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos'])) | |
if lm is not None: | |
outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos'])) | |
if c2m is not None: | |
self.reset() | |
c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos']) | |
c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc) | |
outputs['c2m'] = self.head(c2m_dec) | |
if m2c is not None: | |
self.reset() | |
m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos']) | |
m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc) | |
outputs['m2c'] = self.head(m2c_dec) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment