Skip to content

Instantly share code, notes, and snippets.

@goerz
Last active March 20, 2022 03:54
Show Gist options
  • Save goerz/65b117c6c6bca0dd2a650d9b8b0b05a5 to your computer and use it in GitHub Desktop.
Save goerz/65b117c6c6bca0dd2a650d9b8b0b05a5 to your computer and use it in GitHub Desktop.
Analytical Reverse-AD for the determinant of a complex matrix (https://github.com/JuliaDiff/ChainRules.jl/issues/600)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "3620199d",
"metadata": {},
"source": [
"# Reverse-AD for the determinant of a complex matrix"
]
},
{
"cell_type": "markdown",
"id": "1c3cfc00",
"metadata": {},
"source": [
"Here, we illustrate for the example of a complex 2×2 matrix $U$, that for $\\Omega = \\det(U)$ and a perturbation $\\Delta\\Omega$, the correct `ChainRules.rrule` is $\\bar\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$, not $\\Omega\\, \\Delta\\Omega\\, ({U^{-1}})^\\dagger$. That is, it involves the complex conjugate of the determinant, not the *value* of the determinant."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c6a9a4b2",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.721701Z",
"start_time": "2022-03-20T03:24:43.379277Z"
}
},
"outputs": [],
"source": [
"from sympy import *"
]
},
{
"cell_type": "markdown",
"id": "22f328a0",
"metadata": {},
"source": [
"## Helper Routines"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7a207be4",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.726000Z",
"start_time": "2022-03-20T03:24:43.722858Z"
}
},
"outputs": [],
"source": [
"def csym(name, i=None, j=None, part=None):\n",
" \"\"\"Create a symbolic complex number.\"\"\"\n",
" if part is None:\n",
" return csym(name, i, j, 're') + I * csym(name, i, j, 'im')\n",
" else:\n",
" if i is None or j is None:\n",
" return symbols('{name}^{part}'.format(name=name, part=part), real=true)\n",
" else:\n",
" return symbols('{name}_{i}{j}^{part}'.format(name=name, i=i, j=j, part=part), real=true)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "47d597ab",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.728423Z",
"start_time": "2022-03-20T03:24:43.726911Z"
}
},
"outputs": [],
"source": [
"def real(z):\n",
" return (z + z.conjugate()) / 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9d2ce0ff",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.731476Z",
"start_time": "2022-03-20T03:24:43.729824Z"
}
},
"outputs": [],
"source": [
"def imag(z):\n",
" return (z - z.conjugate()) / (2 * I)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "43660fb7",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.734266Z",
"start_time": "2022-03-20T03:24:43.732476Z"
}
},
"outputs": [],
"source": [
"conj = conjugate"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "42dad762",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.736840Z",
"start_time": "2022-03-20T03:24:43.735178Z"
}
},
"outputs": [],
"source": [
"def dagger(U):\n",
" return conj(transpose(U))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f4c06038",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.740421Z",
"start_time": "2022-03-20T03:24:43.737935Z"
}
},
"outputs": [],
"source": [
"def deriv_scalar_matrix(J, M):\n",
" \"\"\"Return the matrix that is the derivative of the real-valued scalar \n",
" $J$ with respect to the real-valued matrix $M$. The $ij$ element of\n",
" the result is defined as $\\frac{\\partial J}{\\partial M_{ji}}$, see\n",
" https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-matrix\n",
" \n",
" Note the implicit transpose in this equation!\n",
" \"\"\"\n",
" n, m = M.shape\n",
" return Matrix(\n",
" [\n",
" [J.diff(M[j, i]) for i in range(n)]\n",
" for j in range(m)\n",
" ]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d6e74317",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.743100Z",
"start_time": "2022-03-20T03:24:43.741305Z"
}
},
"outputs": [],
"source": [
"def u(i, j, part=None):\n",
" \"\"\"Complex elements of the matrix U\"\"\"\n",
" return csym(\"u\", i, j, part)"
]
},
{
"cell_type": "markdown",
"id": "6cdeb9ba",
"metadata": {},
"source": [
"## Definitions"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5926cf50",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.874856Z",
"start_time": "2022-03-20T03:24:43.743838Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}i u^{im}_{11} + u^{re}_{11} & i u^{im}_{12} + u^{re}_{12}\\\\i u^{im}_{21} + u^{re}_{21} & i u^{im}_{22} + u^{re}_{22}\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[I*u_11^im + u_11^re, I*u_12^im + u_12^re],\n",
"[I*u_21^im + u_21^re, I*u_22^im + u_22^re]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"U = Matrix([[u(1, 1), u(1, 2)], [u(2, 1), u(2, 2)]])\n",
"U"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "675c7fe4",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.886047Z",
"start_time": "2022-03-20T03:24:43.877263Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} + u^{re}_{11} u^{re}_{22} + u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} - u^{re}_{12} u^{re}_{21}$"
],
"text/plain": [
"-u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im + u_11^re*u_22^re + u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im - u_12^re*u_21^re"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Ω = det(U)\n",
"Ω"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7f961285",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.890044Z",
"start_time": "2022-03-20T03:24:43.886863Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle i \\Delta\\Omega^{im} + \\Delta\\Omega^{re}$"
],
"text/plain": [
"I*\\Delta\\Omega^im + \\Delta\\Omega^re"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ΔΩ = csym(r\"\\Delta\\Omega\");\n",
"ΔΩ"
]
},
{
"cell_type": "markdown",
"id": "09ed58e0",
"metadata": {},
"source": [
"## Left Hand Side"
]
},
{
"cell_type": "markdown",
"id": "6f25fa15",
"metadata": {},
"source": [
"On the left-hand side, we have the equation for `rrule` from https://juliadiff.org/ChainRulesCore.jl/dev/maths/complex.html"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "66decf6f",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:43.921962Z",
"start_time": "2022-03-20T03:24:43.890903Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n",
"[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lhs = (\n",
" real(ΔΩ) * deriv_scalar_matrix(real(Ω), real(U))\n",
" + imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), real(U))\n",
" + I * real(ΔΩ) * deriv_scalar_matrix(real(Ω), imag(U))\n",
" + I * imag(ΔΩ) * deriv_scalar_matrix(imag(Ω), imag(U))\n",
");\n",
"lhs"
]
},
{
"cell_type": "markdown",
"id": "bbfca86f",
"metadata": {},
"source": [
"## Right Hand Side"
]
},
{
"cell_type": "markdown",
"id": "fb426a79",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-19T21:48:47.126422Z",
"start_time": "2022-03-19T21:48:47.113734Z"
}
},
"source": [
"On the right hand side, we have to formula corresponding to the code at https://github.com/JuliaDiff/ChainRules.jl/blob/9023d898a0b957bd9b3baab6bc38b54822d6963a/src/rulesets/LinearAlgebra/dense.jl#L132"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e67fe9d3",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:46.681988Z",
"start_time": "2022-03-20T03:24:43.922856Z"
}
},
"outputs": [],
"source": [
"rhs_old = simplify(Ω * ΔΩ * dagger(U.inv())).expand()\n",
"rhs_old;"
]
},
{
"cell_type": "markdown",
"id": "2a2ddafc",
"metadata": {},
"source": [
"respectively the corrected version:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "00ba42a0",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:47.224201Z",
"start_time": "2022-03-20T03:24:46.682738Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}\\Delta\\Omega^{im} u^{im}_{22} + i \\Delta\\Omega^{im} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{22} + \\Delta\\Omega^{re} u^{re}_{22} & - \\Delta\\Omega^{im} u^{im}_{21} - i \\Delta\\Omega^{im} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{21} - \\Delta\\Omega^{re} u^{re}_{21}\\\\- \\Delta\\Omega^{im} u^{im}_{12} - i \\Delta\\Omega^{im} u^{re}_{12} + i \\Delta\\Omega^{re} u^{im}_{12} - \\Delta\\Omega^{re} u^{re}_{12} & \\Delta\\Omega^{im} u^{im}_{11} + i \\Delta\\Omega^{im} u^{re}_{11} - i \\Delta\\Omega^{re} u^{im}_{11} + \\Delta\\Omega^{re} u^{re}_{11}\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[ \\Delta\\Omega^im*u_22^im + I*\\Delta\\Omega^im*u_22^re - I*\\Delta\\Omega^re*u_22^im + \\Delta\\Omega^re*u_22^re, -\\Delta\\Omega^im*u_21^im - I*\\Delta\\Omega^im*u_21^re + I*\\Delta\\Omega^re*u_21^im - \\Delta\\Omega^re*u_21^re],\n",
"[-\\Delta\\Omega^im*u_12^im - I*\\Delta\\Omega^im*u_12^re + I*\\Delta\\Omega^re*u_12^im - \\Delta\\Omega^re*u_12^re, \\Delta\\Omega^im*u_11^im + I*\\Delta\\Omega^im*u_11^re - I*\\Delta\\Omega^re*u_11^im + \\Delta\\Omega^re*u_11^re]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rhs_new = simplify(conj(Ω) * ΔΩ * dagger(U.inv())).expand()\n",
"rhs_new"
]
},
{
"cell_type": "markdown",
"id": "77b64d5c",
"metadata": {},
"source": [
"## Check"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "08a600bb",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.681910Z",
"start_time": "2022-03-20T03:24:47.225137Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}\\frac{2 \\left(i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{im} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{22} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{22} u^{re}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} \\left(u^{re}_{22}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{11} \\left(u^{im}_{22}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{22} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{22}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{21} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{21} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{21} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{21} u^{im}_{22} + i \\Delta\\Omega^{im} u^{im}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{im} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + i \\Delta\\Omega^{im} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} - \\Delta\\Omega^{im} u^{re}_{12} u^{im}_{21} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{21} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{21} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{21} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{21} u^{im}_{22} + \\Delta\\Omega^{re} u^{im}_{12} u^{im}_{21} u^{re}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} \\left(u^{re}_{21}\\right)^{2} + \\Delta\\Omega^{re} u^{re}_{12} \\left(u^{im}_{21}\\right)^{2} + i \\Delta\\Omega^{re} u^{re}_{12} u^{im}_{21} u^{re}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\\\\\frac{2 \\left(- i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{22} + \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{re}_{22} - i \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{22} + i \\Delta\\Omega^{im} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + i \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} u^{im}_{12} u^{re}_{12} u^{re}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{22} - i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{re}_{22} - \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{22} + \\Delta\\Omega^{re} \\left(u^{im}_{12}\\right)^{2} u^{re}_{21} + \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} u^{im}_{12} u^{re}_{12} u^{re}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{12}\\right)^{2} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}} & \\frac{2 \\left(i \\Delta\\Omega^{im} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{im}_{22} - \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{11} u^{re}_{22} - i \\Delta\\Omega^{im} u^{im}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{im} u^{im}_{11} u^{re}_{12} u^{im}_{21} - \\Delta\\Omega^{im} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} + \\Delta\\Omega^{im} u^{re}_{11} u^{im}_{12} u^{re}_{21} + \\Delta\\Omega^{im} u^{re}_{11} u^{re}_{12} u^{im}_{21} + \\Delta\\Omega^{re} \\left(u^{im}_{11}\\right)^{2} u^{re}_{22} + \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{im}_{22} + i \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{11} u^{re}_{22} - \\Delta\\Omega^{re} u^{im}_{11} u^{im}_{12} u^{re}_{21} - \\Delta\\Omega^{re} u^{im}_{11} u^{re}_{12} u^{im}_{21} + i \\Delta\\Omega^{re} \\left(u^{re}_{11}\\right)^{2} u^{im}_{22} - i \\Delta\\Omega^{re} u^{re}_{11} u^{im}_{12} u^{re}_{21} - i \\Delta\\Omega^{re} u^{re}_{11} u^{re}_{12} u^{im}_{21}\\right)}{u^{im}_{11} u^{im}_{22} + i u^{im}_{11} u^{re}_{22} + i u^{re}_{11} u^{im}_{22} - u^{re}_{11} u^{re}_{22} - u^{im}_{12} u^{im}_{21} - i u^{im}_{12} u^{re}_{21} - i u^{re}_{12} u^{im}_{21} + u^{re}_{12} u^{re}_{21}}\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[ 2*(I*\\Delta\\Omega^im*u_11^im*u_22^im*u_22^re - \\Delta\\Omega^im*u_11^im*u_22^re**2 + I*\\Delta\\Omega^im*u_11^re*u_22^im**2 - \\Delta\\Omega^im*u_11^re*u_22^im*u_22^re - I*\\Delta\\Omega^im*u_12^im*u_21^re*u_22^im + \\Delta\\Omega^im*u_12^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_12^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_12^re*u_21^im*u_22^re + \\Delta\\Omega^re*u_11^im*u_22^im*u_22^re + I*\\Delta\\Omega^re*u_11^im*u_22^re**2 + \\Delta\\Omega^re*u_11^re*u_22^im**2 + I*\\Delta\\Omega^re*u_11^re*u_22^im*u_22^re - \\Delta\\Omega^re*u_12^im*u_21^re*u_22^im - I*\\Delta\\Omega^re*u_12^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_12^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_12^re*u_21^im*u_22^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(-I*\\Delta\\Omega^im*u_11^im*u_21^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_21^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_21^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_21^re*u_22^im + I*\\Delta\\Omega^im*u_12^im*u_21^im*u_21^re - \\Delta\\Omega^im*u_12^im*u_21^re**2 + I*\\Delta\\Omega^im*u_12^re*u_21^im**2 - \\Delta\\Omega^im*u_12^re*u_21^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_21^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_21^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_21^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_21^re*u_22^im + \\Delta\\Omega^re*u_12^im*u_21^im*u_21^re + I*\\Delta\\Omega^re*u_12^im*u_21^re**2 + \\Delta\\Omega^re*u_12^re*u_21^im**2 + I*\\Delta\\Omega^re*u_12^re*u_21^im*u_21^re)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)],\n",
"[2*(-I*\\Delta\\Omega^im*u_11^im*u_12^im*u_22^re + \\Delta\\Omega^im*u_11^im*u_12^re*u_22^re - I*\\Delta\\Omega^im*u_11^re*u_12^im*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^re*u_22^im + I*\\Delta\\Omega^im*u_12^im**2*u_21^re + I*\\Delta\\Omega^im*u_12^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_12^im*u_12^re*u_21^re - \\Delta\\Omega^im*u_12^re**2*u_21^im - \\Delta\\Omega^re*u_11^im*u_12^im*u_22^re - I*\\Delta\\Omega^re*u_11^im*u_12^re*u_22^re - \\Delta\\Omega^re*u_11^re*u_12^im*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_22^im + \\Delta\\Omega^re*u_12^im**2*u_21^re + \\Delta\\Omega^re*u_12^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_12^im*u_12^re*u_21^re + I*\\Delta\\Omega^re*u_12^re**2*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re), 2*(I*\\Delta\\Omega^im*u_11^im**2*u_22^re + I*\\Delta\\Omega^im*u_11^im*u_11^re*u_22^im - \\Delta\\Omega^im*u_11^im*u_11^re*u_22^re - I*\\Delta\\Omega^im*u_11^im*u_12^im*u_21^re - I*\\Delta\\Omega^im*u_11^im*u_12^re*u_21^im - \\Delta\\Omega^im*u_11^re**2*u_22^im + \\Delta\\Omega^im*u_11^re*u_12^im*u_21^re + \\Delta\\Omega^im*u_11^re*u_12^re*u_21^im + \\Delta\\Omega^re*u_11^im**2*u_22^re + \\Delta\\Omega^re*u_11^im*u_11^re*u_22^im + I*\\Delta\\Omega^re*u_11^im*u_11^re*u_22^re - \\Delta\\Omega^re*u_11^im*u_12^im*u_21^re - \\Delta\\Omega^re*u_11^im*u_12^re*u_21^im + I*\\Delta\\Omega^re*u_11^re**2*u_22^im - I*\\Delta\\Omega^re*u_11^re*u_12^im*u_21^re - I*\\Delta\\Omega^re*u_11^re*u_12^re*u_21^im)/(u_11^im*u_22^im + I*u_11^im*u_22^re + I*u_11^re*u_22^im - u_11^re*u_22^re - u_12^im*u_21^im - I*u_12^im*u_21^re - I*u_12^re*u_21^im + u_12^re*u_21^re)]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diff_old = (lhs - rhs_old).simplify();\n",
"diff_old"
]
},
{
"cell_type": "markdown",
"id": "7ea2458c",
"metadata": {},
"source": [
"Note that for U ∈ ℝ, the old definition works:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d4729acd",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.707578Z",
"start_time": "2022-03-20T03:24:49.682794Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[0, 0],\n",
"[0, 0]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diff_old.subs({sym: 0 for sym in imag(U).free_symbols})"
]
},
{
"cell_type": "markdown",
"id": "35a77f90",
"metadata": {},
"source": [
"(Even without plugging in values, this is pretty obvious, since $\\det U \\in \\mathbb{R}$)"
]
},
{
"cell_type": "markdown",
"id": "4723de22",
"metadata": {},
"source": [
"The new RHS works unconditionally:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "192d6a30",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.711744Z",
"start_time": "2022-03-20T03:24:49.708496Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}0 & 0\\\\0 & 0\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[0, 0],\n",
"[0, 0]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(lhs - rhs_new).simplify()"
]
},
{
"cell_type": "markdown",
"id": "0f0edf33",
"metadata": {},
"source": [
"## A \"Well-conditioned\" matrix"
]
},
{
"cell_type": "markdown",
"id": "d56dd253",
"metadata": {},
"source": [
"The current tests in `ChainRules` actually include a check for the `rrule` of `det` for a complex matrix. However, the test matrix is a \"well-conditioned matrix\" defined as follows:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6e08800c",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.714291Z",
"start_time": "2022-03-20T03:24:49.712576Z"
}
},
"outputs": [],
"source": [
"def v(i, j, part=None):\n",
" \"\"\"Complex elements of the matrix V\"\"\"\n",
" return csym(\"V\", i, j, part)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a723dd40",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.717562Z",
"start_time": "2022-03-20T03:24:49.715232Z"
}
},
"outputs": [],
"source": [
"V = Matrix([[v(1, 1), v(1, 2)], [v(2, 1), v(2, 2)]])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "38f09ee6",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.826463Z",
"start_time": "2022-03-20T03:24:49.718728Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}\\left(V^{im}_{11}\\right)^{2} + \\left(V^{re}_{11}\\right)^{2} + \\left(V^{im}_{12}\\right)^{2} + \\left(V^{re}_{12}\\right)^{2} + 1 & - \\left(i V^{im}_{11} + V^{re}_{11}\\right) \\left(i V^{im}_{21} - V^{re}_{21}\\right) - \\left(i V^{im}_{12} + V^{re}_{12}\\right) \\left(i V^{im}_{22} - V^{re}_{22}\\right)\\\\- \\left(i V^{im}_{11} - V^{re}_{11}\\right) \\left(i V^{im}_{21} + V^{re}_{21}\\right) - \\left(i V^{im}_{12} - V^{re}_{12}\\right) \\left(i V^{im}_{22} + V^{re}_{22}\\right) & \\left(V^{im}_{21}\\right)^{2} + \\left(V^{re}_{21}\\right)^{2} + \\left(V^{im}_{22}\\right)^{2} + \\left(V^{re}_{22}\\right)^{2} + 1\\end{matrix}\\right]$"
],
"text/plain": [
"Matrix([\n",
"[ V_11^im**2 + V_11^re**2 + V_12^im**2 + V_12^re**2 + 1, -(I*V_11^im + V_11^re)*(I*V_21^im - V_21^re) - (I*V_12^im + V_12^re)*(I*V_22^im - V_22^re)],\n",
"[-(I*V_11^im - V_11^re)*(I*V_21^im + V_21^re) - (I*V_12^im - V_12^re)*(I*V_22^im + V_22^re), V_21^im**2 + V_21^re**2 + V_22^im**2 + V_22^re**2 + 1]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W = simplify(V * dagger(V) + eye(2));\n",
"W"
]
},
{
"cell_type": "markdown",
"id": "3f317de6",
"metadata": {},
"source": [
"We find that $\\det(W) \\in \\mathbb{R}$:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a142e80b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.847133Z",
"start_time": "2022-03-20T03:24:49.828828Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 0$"
],
"text/plain": [
"0"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"imag(det(W))"
]
},
{
"cell_type": "markdown",
"id": "790caf8f",
"metadata": {},
"source": [
"This is unlike the determinant of an arbitary matrix:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "691257d8",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-20T03:24:49.856447Z",
"start_time": "2022-03-20T03:24:49.848120Z"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - \\frac{i \\left(2 i u^{im}_{11} u^{re}_{22} + 2 i u^{re}_{11} u^{im}_{22} - 2 i u^{im}_{12} u^{re}_{21} - 2 i u^{re}_{12} u^{im}_{21}\\right)}{2}$"
],
"text/plain": [
"-I*(2*I*u_11^im*u_22^re + 2*I*u_11^re*u_22^im - 2*I*u_12^im*u_21^re - 2*I*u_12^re*u_21^im)/2"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"imag(det(U))"
]
},
{
"cell_type": "markdown",
"id": "83b565ca",
"metadata": {},
"source": [
"and explains why this bug is not caught by the current tests."
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment