Last active
May 10, 2024 19:58
-
-
Save Helw150/70216e0b9f22650db54db6941cd8daea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
text = # Tokenized Text Corresponding to Recording Transcript | |
audio = # Mel Spectrogram of the Recording | |
# Only Train Connector and Projection | |
self.encoder.freeze() | |
self.llama.freeze() | |
# Convert Raw Audio Signal to 1500 Embeddings with Whisper Encoder (CNN+Transformer) | |
audio_features = self.encoder(audio) | |
# Learned Query Tokens that will serve as the Q in QKV Cross Attention with the Above | |
static_query_tokens = (query_tokens + query_position_embeds) | |
# Cross Attention Between Query Tokens and Extracted Audio Features | |
virt_whisper_tokens = self.connector.transformer( | |
queries=static_query_tokens, | |
cross_attn_keys_and_values=audio_features, | |
) | |
# Linear Projection From Whisper Embedding Space to LLama Input Embedding Space | |
virtual_audio_tokens = self.projection(virt_whisper_tokens, axis="embed") | |
# Ground Truth Embedding of the Transcript | |
text_embeds = self.llama.embeddings.embed(text) | |
# Get Output Embedding of Just the Final token in response to both text and audio embeddings. | |
audio = self.llama.transformer(virtual_audio_tokens)[-1] | |
text = self.llama.transformer(text_embeds)[-1] | |
# L2 Loss Between Final Embedding | |
diff_distill = audio_pred - text_pred | |
loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5 | |
"""" | |
Minimizing the above loss is equivalent to minimizing the KL Divergence of the final token. | |
Back of the envelope proof: | |
KL Divergence Loss is defined as loss_{kl} = P_{target} * (log P_{target} - log P_{source}. | |
This function achieves it's global minimum at P_{target} = P_{source}. | |
For neural models, we just define P as softmax(matmul(OutputHiddenState, EmbeddingMatrix)). | |
So in this case, KL Divergence is minimized at softmax(matmul(OutputHiddenState_{target}, EmbeddingMatrix)) = softmax(matmul(OutputHiddenState_{source}, EmbeddingMatrix)). | |
If EmbeddingMatrix is held constant (e.g. we are doing LoRA Training), this simplifies down to just OutputHiddenState_{target} = OutputHiddenState_{source}). | |
This means that minimizing the L2 Loss of these output hidden states should lead to a global minimum for the KL as well, but it's much cheaper and more stable to compute when vocabulary size > embedding dimension. | |
Simplified Proof Of Concept Notebook: https://colab.research.google.com/drive/1g1BEegIJzoZ1PHY_PeRNJkuveIQn7QIp?usp=sharing | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment