Skip to content

Instantly share code, notes, and snippets.

@hzhangxyz
Last active July 9, 2025 10:12
Show Gist options
  • Save hzhangxyz/2dee2287153b967be158ac1ff99c05f1 to your computer and use it in GitHub Desktop.
Save hzhangxyz/2dee2287153b967be158ac1ff99c05f1 to your computer and use it in GitHub Desktop.
learn-autograd.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "427371ec-a216-4aa4-a658-1fcf44fdd746",
"metadata": {},
"source": [
"# 自动微分"
]
},
{
"cell_type": "markdown",
"id": "08628e8d-d11d-426b-8ea4-4579867247b5",
"metadata": {},
"source": [
"## 引言:为什么需要自动微分?"
]
},
{
"cell_type": "markdown",
"id": "39797cf8-5780-41d5-90a6-f99af7496314",
"metadata": {},
"source": [
"对比手动推导、数值微分与符号微分的局限性\n",
"\n",
"- 手动推导:易错、耗时\n",
"- 数值微分(如有限差分):精度问题、计算成本高\n",
"- 符号微分(如Mathematica):表达式膨胀问题\n",
"\n",
"自动微分的优势:高效、精确、适合计算机实现"
]
},
{
"cell_type": "markdown",
"id": "d2b70b3a-3024-498e-ac19-cbc7c997f93d",
"metadata": {},
"source": [
"## 计算图"
]
},
{
"attachments": {
"cbb80587-8b64-478d-83c7-bb967d6262d5.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "e35f46ea-a1ea-4a40-acc8-1c37aa8e2d71",
"metadata": {},
"source": [
"![无标题.png](attachment:cbb80587-8b64-478d-83c7-bb967d6262d5.png)"
]
},
{
"cell_type": "markdown",
"id": "b10676bd-f7ba-497c-9756-b0f0478c0c82",
"metadata": {},
"source": [
"## 自动微分的基本方法"
]
},
{
"cell_type": "markdown",
"id": "f542d48f-4087-499c-92de-dc12f0c8ea07",
"metadata": {},
"source": [
"自动微分能做的事情:\n",
"\n",
"对于 $y \\leftarrow f(x)$\n",
"对于给定 $x_0$ ,计算 $f'(x)$ 。"
]
},
{
"cell_type": "markdown",
"id": "6c8669e7-1187-4ae5-bbb9-d7661b88e92e",
"metadata": {},
"source": [
"自动微分的对象必须是连续函数\n",
"\n",
"不能是差分函数\n",
"\n",
"因为我们是在单个点的附近通过梯度传播进行计算的,差分的话,高阶项扔不掉"
]
},
{
"cell_type": "markdown",
"id": "82ccbd99-18ac-48f8-bc53-2cbd74912568",
"metadata": {},
"source": [
"有正向模式和反向模式,都是基于链式法则"
]
},
{
"cell_type": "markdown",
"id": "0cfc520c-d876-4ac5-be6a-18b0cd19017d",
"metadata": {},
"source": [
"### 反向模式"
]
},
{
"cell_type": "markdown",
"id": "7b0bee13-2c15-4088-8296-cd5deb2f1adf",
"metadata": {},
"source": [
"如果我们固定感兴趣的量为\n",
"\n",
"$\\frac{\\partial z}{\\partial ?}$\n",
"\n",
"我们能很轻松的计算$z$的直属上游,比如\n",
"\n",
"$\\frac{\\partial z}{\\partial y_1}, \\frac{\\partial z}{\\partial y_2}, \\cdots$\n",
"\n",
"我们也可以通过这些结果计算二阶上游的梯度,比如\n",
"\n",
"$\\frac{\\partial z} {\\partial x_1} = \\frac{\\partial z} {\\partial y_1} \\frac{\\partial y_1} {\\partial x_0} + \\frac{\\partial z} {\\partial y_2} \\frac{\\partial y_2} {\\partial x_0} + \\cdots$\n",
"\n",
"对于每个二阶上游 $x_i$ ,都可以这么计算下来\n",
"\n",
"有了二阶上游的导数,我们就可以接着往下算了。"
]
},
{
"cell_type": "markdown",
"id": "5579c3dd-986f-473b-a90a-d8ab75e2970e",
"metadata": {},
"source": [
"通过反向模式的自动微分,我们能以正向计算的复杂度,计算某个scalar对所有可能的上游参数的梯度"
]
},
{
"cell_type": "markdown",
"id": "cced4a6d-0784-4c2f-9901-a245042eb667",
"metadata": {},
"source": [
"### 正向模式"
]
},
{
"cell_type": "markdown",
"id": "d791e760-9e19-4288-9ba2-07313aa967c9",
"metadata": {},
"source": [
"如果我们固定感兴趣的量为\n",
"\n",
"$\\frac{\\partial ?}{\\partial x}$\n",
"\n",
"我们能很轻松地计算$x$的直属下游,比如\n",
"\n",
"$\\frac{\\partial y_1}{\\partial x}, \\frac{\\partial y_2}{\\partial x}, \\cdots$\n",
"\n",
"我们也可以通过这些结果计算二阶下游的梯度,比如\n",
"\n",
"$\\frac{\\partial z_1} {\\partial x} = \\frac{\\partial z_1} {\\partial y_1} \\frac{\\partial y_1} {\\partial x} + \\frac{\\partial z_1} {\\partial y_2} \\frac{\\partial y_2} {\\partial x} + \\cdots$\n",
"\n",
"对于每个二阶下游 $z_i$ ,都可以这么计算下来\n",
"\n",
"有了二阶上下游的导数,我们就可以接着往下算了。"
]
},
{
"cell_type": "markdown",
"id": "a4f707a5-7ee3-472e-b296-b2b4bafe8a82",
"metadata": {},
"source": [
"通过正向模式的自动微分,我们能以正向计算的复杂度,计算所有可能的下游对给某个scalar参数的梯度"
]
},
{
"cell_type": "markdown",
"id": "fbca55a8-abc9-4866-a663-41b760c8e81f",
"metadata": {},
"source": [
"## 实现们"
]
},
{
"cell_type": "markdown",
"id": "6c574557-2e26-4f9d-8b1c-93be4a9feaec",
"metadata": {},
"source": [
"pytorch、tensorflow、jax、..."
]
},
{
"cell_type": "markdown",
"id": "28d2b42e-459b-4165-85ad-d20b76b9766e",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 一个简单的实现"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1b4e03ee-359c-4d16-a9b9-4441c03e3435",
"metadata": {},
"outputs": [],
"source": [
"class Leaf:\n",
"\n",
" def __init__(self, value):\n",
" self._value = value\n",
"\n",
" def forward(self):\n",
" return self._value\n",
"\n",
" def backward(self, grad):\n",
" self.grad = grad"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "696d337e-feae-4634-b82f-059ad49890a3",
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
"\n",
" def __init__(self, forward, backward, arg1, arg2):\n",
" self._forward = forward\n",
" self._backward = backward\n",
" self._arg1 = arg1\n",
" self._arg2 = arg2\n",
"\n",
" def forward(self):\n",
" context, result = self._forward(self._arg1.forward(),\n",
" self._arg2.forward())\n",
" self._context = context\n",
" return result\n",
"\n",
" def backward(self, grad):\n",
" grad1, grad2 = self._backward(self._context, grad)\n",
" self._arg1.backward(grad1)\n",
" self._arg2.backward(grad2)"
]
},
{
"cell_type": "markdown",
"id": "2064f793-951c-4a68-966e-d0982f9d543a",
"metadata": {},
"source": [
"### 测试一下"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f407c5a1-3a24-40a3-8d6f-51ce1c8fc52d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=6\n",
"x.grad=3, y.grad=2\n"
]
}
],
"source": [
"# result = x * y\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_times_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x,\n",
" y,\n",
")\n",
"\n",
"result = x_times_y.forward()\n",
"print(f\"{result=}\")\n",
"x_times_y.backward(1)\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "87883c4d-7427-440d-9a76-78872072ff8a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=20\n",
"x.grad=4, y.grad=4, z.grad=5\n"
]
}
],
"source": [
"# result = (x + y) * z\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"z = Leaf(4)\n",
"\n",
"x_plus_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a + b),\n",
" lambda c, g: (g, g),\n",
" x,\n",
" y,\n",
")\n",
"\n",
"x_plus_y_multiplied_by_z = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x_plus_y,\n",
" z,\n",
")\n",
"\n",
"result = x_plus_y_multiplied_by_z.forward()\n",
"print(f\"{result=}\")\n",
"x_plus_y_multiplied_by_z.backward(1)\n",
"print(f\"{x.grad=}, {y.grad=}, {z.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "916a8397-9674-49c1-a343-8ca7e9019c6d",
"metadata": {},
"source": [
"这是一个简单的计算图实现示例,展示了如何使用叶节点和节点来进行前向传播和反向传播。\n",
"但是,这个实现不支持一个节点有多个输出的情况。"
]
},
{
"cell_type": "markdown",
"id": "39e11ae5-bbd0-4ad6-9fd8-415ecc333c82",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 支持多输出的版本,需要累加不同下游的梯度"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "94cb905f-7a0f-43f0-984f-ff2ba5ebdd65",
"metadata": {},
"outputs": [],
"source": [
"class Leaf:\n",
"\n",
" def __init__(self, value):\n",
" self._value = value\n",
" self.temp_grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" return self._value\n",
"\n",
" def backward(self, grad):\n",
" self.grad = grad"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c46d0716-28df-4270-b31c-07bceff602a7",
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
"\n",
" def __init__(self, forward, backward, arg1, arg2):\n",
" self._forward = forward\n",
" self._backward = backward\n",
" self._arg1 = arg1\n",
" self._arg2 = arg2\n",
" self.temp_grad = None\n",
" self.ref_count = 0\n",
" self._arg1.ref_count += 1\n",
" self._arg2.ref_count += 1\n",
"\n",
" def forward(self):\n",
" context, result = self._forward(self._arg1.forward(),\n",
" self._arg2.forward())\n",
" self._context = context\n",
" return result\n",
"\n",
" def backward(self, grad):\n",
" grad1, grad2 = self._backward(self._context, grad)\n",
" if self._arg1.temp_grad is None:\n",
" self._arg1.temp_grad = grad1\n",
" else:\n",
" self._arg1.temp_grad += grad1\n",
" if self._arg2.temp_grad is None:\n",
" self._arg2.temp_grad = grad2\n",
" else:\n",
" self._arg2.temp_grad += grad2\n",
" self._arg1.ref_count -= 1\n",
" if self._arg1.ref_count == 0:\n",
" self._arg1.backward(self._arg1.temp_grad)\n",
" self._arg1.temp_grad = None\n",
" self._arg2.ref_count -= 1\n",
" if self._arg2.ref_count == 0:\n",
" self._arg2.backward(self._arg2.temp_grad)\n",
" self._arg2.temp_grad = None"
]
},
{
"cell_type": "markdown",
"id": "c31d7b73-b359-400e-a93e-e0e77a30e121",
"metadata": {},
"source": [
"### 测试一下"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "46938abe-2e6f-4cc4-8306-cf77047faee7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=10\n",
"x.grad=7, y.grad=2\n"
]
}
],
"source": [
"# result = (x + y) * x\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_plus_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a + b),\n",
" lambda c, g: (g, g),\n",
" x,\n",
" y,\n",
")\n",
"\n",
"x_plus_y_multiplied_by_x = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x_plus_y,\n",
" x,\n",
")\n",
"\n",
"result = x_plus_y_multiplied_by_x.forward()\n",
"print(f\"{result=}\")\n",
"x_plus_y_multiplied_by_x.backward(1)\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "2e20e522-1091-47c8-9603-3edbe066a382",
"metadata": {},
"source": [
"这个实现支持了一个节点有多个输出的情况。\n",
"这个实现的设计有些不合理,至少temp grad应该在每个节点中管理。"
]
},
{
"cell_type": "markdown",
"id": "c010badb-0cca-46a2-a59b-862323b38986",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 调整代码重复"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f9c6324b-4367-41d5-b36c-fdc090c18586",
"metadata": {},
"outputs": [],
"source": [
"class Leaf:\n",
"\n",
" def __init__(self, value):\n",
" self._value = value\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" return self._value\n",
"\n",
" def backward(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "aa568882-c9ae-41c6-b97b-4228f5632112",
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
"\n",
" def __init__(self, forward, backward, arg1, arg2):\n",
" self._forward = forward\n",
" self._backward = backward\n",
" self._arg1 = arg1\n",
" self._arg2 = arg2\n",
" self.grad = None\n",
" self.ref_count = 0\n",
" self._arg1.ref_count += 1\n",
" self._arg2.ref_count += 1\n",
"\n",
" def forward(self):\n",
" context, result = self._forward(self._arg1.forward(),\n",
" self._arg2.forward())\n",
" self._context = context\n",
" return result\n",
"\n",
" def backward(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
" if self.ref_count == 0:\n",
" grad1, grad2 = self._backward(self._context, grad)\n",
" self._arg1.backward(grad1)\n",
" self._arg2.backward(grad2)\n",
"\n",
" def backward_interface(self):\n",
" self.ref_count += 1\n",
" self.backward(1)"
]
},
{
"cell_type": "markdown",
"id": "ed83a692-dcb4-4256-b377-a7830366445d",
"metadata": {},
"source": [
"### 测试一下"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "68c8c3d2-c98e-4e22-947f-da6ae97b7161",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=10\n",
"x.grad=7, y.grad=2\n"
]
}
],
"source": [
"# result = (x + y) * x\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_plus_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a + b),\n",
" lambda c, g: (g, g),\n",
" x,\n",
" y,\n",
")\n",
"\n",
"x_plus_y_multiplied_by_x = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x_plus_y,\n",
" x,\n",
")\n",
"\n",
"result = x_plus_y_multiplied_by_x.forward()\n",
"print(f\"{result=}\")\n",
"x_plus_y_multiplied_by_x.backward_interface()\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "ad5f8678-a9fe-4667-b79d-33a300575a4b",
"metadata": {},
"source": [
"这个实现做了一些改进,但是有一个严重的问题。\n",
"如果x -> y, 同时 x -> z,但是只对y进行反向传播,那么z的梯度将不会被计算,结果x将一直等待来自z的梯度。\n",
"我们应在每次backward时,寻找有效子图的ref count。\n",
"而不是在构造计算图时直接生成ref count。"
]
},
{
"cell_type": "markdown",
"id": "78f08773-d687-4249-bf5a-8cc787c362b6",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 使用有效子图"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3ca5f026-30f0-41b5-a817-6c5971c29729",
"metadata": {},
"outputs": [],
"source": [
"class Leaf:\n",
"\n",
" def __init__(self, value):\n",
" self._value = value\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" return self._value\n",
"\n",
" def update_grad(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
"\n",
" def activate_subgraph(self):\n",
" self.ref_count += 1"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "19b57da8-fe66-446a-88d9-69d63d8fdb6f",
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
"\n",
" def __init__(self, forward, backward, arg1, arg2):\n",
" self._forward = forward\n",
" self._backward = backward\n",
" self._arg1 = arg1\n",
" self._arg2 = arg2\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" context, result = self._forward(self._arg1.forward(),\n",
" self._arg2.forward())\n",
" self._context = context\n",
" return result\n",
"\n",
" def update_grad(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
" if self.ref_count == 0:\n",
" grad1, grad2 = self._backward(self._context, grad)\n",
" self._arg1.update_grad(grad1)\n",
" self._arg2.update_grad(grad2)\n",
"\n",
" def activate_subgraph(self):\n",
" self.ref_count += 1\n",
" self._arg1.activate_subgraph()\n",
" self._arg2.activate_subgraph()\n",
"\n",
" def backward(self):\n",
" self.activate_subgraph()\n",
" self.update_grad(1)"
]
},
{
"cell_type": "markdown",
"id": "de052196-667c-4171-bec5-0cfba4243f15",
"metadata": {},
"source": [
"### 测试一下"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3c9ca10c-469d-4a69-8c1a-9809d45e3fc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=10\n",
"x.grad=7, y.grad=2\n"
]
}
],
"source": [
"# result = (x + y) * x\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_plus_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a + b),\n",
" lambda c, g: (g, g),\n",
" x,\n",
" y,\n",
")\n",
"\n",
"x_plus_y_multiplied_by_x = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x_plus_y,\n",
" x,\n",
")\n",
"\n",
"another_x_plus_y_multiplied_by_x = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: (g * c[\"b\"], g * c[\"a\"]),\n",
" x_plus_y,\n",
" x,\n",
")\n",
"\n",
"result = x_plus_y_multiplied_by_x.forward()\n",
"print(f\"{result=}\")\n",
"x_plus_y_multiplied_by_x.backward()\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "d27d8614-5f06-460b-aeb1-7e754184053a",
"metadata": {},
"source": [
"目前只支持binary operator,这是不合理的"
]
},
{
"cell_type": "markdown",
"id": "2973836d-a837-438c-8f33-8dd20c541a12",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 支持任意参数的版本"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a60753c2-e705-44cc-b58e-6c491c8d1cb8",
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"import itertools\n",
"\n",
"\n",
"class AutoGrad:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "785b6597-d5c9-415a-a703-f0fe3d8251bf",
"metadata": {},
"outputs": [],
"source": [
"class Leaf(AutoGrad):\n",
"\n",
" def __init__(self, value):\n",
" self._value = value\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" return self._value\n",
"\n",
" def update_grad(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
"\n",
" def activate_subgraph(self):\n",
" self.ref_count += 1"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "875e42d1-2311-40f1-8e6c-3775cff6fd35",
"metadata": {},
"outputs": [],
"source": [
"class Node(AutoGrad):\n",
"\n",
" def __init__(self, forward, backward, *args, **kwargs):\n",
" self._forward = forward\n",
" self._backward = backward\n",
" self._args = args\n",
" self._kwargs = kwargs\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" context, result = self._forward(\n",
" *(arg.forward() if isinstance(arg, AutoGrad) else arg\n",
" for arg in self._args),\n",
" **{\n",
" key: arg.forward() if isinstance(arg, AutoGrad) else arg\n",
" for key, arg in self._kwargs\n",
" },\n",
" )\n",
" self._context = context\n",
" return result\n",
"\n",
" def update_grad(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
" if self.ref_count == 0:\n",
" self.invoke_backward()\n",
"\n",
" def invoke_backward(self):\n",
" upstream_grad = self._backward(self._context, self.grad)\n",
" sig = inspect.signature(self._forward)\n",
" bound = sig.bind(*self._args, **self._kwargs)\n",
" for name, param in sig.parameters.items():\n",
" if name in bound.arguments:\n",
" value = bound.arguments[name]\n",
" if isinstance(value, AutoGrad):\n",
" value.update_grad(upstream_grad[name])\n",
"\n",
" def activate_subgraph(self):\n",
" self.ref_count += 1\n",
" for arg in itertools.chain(self._args, self._kwargs.values()):\n",
" if isinstance(arg, AutoGrad):\n",
" arg.activate_subgraph()\n",
"\n",
" def backward(self):\n",
" self.activate_subgraph()\n",
" self.update_grad(1)"
]
},
{
"cell_type": "markdown",
"id": "83c18835-113f-4c1f-a57c-36aaa3ed19bc",
"metadata": {},
"source": [
"### 测试一下"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8f463fce-ca24-4505-a315-d684d7dc5bfa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=-10\n",
"x.grad=-7, y.grad=-2\n"
]
}
],
"source": [
"# result = - (x + y) * x\n",
"\n",
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_plus_y = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a + b),\n",
" lambda c, g: {\n",
" \"a\": g,\n",
" \"b\": g\n",
" },\n",
" x,\n",
" y,\n",
")\n",
"\n",
"x_plus_y_multiplied_by_x = Node(\n",
" lambda a, b: ({\n",
" \"a\": a,\n",
" \"b\": b\n",
" }, a * b),\n",
" lambda c, g: {\n",
" \"a\": g * c[\"b\"],\n",
" \"b\": g * c[\"a\"]\n",
" },\n",
" x_plus_y,\n",
" x,\n",
")\n",
"\n",
"minus_x_plus_y_multiplied_by_x = Node(\n",
" lambda a: ({}, -a),\n",
" lambda c, g: {\"a\": -g},\n",
" x_plus_y_multiplied_by_x,\n",
")\n",
"\n",
"result = minus_x_plus_y_multiplied_by_x.forward()\n",
"print(f\"{result=}\")\n",
"minus_x_plus_y_multiplied_by_x.backward()\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "85a41387-511b-4115-b7c3-6b3d70ce3edb",
"metadata": {},
"source": [
"看起来好多了,但是每次创建node都需要传入lambda函数,封装一下吧"
]
},
{
"cell_type": "markdown",
"id": "73e4e926-0b46-4e2f-93ac-80f1805ad7b4",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## 将每个op单独封装"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "89e62586-a4bb-4dbc-a701-cd744cf1c2db",
"metadata": {},
"outputs": [],
"source": [
"class Node(AutoGrad):\n",
"\n",
" def __init__(self, *args, **kwargs):\n",
" self.args = args\n",
" self.kwargs = kwargs\n",
" self.grad = None\n",
" self.ref_count = 0\n",
"\n",
" def forward(self):\n",
" result = self.forward_impl(\n",
" *(arg.forward() if isinstance(arg, AutoGrad) else arg\n",
" for arg in self.args),\n",
" **{\n",
" key: arg.forward() if isinstance(arg, AutoGrad) else arg\n",
" for key, arg in self.kwargs\n",
" },\n",
" )\n",
" return result\n",
"\n",
" def update_grad(self, grad):\n",
" if self.grad is None:\n",
" self.grad = grad\n",
" else:\n",
" self.grad += grad\n",
" self.ref_count -= 1\n",
" if self.ref_count == 0:\n",
" self.backward_wrap()\n",
"\n",
" def backward_wrap(self):\n",
" upstream_grad = self.backward_impl(self.grad)\n",
" sig = inspect.signature(self.forward_impl)\n",
" bound = sig.bind(*self.args, **self.kwargs)\n",
" for name, param in sig.parameters.items():\n",
" if name in bound.arguments:\n",
" value = bound.arguments[name]\n",
" if isinstance(value, AutoGrad):\n",
" value.update_grad(upstream_grad[name])\n",
"\n",
" def activate_subgraph(self):\n",
" self.ref_count += 1\n",
" for arg in itertools.chain(self.args, self.kwargs.values()):\n",
" if isinstance(arg, AutoGrad):\n",
" arg.activate_subgraph()\n",
"\n",
" def backward(self):\n",
" self.activate_subgraph()\n",
" self.update_grad(1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "fce74892-0888-42fb-9e92-39f11338a45d",
"metadata": {},
"outputs": [],
"source": [
"class Plus(Node):\n",
"\n",
" def forward_impl(self, a, b):\n",
" self.a = a\n",
" self.b = b\n",
" return a + b\n",
"\n",
" def backward_impl(self, grad):\n",
" return {\"a\": grad, \"b\": grad}\n",
"\n",
"\n",
"class Time(Node):\n",
"\n",
" def forward_impl(self, a, b):\n",
" self.a = a\n",
" self.b = b\n",
" return a * b\n",
"\n",
" def backward_impl(self, grad):\n",
" return {\"a\": grad * self.b, \"b\": grad * self.a}\n",
"\n",
"\n",
"class Neg(Node):\n",
"\n",
" def forward_impl(self, a):\n",
" self.a = a\n",
" return -a\n",
"\n",
" def backward_impl(self, grad):\n",
" return {\"a\": -grad}"
]
},
{
"cell_type": "markdown",
"id": "f8de0e92-38b3-4659-b8b2-267cd13a1557",
"metadata": {},
"source": [
"### 跑跑看"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1335344a-92c6-488e-a59f-7ab063970c1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"result=-9\n",
"x.grad=-7, y.grad=-2\n"
]
}
],
"source": [
"x = Leaf(2)\n",
"y = Leaf(3)\n",
"\n",
"x_plus_y = Plus(x, y)\n",
"x_plus_y_multiplied_by_x = Time(x_plus_y, x)\n",
"minus_x_plus_y_multiplied_by_x = Neg(x_plus_y_multiplied_by_x)\n",
"minus_x_plus_y_multiplied_by_x_plus_one = Plus(minus_x_plus_y_multiplied_by_x, 1)\n",
"\n",
"result = minus_x_plus_y_multiplied_by_x_plus_one.forward()\n",
"print(f\"{result=}\")\n",
"minus_x_plus_y_multiplied_by_x_plus_one.backward()\n",
"print(f\"{x.grad=}, {y.grad=}\")"
]
},
{
"cell_type": "markdown",
"id": "76ab55b0-7b81-4919-b162-96a6a932eca5",
"metadata": {},
"source": [
"## 实现小结"
]
},
{
"cell_type": "markdown",
"id": "03272b44-3f16-4795-aa1c-daadcd39d406",
"metadata": {},
"source": [
"- 构造计算图\n",
"- 反向传播梯度\n",
"- 多下游的梯度进行累加\n",
"- backward时有效子图"
]
},
{
"cell_type": "markdown",
"id": "1577d949-e015-42f4-8e68-258faf9e9a92",
"metadata": {},
"source": [
"## 其他话题"
]
},
{
"cell_type": "markdown",
"id": "a757fef4-22ae-4d86-bccc-6cf64bf95772",
"metadata": {},
"source": [
"### 高阶导数"
]
},
{
"cell_type": "markdown",
"id": "7009aa52-6605-4596-9fb6-1f39ee2c3517",
"metadata": {},
"source": [
"反向传播的同时也捕获计算图即可"
]
},
{
"cell_type": "markdown",
"id": "b53eac2e-639c-4311-8b9c-70fcebd58427",
"metadata": {},
"source": [
"### Hessian与Jacobian"
]
},
{
"attachments": {
"3d1dc0ab-3343-43db-b3d4-6a7ef1e1b6f9.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "001cba5d-3942-4cee-af5d-582151ffd267",
"metadata": {},
"source": [
"![图片.png](attachment:3d1dc0ab-3343-43db-b3d4-6a7ef1e1b6f9.png)"
]
},
{
"attachments": {
"d2c25cc9-e802-41a7-986b-d63b679af33b.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "d2ff3aff-d54a-4510-80dd-f3c535c6ccaf",
"metadata": {},
"source": [
"![图片.png](attachment:d2c25cc9-e802-41a7-986b-d63b679af33b.png)"
]
},
{
"cell_type": "markdown",
"id": "c747e008-cbd5-4f21-9b1b-35c093fd6d67",
"metadata": {},
"source": [
"H 用于估计二阶近似,J用于诱导两个空间度量"
]
},
{
"cell_type": "markdown",
"id": "2b83a529-c12b-4b6e-b434-36c33c66ae9e",
"metadata": {},
"source": [
"困难在量大,无论正向反向,都需要算多次。"
]
},
{
"cell_type": "markdown",
"id": "5a278c84-d683-4831-9534-e93aeaae3161",
"metadata": {},
"source": [
"但是实际场景下,大多数是J、 H作用在某个向量v上"
]
},
{
"cell_type": "markdown",
"id": "4f9d617b-d35a-4509-b92a-71defa8f66bb",
"metadata": {},
"source": [
"### vJ"
]
},
{
"cell_type": "markdown",
"id": "31381a05-92e7-4c61-97b6-ac635ccaab07",
"metadata": {},
"source": [
"对于 vJ,构造辅助量进行反向梯度计算"
]
},
{
"cell_type": "markdown",
"id": "c5fdae78-9ec6-4b58-91ee-9a0fe413935d",
"metadata": {},
"source": [
"$(vJ)_i = \\sum_j v_j J_{ji}\n",
" = \\sum_j v_j \\frac{\\partial f_j}{\\partial x_i}\n",
" = \\sum_j \\frac{\\partial v_j f_j}{\\partial x_i}\n",
" = \\frac{\\partial \\sum_j v_j f_j}{\\partial x_i}$"
]
},
{
"cell_type": "markdown",
"id": "60891391-23b1-48ce-bb42-6430f03f0b73",
"metadata": {},
"source": [
"### Jv"
]
},
{
"cell_type": "markdown",
"id": "2c7c30ee-8a5f-4630-a1be-789f6de514a1",
"metadata": {},
"source": [
"对于Jv,构造辅助量进行正向梯度计算"
]
},
{
"cell_type": "markdown",
"id": "828f8f7c-452e-4c12-90b6-774bceb8f0d9",
"metadata": {},
"source": [
"let $x_j = v_j a + x_j^0$, where $a = 0$"
]
},
{
"cell_type": "markdown",
"id": "5aecd4e0-6015-48fc-bb96-bc9be8b44bae",
"metadata": {},
"source": [
"$(J v)_{i} = \\sum_{j} J_{ij} v_j\n",
" = \\sum_{j} \\frac{\\partial f_i}{\\partial x_j} v_j\n",
" = \\sum_{j} \\frac{\\partial f_i}{\\partial x_j} \\frac{\\partial x_j}{\\partial a}\n",
" = \\frac{\\partial f_i}{\\partial a}$"
]
},
{
"cell_type": "markdown",
"id": "5969d552-9c12-4d8d-be1f-3f27c7151079",
"metadata": {},
"source": [
"### vH"
]
},
{
"cell_type": "markdown",
"id": "0e285511-5ee3-4850-8cce-0a4e2bc440bb",
"metadata": {},
"source": [
"let $x_j = v_j a + x_j^0$ where a = 0"
]
},
{
"cell_type": "markdown",
"id": "c48b4347-fa46-401e-a5a5-ab42c4c43988",
"metadata": {},
"source": [
"$(vH)_i = \\sum_j v_j H_{ji}\n",
" = \\sum_j v_j \\frac{\\partial^2 f}{\\partial x_j \\partial x_i}\n",
" = \\frac{\\partial^2 f}{\\partial a \\partial x_i}$"
]
},
{
"cell_type": "markdown",
"id": "efd00918-5445-471f-96bb-dc9d75e10e0f",
"metadata": {},
"source": [
"然后使用反向梯度计算两次即可"
]
},
{
"cell_type": "markdown",
"id": "62da7e02-7e57-4b99-8f87-a35bc3216e00",
"metadata": {},
"source": [
"## 总结"
]
},
{
"cell_type": "markdown",
"id": "1abd11ea-61be-4f29-9bad-d55d04f06de7",
"metadata": {},
"source": [
"- 计算图\n",
"- 链式法则\n",
"- 正向反向,计算复杂度不变\n",
"- 利用矩阵apply省掉计算"
]
},
{
"cell_type": "markdown",
"id": "3ef6b78c-d978-4d5e-8878-f1a713c097c9",
"metadata": {},
"source": [
"### 结束后的兴趣话题"
]
},
{
"cell_type": "markdown",
"id": "f62c1105-693d-44fa-9370-9b9d7ee2b8ab",
"metadata": {},
"source": [
"Wirtinger derivatives for complex\n",
"\n",
"See: https://docs.pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72226ab1-79aa-4d9c-943f-9567d9220e10",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment