Skip to content

Instantly share code, notes, and snippets.

@rijulg
Last active July 18, 2020 08:06
Show Gist options
  • Save rijulg/25a350612d6a1e5fb8e8446f6459918c to your computer and use it in GitHub Desktop.
Save rijulg/25a350612d6a1e5fb8e8446f6459918c to your computer and use it in GitHub Desktop.
Pytorch Embedding that supports gradient propagation to inputs
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python37764bitbaseconda4a4f2856d96346b89bf455c97059c11c",
"display_name": "Python 3.7.7 64-bit ('base': conda)"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class HotEmbedding(torch.nn.Module):\n",
" def __init__(self, max_val, embedding_dim, eps=1e-2):\n",
" super(HotEmbedding, self).__init__()\n",
" A = torch.arange(max_val, requires_grad=False)\n",
" self.register_buffer(\"A\", A)\n",
" self.B = torch.nn.Linear(max_val, embedding_dim)\n",
" self.eps = eps\n",
"\n",
" def forward(self, x):\n",
" return self.B(1/((x.unsqueeze(1) - self.A)+self.eps))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "tensor([1., 2., 3., 1., 2., 3.], requires_grad=True)\ntensor([[-27.3928, -27.0610, 18.1991, 16.2886, 2.6034],\n [ 26.3739, 20.9340, -11.5895, -12.7277, -28.9630],\n [-17.8610, 18.7509, -18.1535, 0.1145, -25.8138],\n [-27.3928, -27.0610, 18.1991, 16.2886, 2.6034],\n [ 26.3739, 20.9340, -11.5895, -12.7277, -28.9630],\n [-17.8610, 18.7509, -18.1535, 0.1145, -25.8138]],\n grad_fn=<AddmmBackward>)\n"
}
],
"source": [
"layer = HotEmbedding(10, 5)\n",
"x = torch.tensor([1.,2.,3.,1.,2.,3.], requires_grad=True)\n",
"y = layer(x)\n",
"print(x)\n",
"print(y)"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment