Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created March 17, 2025 18:10
Show Gist options
  • Save ricardoV94/005fb4b48d274bcf383a3eb5456462d3 to your computer and use it in GitHub Desktop.
Save ricardoV94/005fb4b48d274bcf383a3eb5456462d3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells" : [ {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:20.197470Z",
"start_time" : "2025-03-17T18:08:19.379418Z"
}
},
"cell_type" : "code",
"source" : [ "import pytensor\n", "import pytensor.tensor as pt\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "L = 1.0\n", "T = 0.1\n", "\n", "def build_heat_equation_fn(mode=None):\n", " nx = 50\n", " nt = pt.scalar(\"nt\", dtype=int)\n", " dx = L / (nx - 1)\n", " dt = T / nt\n", "\n", " alpha = pt.scalar(\"alpha\")\n", " r = alpha * dt / dx**2\n", "\n", " x = np.linspace(0, L, nx)\n", " u = np.sin(np.pi * x)\n", " # Check bounds are zero\n", " np.testing.assert_allclose(u[[0, -1]], 0, atol=1e-15)\n", "\n", " A = pt.zeros((nx - 2, nx - 2))\n", " row_idxs, col_idxs = np.diag_indices(nx - 2)\n", " A = A[row_idxs, col_idxs].set(1 + 2 * r)\n", " A = A[row_idxs[:-1], col_idxs[1:]].set(-r)\n", " A = A[row_idxs[1:], col_idxs[:-1]].set(-r)\n", "\n", " assume_a = \"tridiagonal\"\n", " if mode.lower() == \"numba\":\n", " # Next best thing\n", " assume_a = \"symmetric\"\n", "\n", " u_history, _ = pytensor.scan(\n", " fn=lambda utm1, A: pt.linalg.solve(A, utm1, assume_a=assume_a),\n", " outputs_info=[u[1:-1]],\n", " non_sequences=[A],\n", " n_steps=nt,\n", " strict=True\n", "\n", " )\n", " # Add the fixed edges\n", " u_history = pt.pad(u_history, [[0, 0], [1, 1]], mode=\"constant\")\n", "\n", " fn = pytensor.function([nt, alpha], u_history, mode=mode)\n", " return fn\n" ],
"id" : "8b38bfaeeb055722",
"outputs" : [ ],
"execution_count" : 1
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:23.030771Z",
"start_time" : "2025-03-17T18:08:20.203522Z"
}
},
"cell_type" : "code",
"source" : "fn = build_heat_equation_fn(mode=\"FAST_RUN\")",
"id" : "c46e787114deeb3",
"outputs" : [ ],
"execution_count" : 2
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:23.356448Z",
"start_time" : "2025-03-17T18:08:23.137105Z"
}
},
"cell_type" : "code",
"source" : "u_history_res = fn(nt=50, alpha=2.5)",
"id" : "51625468a0c5107",
"outputs" : [ ],
"execution_count" : 3
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:23.562571Z",
"start_time" : "2025-03-17T18:08:23.365995Z"
}
},
"cell_type" : "code",
"source" : [ "# Create a heatmap of the temperature evolution.\n", "plt.figure(figsize=(8, 5))\n", "# Use imshow: note that we set extent to map array indices to physical time and space.\n", "extent = [0, L, T, 0] # time goes from 0 to T; flipped so time increases downward\n", "plt.imshow(u_history_res, extent=extent, aspect='auto', cmap='hot')\n", "plt.colorbar(label='Temperature')\n", "plt.xlabel('Position, x')\n", "plt.ylabel('Time, t')\n", "plt.title('Heat Equation: Temperature Evolution')\n", "plt.show()" ],
"id" : "97373d7c55d6085",
"outputs" : [ {
"data" : {
"text/plain" : [ "<Figure size 800x500 with 2 Axes>" ],
"image/png" : ""
},
"metadata" : { },
"output_type" : "display_data"
} ],
"execution_count" : 4
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:23.577161Z",
"start_time" : "2025-03-17T18:08:23.572046Z"
}
},
"cell_type" : "code",
"source" : "# fn.dprint(print_shape=True, print_memory_map=True)",
"id" : "6a4ce34cb39e8cb5",
"outputs" : [ ],
"execution_count" : 5
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:26.824757Z",
"start_time" : "2025-03-17T18:08:23.617272Z"
}
},
"cell_type" : "code",
"source" : [ "nt_test = np.array(50)\n", "alpha_test = np.array(1.0)\n", "fn.trust_input = True\n", "%timeit fn(nt_test, alpha_test)\n", "fn.trust_input = False" ],
"id" : "cf9033d5598071da",
"outputs" : [ {
"name" : "stdout",
"output_type" : "stream",
"text" : [ "3.89 ms ± 104 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ]
} ],
"execution_count" : 6
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:38.878998Z",
"start_time" : "2025-03-17T18:08:26.835721Z"
}
},
"cell_type" : "code",
"source" : [ "# with pytensor.config.change_flags(optimizer_verbose=True, optimizer_verbose_ignore=\"constant_folding,MergeOptimizer\"):\n", "fn_numba = build_heat_equation_fn(mode=\"NUMBA\")\n", "fn_numba(nt_test, alpha_test);" ],
"id" : "83c05e38d2e02ff7",
"outputs" : [ {
"name" : "stderr",
"output_type" : "stream",
"text" : [ "/tmp/tmp5xa7er2j:21: NumbaWarning: \u001B[1m\u001B[1mCannot cache compiled function \"scan\" as it uses dynamic globals (such as ctypes pointers and large global arrays)\u001B[0m\u001B[0m\n", " tensor_variable_9 = scan(nt, tensor_variable_8, tensor_variable_5)\n" ]
} ],
"execution_count" : 7
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:38.891948Z",
"start_time" : "2025-03-17T18:08:38.888553Z"
}
},
"cell_type" : "code",
"source" : "# fn_numba.dprint(print_memory_map=True, print_shape=True)",
"id" : "fdc0cb93efa8c937",
"outputs" : [ ],
"execution_count" : 8
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:40.840805Z",
"start_time" : "2025-03-17T18:08:38.935833Z"
}
},
"cell_type" : "code",
"source" : [ "fn_numba.trust_input = True\n", "%timeit fn_numba(nt_test, alpha_test)\n", "fn_numba.trust_input = False" ],
"id" : "51391ca854a2ecde",
"outputs" : [ {
"name" : "stdout",
"output_type" : "stream",
"text" : [ "2.33 ms ± 87.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ]
} ],
"execution_count" : 9
}, {
"metadata" : {
"ExecuteTime" : {
"end_time" : "2025-03-17T18:08:40.855892Z",
"start_time" : "2025-03-17T18:08:40.853318Z"
}
},
"cell_type" : "code",
"source" : "",
"id" : "5eaa889dcc5d584a",
"outputs" : [ ],
"execution_count" : null
} ],
"metadata" : {
"kernelspec" : {
"display_name" : "Python 3",
"language" : "python",
"name" : "python3"
},
"language_info" : {
"codemirror_mode" : {
"name" : "ipython",
"version" : 2
},
"file_extension" : ".py",
"mimetype" : "text/x-python",
"name" : "python",
"nbconvert_exporter" : "python",
"pygments_lexer" : "ipython2",
"version" : "2.7.6"
}
},
"nbformat" : 4,
"nbformat_minor" : 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment