Last active
March 20, 2022 03:54
-
-
Save goerz/65b117c6c6bca0dd2a650d9b8b0b05a5 to your computer and use it in GitHub Desktop.
Analytical Reverse-AD for the determinant of a complex matrix (https://github.com/JuliaDiff/ChainRules.jl/issues/600)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "3620199d", | |
"metadata": {}, | |
"source": [ | |
"# Reverse-AD for the determinant of a complex matrix" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c3cfc00", | |
"metadata": {}, | |
"source": [ | |
"Here, we illustrate for the example of a complex 2×2 matrix $U$, that for $\\Omega = \\det(U)$ and a perturbation $\\Delta\\Omega$, the correct `ChainRules.rrule` is $\\bar\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$, not $\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$. That is, it involves the complex conjugate of the determinant, not the *value* of the determinant." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "c6a9a4b2", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.721701Z", | |
"start_time": "2022-03-20T03:24:43.379277Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from sympy import *" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "22f328a0", | |
"metadata": {}, | |
"source": [ | |
"## Helper Routines" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "7a207be4", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.726000Z", | |
"start_time": "2022-03-20T03:24:43.722858Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def csym(name, i=None, j=None, part=None):\n", | |
" \"\"\"Create a symbolic complex number.\"\"\"\n", | |
" if part is None:\n", | |
" return csym(name, i, j, 're') + I * csym(name, i, j, 'im')\n", | |
" else:\n", | |
" if i is None or j is None:\n", | |
" return symbols('{name}^{part}'.format(name=name, part=part), real=true)\n", | |
" else:\n", | |
" return symbols('{name}_{i}{j}^{part}'.format(name=name, i=i, j=j, part=part), real=true)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "47d597ab", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.728423Z", | |
"start_time": "2022-03-20T03:24:43.726911Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def real(z):\n", | |
" return (z + z.conjugate()) / 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "9d2ce0ff", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.731476Z", | |
"start_time": "2022-03-20T03:24:43.729824Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def imag(z):\n", | |
" return (z - z.conjugate()) / (2 * I)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "43660fb7", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.734266Z", | |
"start_time": "2022-03-20T03:24:43.732476Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"conj = conjugate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "42dad762", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.736840Z", | |
"start_time": "2022-03-20T03:24:43.735178Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def dagger(U):\n", | |
" return conj(transpose(U))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f4c06038", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.740421Z", | |
"start_time": "2022-03-20T03:24:43.737935Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def deriv_scalar_matrix(J, M):\n", | |
" \"\"\"Return the matrix that is the derivative of the real-valued scalar \n", | |
" $J$ with respect to the real-valued matrix $M$. The $ij$ element of\n", | |
" the result is defined as $\\frac{\\partial J}{\\partial M_{ji}}$, see\n", | |
" https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-matrix\n", | |
" \n", | |
" Note the implicit transpose in this equation!\n", | |
" \"\"\"\n", | |
" n, m = M.shape\n", | |
" return Matrix(\n", | |
" [\n", | |
" [J.diff(M[j, i]) for i in range(n)]\n", | |
" for j in range(m)\n", | |
" ]\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d6e74317", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.743100Z", | |
"start_time": "2022-03-20T03:24:43.741305Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def u(i, j, part=None):\n", | |
" \"\"\"Complex elements of the matrix U\"\"\"\n", | |
" return csym(\"u\", i, j, part)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6cdeb9ba", | |
"metadata": {}, | |
"source": [ | |
"## Definitions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "5926cf50", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.874856Z", | |
"start_time": "2022-03-20T03:24:43.743838Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}i u^{im}_{11} + u^{re}_{11} & i u^{im}_{12} + u^{re}_{12}\\\\i u^{im}_{21} + u^{re}_{21} & i u^{im}_{22} + u^{re}_{22}\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[I*u_11^im + u_11^re, I*u_12^im + u_12^re],\n", | |
"[I*u_21^im + u_21^re, I*u_22^im + u_22^re]])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"U = Matrix([[u(1, 1), u(1, 2)], [u(2, 1), u(2, 2)]])\n", | |
"U" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "675c7fe4", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.886047Z", | |
"start_time": "2022-03-20T03:24:43.877263Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle - u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} + u^{re}_{11} u^{re}_{22} + u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} - u^{re}_{12} u^{re}_{21}$" | |
], | |
"text/plain": [ | |
"-u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im + u_11^re*u_22^re + u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im - u_12^re*u_21^re" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Ω = det(U)\n", | |
"Ω" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "7f961285", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.890044Z", | |
"start_time": "2022-03-20T03:24:43.886863Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle i \\Delta\\Omega^{im} + \\Delta\\Omega^{re}$" | |
], | |
"text/plain": [ | |
"I*\\Delta\\Omega^im + \\Delta\\Omega^re" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ΔΩ = csym(r\"\\Delta\\Omega\");\n", | |
"ΔΩ" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "09ed58e0", | |
"metadata": {}, | |
"source": [ | |
"## Left Hand Side" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6f25fa15", | |
"metadata": {}, | |
"source": [ | |
"On the left-hand side, we have the equation for `rrule` from https://juliadiff.org/ChainRulesCore.jl/dev/maths/complex.html" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "66decf6f", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:43.921962Z", | |
"start_time": "2022-03-20T03:24:43.890903Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n", | |
"[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lhs = (\n", | |
" real(ΔΩ) * deriv_scalar_matrix(real(Ω), real(U))\n", | |
" + imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), real(U))\n", | |
" + I * real(ΔΩ) * deriv_scalar_matrix(real(Ω), imag(U))\n", | |
" + I * imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), imag(U))\n", | |
");\n", | |
"lhs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bbfca86f", | |
"metadata": {}, | |
"source": [ | |
"## Right Hand Side" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fb426a79", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-19T21:48:47.126422Z", | |
"start_time": "2022-03-19T21:48:47.113734Z" | |
} | |
}, | |
"source": [ | |
"On the right hand side, we have to formula corresponding to the code at https://github.com/JuliaDiff/ChainRules.jl/blob/9023d898a0b957bd9b3baab6bc38b54822d6963a/src/rulesets/LinearAlgebra/dense.jl#L132" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "e67fe9d3", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:46.681988Z", | |
"start_time": "2022-03-20T03:24:43.922856Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"rhs_old = simplify(Ω * ΔΩ * dagger(U.inv())).expand()\n", | |
"rhs_old;" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2a2ddafc", | |
"metadata": {}, | |
"source": [ | |
"respectively the corrected version:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "00ba42a0", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:47.224201Z", | |
"start_time": "2022-03-20T03:24:46.682738Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n", | |
"[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rhs_new = simplify(conj(Ω) * ΔΩ * dagger(U.inv())).expand()\n", | |
"rhs_new" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "77b64d5c", | |
"metadata": {}, | |
"source": [ | |
"## Check" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "08a600bb", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.681910Z", | |
"start_time": "2022-03-20T03:24:47.225137Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}\\frac{2 \\left(i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{im} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{22} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{22} u^{re}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{22}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{21} u^{im}_{22} + i \\Delta\\Omega^{im} u^{im}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{im} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{21} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{re} u^{im}_{12} u^{im}_{21} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\\\\\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{22} + i \\Delta\\Omega^{im} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{re}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{22} + \\Delta\\Omega^{re} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{re}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(i \\Delta\\Omega^{im} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{im}_{22} - \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{re}_{21} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{21} + \\Delta\\Omega^{re} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{im}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[ 2*(I*\\Delta\\Omega^im*u_11^im*u_22^im*u_22^re - \\Delta\\Omega^im*u_11^im*u_22^re**2 + I*\\Delta\\Omega^im*u_11^re*u_22^im**2 - \\Delta\\Omega^im*u_11^re*u_22^im*u_22^re - I*\\Delta\\Omega^im*u_12^im*u_21^re*u_22^im + \\Delta\\Omega^im*u_12^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_12^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_12^re*u_21^im*u_22^re + \\Delta\\Omega^re*u_11^im*u_22^im*u_22^re + I*\\Delta\\Omega^re*u_11^im*u_22^re**2 + \\Delta\\Omega^re*u_11^re*u_22^im**2 + I*\\Delta\\Omega^re*u_11^re*u_22^im*u_22^re - \\Delta\\Omega^re*u_12^im*u_21^re*u_22^im - I*\\Delta\\Omega^re*u_12^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_12^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_12^re*u_21^im*u_22^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(-I*\\Delta\\Omega^im*u_11^im*u_21^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_21^re*u_22^im + I*\\Delta\\Omega^im*u_12^im*u_21^im*u_21^re - \\Delta\\Omega^im*u_12^im*u_21^re**2 + I*\\Delta\\Omega^im*u_12^re*u_21^im**2 - \\Delta\\Omega^im*u_12^re*u_21^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_21^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_21^re*u_22^im + \\Delta\\Omega^re*u_12^im*u_21^im*u_21^re + I*\\Delta\\Omega^re*u_12^im*u_21^re**2 + \\Delta\\Omega^re*u_12^re*u_21^im**2 + I*\\Delta\\Omega^re*u_12^re*u_21^im*u_21^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)],\n", | |
"[2*(-I*\\Delta\\Omega^im*u_11^im*u_12^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_12^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_12^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^re*u_22^im + I*\\Delta\\Omega^im*u_12^im**2*u_21^re + I*\\Delta\\Omega^im*u_12^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_12^im*u_12^re*u_21^re - \\Delta\\Omega^im*u_12^re**2*u_21^im - \\Delta\\Omega^re*u_11^im*u_12^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_12^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_12^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_22^im + \\Delta\\Omega^re*u_12^im**2*u_21^re + \\Delta\\Omega^re*u_12^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_12^im*u_12^re*u_21^re + I*\\Delta\\Omega^re*u_12^re**2*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(I*\\Delta\\Omega^im*u_11^im**2*u_22^re + I*\\Delta\\Omega^im*u_11^im*u_11^re*u_22^im - \\Delta\\Omega^im*u_11^im*u_11^re*u_22^re - I*\\Delta\\Omega^im*u_11^im*u_12^im*u_21^re - I*\\Delta\\Omega^im*u_11^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_11^re**2*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^im*u_21^re + \\Delta\\Omega^im*u_11^re*u_12^re*u_21^im + \\Delta\\Omega^re*u_11^im**2*u_22^re + \\Delta\\Omega^re*u_11^im*u_11^re*u_22^im + I*\\Delta\\Omega^re*u_11^im*u_11^re*u_22^re - \\Delta\\Omega^re*u_11^im*u_12^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_11^re**2*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^im*u_21^re - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)]])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"diff_old = (lhs - rhs_old).simplify();\n", | |
"diff_old" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7ea2458c", | |
"metadata": {}, | |
"source": [ | |
"Note that for U ∈ ℝ, the old definition works:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "d4729acd", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.707578Z", | |
"start_time": "2022-03-20T03:24:49.682794Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[0, 0],\n", | |
"[0, 0]])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"diff_old.subs({sym: 0 for sym in imag(U).free_symbols})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "35a77f90", | |
"metadata": {}, | |
"source": [ | |
"(Even without plugging in values, this is pretty obvious, since $\\det U \\in \\mathbb{R}$)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4723de22", | |
"metadata": {}, | |
"source": [ | |
"The new RHS works unconditionally:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "192d6a30", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.711744Z", | |
"start_time": "2022-03-20T03:24:49.708496Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[0, 0],\n", | |
"[0, 0]])" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(lhs - rhs_new).simplify()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0f0edf33", | |
"metadata": {}, | |
"source": [ | |
"## A \"Well-conditioned\" matrix" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d56dd253", | |
"metadata": {}, | |
"source": [ | |
"The current tests in `ChainRules` actually include a check for the `rrule` of `det` for a complex matrix. However, the test matrix is a \"well-conditioned matrix\" defined as follows:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "6e08800c", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.714291Z", | |
"start_time": "2022-03-20T03:24:49.712576Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def v(i, j, part=None):\n", | |
" \"\"\"Complex elements of the matrix V\"\"\"\n", | |
" return csym(\"V\", i, j, part)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "a723dd40", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.717562Z", | |
"start_time": "2022-03-20T03:24:49.715232Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"V = Matrix([[v(1, 1), v(1, 2)], [v(2, 1), v(2, 2)]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "38f09ee6", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.826463Z", | |
"start_time": "2022-03-20T03:24:49.718728Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\left[\\begin{matrix}\\left(V^{im}_{11}\\right)^{2} + \\left(V^{re}_{11}\\right)^{2} + \\left(V^{im}_{12}\\right)^{2} + \\left(V^{re}_{12}\\right)^{2} + 1 & - \\left(i V^{im}_{11} + V^{re}_{11}\\right) \\left(i V^{im}_{21} - V^{re}_{21}\\right) - \\left(i V^{im}_{12} + V^{re}_{12}\\right) \\left(i V^{im}_{22} - V^{re}_{22}\\right)\\\\- \\left(i V^{im}_{11} - V^{re}_{11}\\right) \\left(i V^{im}_{21} + V^{re}_{21}\\right) - \\left(i V^{im}_{12} - V^{re}_{12}\\right) \\left(i V^{im}_{22} + V^{re}_{22}\\right) & \\left(V^{im}_{21}\\right)^{2} + \\left(V^{re}_{21}\\right)^{2} + \\left(V^{im}_{22}\\right)^{2} + \\left(V^{re}_{22}\\right)^{2} + 1\\end{matrix}\\right]$" | |
], | |
"text/plain": [ | |
"Matrix([\n", | |
"[ V_11^im**2 + V_11^re**2 + V_12^im**2 + V_12^re**2 + 1, -(I*V_11^im + V_11^re)*(I*V_21^im - V_21^re) - (I*V_12^im + V_12^re)*(I*V_22^im - V_22^re)],\n", | |
"[-(I*V_11^im - V_11^re)*(I*V_21^im + V_21^re) - (I*V_12^im - V_12^re)*(I*V_22^im + V_22^re), V_21^im**2 + V_21^re**2 + V_22^im**2 + V_22^re**2 + 1]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W = simplify(V * dagger(V) + eye(2));\n", | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3f317de6", | |
"metadata": {}, | |
"source": [ | |
"We find that $\\det(W) \\in \\mathbb{R}$:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "a142e80b", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.847133Z", | |
"start_time": "2022-03-20T03:24:49.828828Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 0$" | |
], | |
"text/plain": [ | |
"0" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"imag(det(W))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "790caf8f", | |
"metadata": {}, | |
"source": [ | |
"This is unlike the determinant of an arbitary matrix:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "691257d8", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2022-03-20T03:24:49.856447Z", | |
"start_time": "2022-03-20T03:24:49.848120Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle - \\frac{i \\left(2 i u^{im}_{11} u^{re}_{22} + 2 i u^{re}_{11} u^{im}_{22} - 2 i u^{im}_{12} u^{re}_{21} - 2 i u^{re}_{12} u^{im}_{21}\\right)}{2}$" | |
], | |
"text/plain": [ | |
"-I*(2*I*u_11^im*u_22^re + 2*I*u_11^re*u_22^im - 2*I*u_12^im*u_21^re - 2*I*u_12^re*u_21^im)/2" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"imag(det(U))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "83b565ca", | |
"metadata": {}, | |
"source": [ | |
"and explains why this bug is not caught by the current tests." | |
] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
}, | |
"varInspector": { | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"delete_cmd_postfix": "", | |
"delete_cmd_prefix": "del ", | |
"library": "var_list.py", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"delete_cmd_postfix": ") ", | |
"delete_cmd_prefix": "rm(", | |
"library": "var_list.r", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
], | |
"window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment