Last active
May 11, 2022 13:57
-
-
Save hsm207/4c68c63f5f295fd62fc35d452fedf68a to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2PmRjiC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"toc": true | |
}, | |
"source": [ | |
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Matrix-Transpose\" data-toc-modified-id=\"Matrix-Transpose-1\"><span class=\"toc-item-num\">1 </span>Matrix Transpose</a></span></li><li><span><a href=\"#Extracting-the-Diagonal-Elements-of-a-Matrix\" data-toc-modified-id=\"Extracting-the-Diagonal-Elements-of-a-Matrix-2\"><span class=\"toc-item-num\">2 </span>Extracting the Diagonal Elements of a Matrix</a></span></li><li><span><a href=\"#Summing-the-Diagonal-Elements-of-a-Matrix\" data-toc-modified-id=\"Summing-the-Diagonal-Elements-of-a-Matrix-3\"><span class=\"toc-item-num\">3 </span>Summing the Diagonal Elements of a Matrix</a></span></li><li><span><a href=\"#Summing-All-Elements-in-a-Matrix\" data-toc-modified-id=\"Summing-All-Elements-in-a-Matrix-4\"><span class=\"toc-item-num\">4 </span>Summing All Elements in a Matrix</a></span></li><li><span><a href=\"#Matrix-Multiplication-with-Transpose\" data-toc-modified-id=\"Matrix-Multiplication-with-Transpose-5\"><span class=\"toc-item-num\">5 </span>Matrix Multiplication with Transpose</a></span></li><li><span><a href=\"#Outer-Product\" data-toc-modified-id=\"Outer-Product-6\"><span class=\"toc-item-num\">6 </span>Outer Product</a></span></li><li><span><a href=\"#Dot-Product\" data-toc-modified-id=\"Dot-Product-7\"><span class=\"toc-item-num\">7 </span>Dot Product</a></span></li><li><span><a href=\"#Matrix-Column-Sum\" data-toc-modified-id=\"Matrix-Column-Sum-8\"><span class=\"toc-item-num\">8 </span>Matrix Column Sum</a></span></li><li><span><a href=\"#Matrix-Row-Sum\" data-toc-modified-id=\"Matrix-Row-Sum-9\"><span class=\"toc-item-num\">9 </span>Matrix Row Sum</a></span></li><li><span><a href=\"#Element-wise-matrix-multiplication\" data-toc-modified-id=\"Element-wise-matrix-multiplication-10\"><span class=\"toc-item-num\">10 </span>Element-wise matrix multiplication</a></span></li><li><span><a href=\"#Column-wise-dot-product\" data-toc-modified-id=\"Column-wise-dot-product-11\"><span class=\"toc-item-num\">11 </span>Column-wise dot product</a></span></li></ul></div>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.0.0'" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Matrix Transpose" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 4, 7],\n", | |
" [2, 5, 8],\n", | |
" [3, 6, 9]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"einsum_transpose = torch.einsum('ji->ij', A)\n", | |
"einsum_transpose" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 4, 7],\n", | |
" [2, 5, 8],\n", | |
" [3, 6, 9]])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"builtin_tranpose = A.transpose(1, 0)\n", | |
"builtin_tranpose" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert torch.equal(einsum_transpose, builtin_tranpose)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Extracting the Diagonal Elements of a Matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"A_diag = torch.einsum('ii->i', A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([1, 5, 9])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"einsum_extract_diag = torch.einsum('ii->i', A)\n", | |
"einsum_extract_diag" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([1, 5, 9])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"builtin_diag = torch.diag(A)\n", | |
"builtin_diag" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert torch.equal(einsum_extract_diag, builtin_diag)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Summing the Diagonal Elements of a Matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(15)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y = torch.einsum('ii->', A)\n", | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(15)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"einsum_sum_diag = torch.einsum('ii->', A)\n", | |
"einsum_sum_diag" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Summing All Elements in a Matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(45)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ij->', A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(45)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.sum(A)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Matrix Multiplication with Transpose" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 1, 2, 3, 4, 5],\n", | |
" [ 6, 7, 8, 9, 10],\n", | |
" [11, 12, 13, 14, 15]])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 16).view(3, 5)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[10, 11, 12, 13, 14],\n", | |
" [15, 16, 17, 18, 19],\n", | |
" [20, 21, 22, 23, 24]])" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"B = torch.arange(10, 25).view(3, 5)\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 190, 265, 340],\n", | |
" [ 490, 690, 890],\n", | |
" [ 790, 1115, 1440]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A @ B.transpose(1, 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 190, 265, 340],\n", | |
" [ 490, 690, 890],\n", | |
" [ 790, 1115, 1440]])" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ik,jk->ij', A, B)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Column-wise dot product" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])\n", | |
"tensor([[10, 11, 12],\n", | |
" [13, 14, 15],\n", | |
" [16, 17, 18]])\n" | |
] | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"B = torch.arange(10, 19).view(3, 3)\n", | |
"print(A)\n", | |
"print(B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([174, 228, 288])" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ij,ij->j', A, B)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Outer Product" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([1, 2, 3]), tensor([4, 5, 6]))" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"a = torch.arange(1, 4)\n", | |
"b = torch.arange(4, 7)\n", | |
"a, b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 4, 5, 6],\n", | |
" [ 8, 10, 12],\n", | |
" [12, 15, 18]])" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('i,j->ij', a, b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Dot Product" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([1, 2, 3]), tensor([4, 5, 6]))" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"a = torch.arange(1, 4)\n", | |
"b = torch.arange(4, 7)\n", | |
"a, b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(32)" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.dot(a, b)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(32)" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('i,i', a, b)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Matrix Column Sum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([12, 15, 18])" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ji->i', A)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Matrix Row Sum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 1, 2, 3, 4, 5, 6, 7],\n", | |
" [ 8, 9, 10, 11, 12, 13, 14]])" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 15).view(2, 7)\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([28, 77])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ij->i', A)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Element-wise matrix multiplication" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([[1, 2, 3],\n", | |
" [4, 5, 6],\n", | |
" [7, 8, 9]])\n", | |
"tensor([[10, 11, 12],\n", | |
" [13, 14, 15],\n", | |
" [16, 17, 18]])\n" | |
] | |
} | |
], | |
"source": [ | |
"A = torch.arange(1, 10).view(3, 3)\n", | |
"B = torch.arange(10, 19).view(3, 3)\n", | |
"print(A)\n", | |
"print(B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 10, 22, 36],\n", | |
" [ 52, 70, 90],\n", | |
" [112, 136, 162]])" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.einsum('ij,ij->ij', A, B)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pytorch", | |
"language": "python", | |
"name": "pytorch" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": true, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment