Last active
May 11, 2022 13:57
-
-
Save hsm207/4c68c63f5f295fd62fc35d452fedf68a to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2PmRjiC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "toc": true | |
| }, | |
| "source": [ | |
| "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
| "<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Matrix-Transpose\" data-toc-modified-id=\"Matrix-Transpose-1\"><span class=\"toc-item-num\">1 </span>Matrix Transpose</a></span></li><li><span><a href=\"#Extracting-the-Diagonal-Elements-of-a-Matrix\" data-toc-modified-id=\"Extracting-the-Diagonal-Elements-of-a-Matrix-2\"><span class=\"toc-item-num\">2 </span>Extracting the Diagonal Elements of a Matrix</a></span></li><li><span><a href=\"#Summing-the-Diagonal-Elements-of-a-Matrix\" data-toc-modified-id=\"Summing-the-Diagonal-Elements-of-a-Matrix-3\"><span class=\"toc-item-num\">3 </span>Summing the Diagonal Elements of a Matrix</a></span></li><li><span><a href=\"#Summing-All-Elements-in-a-Matrix\" data-toc-modified-id=\"Summing-All-Elements-in-a-Matrix-4\"><span class=\"toc-item-num\">4 </span>Summing All Elements in a Matrix</a></span></li><li><span><a href=\"#Matrix-Multiplication-with-Transpose\" data-toc-modified-id=\"Matrix-Multiplication-with-Transpose-5\"><span class=\"toc-item-num\">5 </span>Matrix Multiplication with Transpose</a></span></li><li><span><a href=\"#Outer-Product\" data-toc-modified-id=\"Outer-Product-6\"><span class=\"toc-item-num\">6 </span>Outer Product</a></span></li><li><span><a href=\"#Dot-Product\" data-toc-modified-id=\"Dot-Product-7\"><span class=\"toc-item-num\">7 </span>Dot Product</a></span></li><li><span><a href=\"#Matrix-Column-Sum\" data-toc-modified-id=\"Matrix-Column-Sum-8\"><span class=\"toc-item-num\">8 </span>Matrix Column Sum</a></span></li><li><span><a href=\"#Matrix-Row-Sum\" data-toc-modified-id=\"Matrix-Row-Sum-9\"><span class=\"toc-item-num\">9 </span>Matrix Row Sum</a></span></li><li><span><a href=\"#Element-wise-matrix-multiplication\" data-toc-modified-id=\"Element-wise-matrix-multiplication-10\"><span class=\"toc-item-num\">10 </span>Element-wise matrix multiplication</a></span></li><li><span><a href=\"#Column-wise-dot-product\" data-toc-modified-id=\"Column-wise-dot-product-11\"><span class=\"toc-item-num\">11 </span>Column-wise dot product</a></span></li></ul></div>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'1.0.0'" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix Transpose" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 4, 7],\n", | |
| " [2, 5, 8],\n", | |
| " [3, 6, 9]])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "einsum_transpose = torch.einsum('ji->ij', A)\n", | |
| "einsum_transpose" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 4, 7],\n", | |
| " [2, 5, 8],\n", | |
| " [3, 6, 9]])" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "builtin_tranpose = A.transpose(1, 0)\n", | |
| "builtin_tranpose" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert torch.equal(einsum_transpose, builtin_tranpose)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Extracting the Diagonal Elements of a Matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "A_diag = torch.einsum('ii->i', A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([1, 5, 9])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "einsum_extract_diag = torch.einsum('ii->i', A)\n", | |
| "einsum_extract_diag" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([1, 5, 9])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "builtin_diag = torch.diag(A)\n", | |
| "builtin_diag" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert torch.equal(einsum_extract_diag, builtin_diag)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Summing the Diagonal Elements of a Matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(15)" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y = torch.einsum('ii->', A)\n", | |
| "y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(15)" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "einsum_sum_diag = torch.einsum('ii->', A)\n", | |
| "einsum_sum_diag" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Summing All Elements in a Matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(45)" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ij->', A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(45)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.sum(A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix Multiplication with Transpose" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 1, 2, 3, 4, 5],\n", | |
| " [ 6, 7, 8, 9, 10],\n", | |
| " [11, 12, 13, 14, 15]])" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 16).view(3, 5)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[10, 11, 12, 13, 14],\n", | |
| " [15, 16, 17, 18, 19],\n", | |
| " [20, 21, 22, 23, 24]])" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "B = torch.arange(10, 25).view(3, 5)\n", | |
| "B" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 190, 265, 340],\n", | |
| " [ 490, 690, 890],\n", | |
| " [ 790, 1115, 1440]])" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A @ B.transpose(1, 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 190, 265, 340],\n", | |
| " [ 490, 690, 890],\n", | |
| " [ 790, 1115, 1440]])" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ik,jk->ij', A, B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Column-wise dot product" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])\n", | |
| "tensor([[10, 11, 12],\n", | |
| " [13, 14, 15],\n", | |
| " [16, 17, 18]])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "B = torch.arange(10, 19).view(3, 3)\n", | |
| "print(A)\n", | |
| "print(B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([174, 228, 288])" | |
| ] | |
| }, | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ij,ij->j', A, B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Outer Product" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([1, 2, 3]), tensor([4, 5, 6]))" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "a = torch.arange(1, 4)\n", | |
| "b = torch.arange(4, 7)\n", | |
| "a, b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 4, 5, 6],\n", | |
| " [ 8, 10, 12],\n", | |
| " [12, 15, 18]])" | |
| ] | |
| }, | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('i,j->ij', a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Dot Product" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([1, 2, 3]), tensor([4, 5, 6]))" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "a = torch.arange(1, 4)\n", | |
| "b = torch.arange(4, 7)\n", | |
| "a, b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(32)" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.dot(a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(32)" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('i,i', a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix Column Sum" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([12, 15, 18])" | |
| ] | |
| }, | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ji->i', A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix Row Sum" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 1, 2, 3, 4, 5, 6, 7],\n", | |
| " [ 8, 9, 10, 11, 12, 13, 14]])" | |
| ] | |
| }, | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 15).view(2, 7)\n", | |
| "A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([28, 77])" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ij->i', A)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Element-wise matrix multiplication" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([[1, 2, 3],\n", | |
| " [4, 5, 6],\n", | |
| " [7, 8, 9]])\n", | |
| "tensor([[10, 11, 12],\n", | |
| " [13, 14, 15],\n", | |
| " [16, 17, 18]])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "A = torch.arange(1, 10).view(3, 3)\n", | |
| "B = torch.arange(10, 19).view(3, 3)\n", | |
| "print(A)\n", | |
| "print(B)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 10, 22, 36],\n", | |
| " [ 52, 70, 90],\n", | |
| " [112, 136, 162]])" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.einsum('ij,ij->ij', A, B)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "pytorch", | |
| "language": "python", | |
| "name": "pytorch" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.8" | |
| }, | |
| "toc": { | |
| "nav_menu": {}, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "toc_cell": true, | |
| "toc_position": {}, | |
| "toc_section_display": "block", | |
| "toc_window_display": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment