Last active
March 20, 2022 03:54
-
-
Save goerz/65b117c6c6bca0dd2a650d9b8b0b05a5 to your computer and use it in GitHub Desktop.
Analytical Reverse-AD for the determinant of a complex matrix (https://github.com/JuliaDiff/ChainRules.jl/issues/600)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "3620199d", | |
| "metadata": {}, | |
| "source": [ | |
| "# Reverse-AD for the determinant of a complex matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1c3cfc00", | |
| "metadata": {}, | |
| "source": [ | |
| "Here, we illustrate for the example of a complex 2×2 matrix $U$, that for $\\Omega = \\det(U)$ and a perturbation $\\Delta\\Omega$, the correct `ChainRules.rrule` is $\\bar\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$, not $\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$. That is, it involves the complex conjugate of the determinant, not the *value* of the determinant." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "c6a9a4b2", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.721701Z", | |
| "start_time": "2022-03-20T03:24:43.379277Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sympy import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "22f328a0", | |
| "metadata": {}, | |
| "source": [ | |
| "## Helper Routines" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "7a207be4", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.726000Z", | |
| "start_time": "2022-03-20T03:24:43.722858Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def csym(name, i=None, j=None, part=None):\n", | |
| " \"\"\"Create a symbolic complex number.\"\"\"\n", | |
| " if part is None:\n", | |
| " return csym(name, i, j, 're') + I * csym(name, i, j, 'im')\n", | |
| " else:\n", | |
| " if i is None or j is None:\n", | |
| " return symbols('{name}^{part}'.format(name=name, part=part), real=true)\n", | |
| " else:\n", | |
| " return symbols('{name}_{i}{j}^{part}'.format(name=name, i=i, j=j, part=part), real=true)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "47d597ab", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.728423Z", | |
| "start_time": "2022-03-20T03:24:43.726911Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def real(z):\n", | |
| " return (z + z.conjugate()) / 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "9d2ce0ff", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.731476Z", | |
| "start_time": "2022-03-20T03:24:43.729824Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def imag(z):\n", | |
| " return (z - z.conjugate()) / (2 * I)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "43660fb7", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.734266Z", | |
| "start_time": "2022-03-20T03:24:43.732476Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "conj = conjugate" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "42dad762", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.736840Z", | |
| "start_time": "2022-03-20T03:24:43.735178Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def dagger(U):\n", | |
| " return conj(transpose(U))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "f4c06038", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.740421Z", | |
| "start_time": "2022-03-20T03:24:43.737935Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def deriv_scalar_matrix(J, M):\n", | |
| " \"\"\"Return the matrix that is the derivative of the real-valued scalar \n", | |
| " $J$ with respect to the real-valued matrix $M$. The $ij$ element of\n", | |
| " the result is defined as $\\frac{\\partial J}{\\partial M_{ji}}$, see\n", | |
| " https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-matrix\n", | |
| " \n", | |
| " Note the implicit transpose in this equation!\n", | |
| " \"\"\"\n", | |
| " n, m = M.shape\n", | |
| " return Matrix(\n", | |
| " [\n", | |
| " [J.diff(M[j, i]) for i in range(n)]\n", | |
| " for j in range(m)\n", | |
| " ]\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "d6e74317", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.743100Z", | |
| "start_time": "2022-03-20T03:24:43.741305Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def u(i, j, part=None):\n", | |
| " \"\"\"Complex elements of the matrix U\"\"\"\n", | |
| " return csym(\"u\", i, j, part)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6cdeb9ba", | |
| "metadata": {}, | |
| "source": [ | |
| "## Definitions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "5926cf50", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.874856Z", | |
| "start_time": "2022-03-20T03:24:43.743838Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}i u^{im}_{11} + u^{re}_{11} & i u^{im}_{12} + u^{re}_{12}\\\\i u^{im}_{21} + u^{re}_{21} & i u^{im}_{22} + u^{re}_{22}\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[I*u_11^im + u_11^re, I*u_12^im + u_12^re],\n", | |
| "[I*u_21^im + u_21^re, I*u_22^im + u_22^re]])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "U = Matrix([[u(1, 1), u(1, 2)], [u(2, 1), u(2, 2)]])\n", | |
| "U" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "675c7fe4", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.886047Z", | |
| "start_time": "2022-03-20T03:24:43.877263Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle - u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} + u^{re}_{11} u^{re}_{22} + u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} - u^{re}_{12} u^{re}_{21}$" | |
| ], | |
| "text/plain": [ | |
| "-u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im + u_11^re*u_22^re + u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im - u_12^re*u_21^re" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Ω = det(U)\n", | |
| "Ω" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "7f961285", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.890044Z", | |
| "start_time": "2022-03-20T03:24:43.886863Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle i \\Delta\\Omega^{im} + \\Delta\\Omega^{re}$" | |
| ], | |
| "text/plain": [ | |
| "I*\\Delta\\Omega^im + \\Delta\\Omega^re" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ΔΩ = csym(r\"\\Delta\\Omega\");\n", | |
| "ΔΩ" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "09ed58e0", | |
| "metadata": {}, | |
| "source": [ | |
| "## Left Hand Side" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6f25fa15", | |
| "metadata": {}, | |
| "source": [ | |
| "On the left-hand side, we have the equation for `rrule` from https://juliadiff.org/ChainRulesCore.jl/dev/maths/complex.html" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "66decf6f", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:43.921962Z", | |
| "start_time": "2022-03-20T03:24:43.890903Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n", | |
| "[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lhs = (\n", | |
| " real(ΔΩ) * deriv_scalar_matrix(real(Ω), real(U))\n", | |
| " + imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), real(U))\n", | |
| " + I * real(ΔΩ) * deriv_scalar_matrix(real(Ω), imag(U))\n", | |
| " + I * imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), imag(U))\n", | |
| ");\n", | |
| "lhs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "bbfca86f", | |
| "metadata": {}, | |
| "source": [ | |
| "## Right Hand Side" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fb426a79", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-19T21:48:47.126422Z", | |
| "start_time": "2022-03-19T21:48:47.113734Z" | |
| } | |
| }, | |
| "source": [ | |
| "On the right hand side, we have to formula corresponding to the code at https://github.com/JuliaDiff/ChainRules.jl/blob/9023d898a0b957bd9b3baab6bc38b54822d6963a/src/rulesets/LinearAlgebra/dense.jl#L132" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "e67fe9d3", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:46.681988Z", | |
| "start_time": "2022-03-20T03:24:43.922856Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "rhs_old = simplify(Ω * ΔΩ * dagger(U.inv())).expand()\n", | |
| "rhs_old;" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2a2ddafc", | |
| "metadata": {}, | |
| "source": [ | |
| "respectively the corrected version:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "00ba42a0", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:47.224201Z", | |
| "start_time": "2022-03-20T03:24:46.682738Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n", | |
| "[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "rhs_new = simplify(conj(Ω) * ΔΩ * dagger(U.inv())).expand()\n", | |
| "rhs_new" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "77b64d5c", | |
| "metadata": {}, | |
| "source": [ | |
| "## Check" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "08a600bb", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.681910Z", | |
| "start_time": "2022-03-20T03:24:47.225137Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}\\frac{2 \\left(i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{im} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{22} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{22} u^{re}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{22}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{21} u^{im}_{22} + i \\Delta\\Omega^{im} u^{im}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{im} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{21} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{re} u^{im}_{12} u^{im}_{21} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\\\\\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{22} + i \\Delta\\Omega^{im} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{re}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{22} + \\Delta\\Omega^{re} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{re}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(i \\Delta\\Omega^{im} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{im}_{22} - \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{re}_{21} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{21} + \\Delta\\Omega^{re} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{im}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[ 2*(I*\\Delta\\Omega^im*u_11^im*u_22^im*u_22^re - \\Delta\\Omega^im*u_11^im*u_22^re**2 + I*\\Delta\\Omega^im*u_11^re*u_22^im**2 - \\Delta\\Omega^im*u_11^re*u_22^im*u_22^re - I*\\Delta\\Omega^im*u_12^im*u_21^re*u_22^im + \\Delta\\Omega^im*u_12^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_12^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_12^re*u_21^im*u_22^re + \\Delta\\Omega^re*u_11^im*u_22^im*u_22^re + I*\\Delta\\Omega^re*u_11^im*u_22^re**2 + \\Delta\\Omega^re*u_11^re*u_22^im**2 + I*\\Delta\\Omega^re*u_11^re*u_22^im*u_22^re - \\Delta\\Omega^re*u_12^im*u_21^re*u_22^im - I*\\Delta\\Omega^re*u_12^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_12^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_12^re*u_21^im*u_22^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(-I*\\Delta\\Omega^im*u_11^im*u_21^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_21^re*u_22^im + I*\\Delta\\Omega^im*u_12^im*u_21^im*u_21^re - \\Delta\\Omega^im*u_12^im*u_21^re**2 + I*\\Delta\\Omega^im*u_12^re*u_21^im**2 - \\Delta\\Omega^im*u_12^re*u_21^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_21^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_21^re*u_22^im + \\Delta\\Omega^re*u_12^im*u_21^im*u_21^re + I*\\Delta\\Omega^re*u_12^im*u_21^re**2 + \\Delta\\Omega^re*u_12^re*u_21^im**2 + I*\\Delta\\Omega^re*u_12^re*u_21^im*u_21^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)],\n", | |
| "[2*(-I*\\Delta\\Omega^im*u_11^im*u_12^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_12^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_12^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^re*u_22^im + I*\\Delta\\Omega^im*u_12^im**2*u_21^re + I*\\Delta\\Omega^im*u_12^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_12^im*u_12^re*u_21^re - \\Delta\\Omega^im*u_12^re**2*u_21^im - \\Delta\\Omega^re*u_11^im*u_12^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_12^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_12^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_22^im + \\Delta\\Omega^re*u_12^im**2*u_21^re + \\Delta\\Omega^re*u_12^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_12^im*u_12^re*u_21^re + I*\\Delta\\Omega^re*u_12^re**2*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(I*\\Delta\\Omega^im*u_11^im**2*u_22^re + I*\\Delta\\Omega^im*u_11^im*u_11^re*u_22^im - \\Delta\\Omega^im*u_11^im*u_11^re*u_22^re - I*\\Delta\\Omega^im*u_11^im*u_12^im*u_21^re - I*\\Delta\\Omega^im*u_11^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_11^re**2*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^im*u_21^re + \\Delta\\Omega^im*u_11^re*u_12^re*u_21^im + \\Delta\\Omega^re*u_11^im**2*u_22^re + \\Delta\\Omega^re*u_11^im*u_11^re*u_22^im + I*\\Delta\\Omega^re*u_11^im*u_11^re*u_22^re - \\Delta\\Omega^re*u_11^im*u_12^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_11^re**2*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^im*u_21^re - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)]])" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "diff_old = (lhs - rhs_old).simplify();\n", | |
| "diff_old" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "7ea2458c", | |
| "metadata": {}, | |
| "source": [ | |
| "Note that for U ∈ ℝ, the old definition works:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "d4729acd", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.707578Z", | |
| "start_time": "2022-03-20T03:24:49.682794Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[0, 0],\n", | |
| "[0, 0]])" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "diff_old.subs({sym: 0 for sym in imag(U).free_symbols})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "35a77f90", | |
| "metadata": {}, | |
| "source": [ | |
| "(Even without plugging in values, this is pretty obvious, since $\\det U \\in \\mathbb{R}$)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4723de22", | |
| "metadata": {}, | |
| "source": [ | |
| "The new RHS works unconditionally:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "192d6a30", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.711744Z", | |
| "start_time": "2022-03-20T03:24:49.708496Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[0, 0],\n", | |
| "[0, 0]])" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "(lhs - rhs_new).simplify()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0f0edf33", | |
| "metadata": {}, | |
| "source": [ | |
| "## A \"Well-conditioned\" matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d56dd253", | |
| "metadata": {}, | |
| "source": [ | |
| "The current tests in `ChainRules` actually include a check for the `rrule` of `det` for a complex matrix. However, the test matrix is a \"well-conditioned matrix\" defined as follows:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "6e08800c", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.714291Z", | |
| "start_time": "2022-03-20T03:24:49.712576Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def v(i, j, part=None):\n", | |
| " \"\"\"Complex elements of the matrix V\"\"\"\n", | |
| " return csym(\"V\", i, j, part)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "a723dd40", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.717562Z", | |
| "start_time": "2022-03-20T03:24:49.715232Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "V = Matrix([[v(1, 1), v(1, 2)], [v(2, 1), v(2, 2)]])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "38f09ee6", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.826463Z", | |
| "start_time": "2022-03-20T03:24:49.718728Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle \\left[\\begin{matrix}\\left(V^{im}_{11}\\right)^{2} + \\left(V^{re}_{11}\\right)^{2} + \\left(V^{im}_{12}\\right)^{2} + \\left(V^{re}_{12}\\right)^{2} + 1 & - \\left(i V^{im}_{11} + V^{re}_{11}\\right) \\left(i V^{im}_{21} - V^{re}_{21}\\right) - \\left(i V^{im}_{12} + V^{re}_{12}\\right) \\left(i V^{im}_{22} - V^{re}_{22}\\right)\\\\- \\left(i V^{im}_{11} - V^{re}_{11}\\right) \\left(i V^{im}_{21} + V^{re}_{21}\\right) - \\left(i V^{im}_{12} - V^{re}_{12}\\right) \\left(i V^{im}_{22} + V^{re}_{22}\\right) & \\left(V^{im}_{21}\\right)^{2} + \\left(V^{re}_{21}\\right)^{2} + \\left(V^{im}_{22}\\right)^{2} + \\left(V^{re}_{22}\\right)^{2} + 1\\end{matrix}\\right]$" | |
| ], | |
| "text/plain": [ | |
| "Matrix([\n", | |
| "[ V_11^im**2 + V_11^re**2 + V_12^im**2 + V_12^re**2 + 1, -(I*V_11^im + V_11^re)*(I*V_21^im - V_21^re) - (I*V_12^im + V_12^re)*(I*V_22^im - V_22^re)],\n", | |
| "[-(I*V_11^im - V_11^re)*(I*V_21^im + V_21^re) - (I*V_12^im - V_12^re)*(I*V_22^im + V_22^re), V_21^im**2 + V_21^re**2 + V_22^im**2 + V_22^re**2 + 1]])" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W = simplify(V * dagger(V) + eye(2));\n", | |
| "W" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3f317de6", | |
| "metadata": {}, | |
| "source": [ | |
| "We find that $\\det(W) \\in \\mathbb{R}$:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "a142e80b", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.847133Z", | |
| "start_time": "2022-03-20T03:24:49.828828Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle 0$" | |
| ], | |
| "text/plain": [ | |
| "0" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "imag(det(W))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "790caf8f", | |
| "metadata": {}, | |
| "source": [ | |
| "This is unlike the determinant of an arbitary matrix:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "691257d8", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2022-03-20T03:24:49.856447Z", | |
| "start_time": "2022-03-20T03:24:49.848120Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/latex": [ | |
| "$\\displaystyle - \\frac{i \\left(2 i u^{im}_{11} u^{re}_{22} + 2 i u^{re}_{11} u^{im}_{22} - 2 i u^{im}_{12} u^{re}_{21} - 2 i u^{re}_{12} u^{im}_{21}\\right)}{2}$" | |
| ], | |
| "text/plain": [ | |
| "-I*(2*I*u_11^im*u_22^re + 2*I*u_11^re*u_22^im - 2*I*u_12^im*u_21^re - 2*I*u_12^re*u_21^im)/2" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "imag(det(U))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "83b565ca", | |
| "metadata": {}, | |
| "source": [ | |
| "and explains why this bug is not caught by the current tests." | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "hide_input": false, | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.7" | |
| }, | |
| "toc": { | |
| "base_numbering": 1, | |
| "nav_menu": {}, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "title_cell": "Table of Contents", | |
| "title_sidebar": "Contents", | |
| "toc_cell": false, | |
| "toc_position": {}, | |
| "toc_section_display": true, | |
| "toc_window_display": false | |
| }, | |
| "varInspector": { | |
| "cols": { | |
| "lenName": 16, | |
| "lenType": 16, | |
| "lenVar": 40 | |
| }, | |
| "kernels_config": { | |
| "python": { | |
| "delete_cmd_postfix": "", | |
| "delete_cmd_prefix": "del ", | |
| "library": "var_list.py", | |
| "varRefreshCmd": "print(var_dic_list())" | |
| }, | |
| "r": { | |
| "delete_cmd_postfix": ") ", | |
| "delete_cmd_prefix": "rm(", | |
| "library": "var_list.r", | |
| "varRefreshCmd": "cat(var_dic_list()) " | |
| } | |
| }, | |
| "types_to_exclude": [ | |
| "module", | |
| "function", | |
| "builtin_function_or_method", | |
| "instance", | |
| "_Feature" | |
| ], | |
| "window_display": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment