Last active
May 3, 2023 01:47
-
-
Save ariG23498/fd76bf197f71cd044e40aa2ffe2b6aee to your computer and use it in GitHub Desktop.
Custom RNN logic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/fd76bf197f71cd044e40aa2ffe2b6aee/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "lIYdn1woOS1n" | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras import layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CustomCell(layers.Layer):\n", | |
" def __init__(self, num_heads, key_dim, dropout, token_size, units, **kwargs):\n", | |
" self.token_size = token_size\n", | |
" self.units = units\n", | |
" self.state_size = tf.TensorShape([token_size, units])\n", | |
" self.output_size = tf.TensorShape([token_size, units])\n", | |
"\n", | |
" # This is the point where we need to add our custom logic\n", | |
" # instead of the MLP\n", | |
" self.attention_module = layers.MultiHeadAttention(\n", | |
" num_heads,\n", | |
" key_dim,\n", | |
" dropout=dropout\n", | |
" )\n", | |
" \n", | |
" super().__init__(**kwargs)\n", | |
"\n", | |
" def call(self, inputs, states):\n", | |
" # inputs => (batch, token_size, dims)\n", | |
" # states => [(batch, token_size, units)]\n", | |
"\n", | |
" prev_state = states[0]\n", | |
"\n", | |
" outputs = self.attention_module(inputs, inputs) # Self Attention as key and value are the same\n", | |
" new_state = outputs + prev_state\n", | |
" \n", | |
" return outputs, [new_state]" | |
], | |
"metadata": { | |
"id": "Zco4Hq1rNlG-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"keras.backend.clear_session()\n", | |
"\n", | |
"batch_size = 8\n", | |
"num_frames = 32\n", | |
"token_size = 543\n", | |
"units = 3\n", | |
"\n", | |
"inputs = tf.random.normal(\n", | |
" (batch_size, num_frames, token_size, units)\n", | |
")\n", | |
"\n", | |
"cell = CustomCell(num_heads=1, key_dim=3, dropout=0.1, token_size=token_size, units=3)\n", | |
"rnn = layers.RNN(cell)\n", | |
"rnn(inputs).shape # This is the last hidden state of the RNN" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6jiaDke1NuHz", | |
"outputId": "9af33977-7a75-4925-ca7a-4c7f14033381" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"TensorShape([8, 543, 3])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "scratchpad", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment