Created
September 30, 2024 00:39
-
-
Save andyfaff/14dc7bfe5215a5d7e156c19b6b419b3d to your computer and use it in GitHub Desktop.
torch + Metal GPU
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "00defb49-487a-4aeb-8a7a-bd20c79e2dbf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# -*- coding: utf-8 -*-\n", | |
"import math\n", | |
"from functools import reduce\n", | |
"\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"from torch import Tensor\n", | |
"\n", | |
"\n", | |
"def abeles(\n", | |
" q: Tensor,\n", | |
" thickness: Tensor,\n", | |
" roughness: Tensor,\n", | |
" sld: Tensor,\n", | |
"):\n", | |
" \"\"\"Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method\n", | |
"\n", | |
" Args:\n", | |
" q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]\n", | |
" thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]\n", | |
" roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]\n", | |
" sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom). The tensor shape should be one of the following:\n", | |
" - [batch_size, n_layers + 1]: in this case, the ambient SLD is not included but assumed to be 0\n", | |
" - [batch_size, n_layers + 2]: this shape includes the ambient SLD as the first element in the tensor\n", | |
"\n", | |
" Returns:\n", | |
" Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]\n", | |
" \"\"\"\n", | |
" c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64\n", | |
"\n", | |
" batch_size, num_layers = thickness.shape\n", | |
"\n", | |
" if sld.shape[-1] == num_layers + 1:\n", | |
" # add zero ambient sld\n", | |
" sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)\n", | |
" if sld.shape[-1] != num_layers + 2:\n", | |
" raise ValueError(\n", | |
" \"Number of SLD values does not equal to num_layers + 2 (substrate + ambient).\"\n", | |
" )\n", | |
"\n", | |
" sld = sld[:, None]\n", | |
"\n", | |
" # add zero thickness for ambient layer:\n", | |
" thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[\n", | |
" :, None\n", | |
" ]\n", | |
"\n", | |
" roughness = roughness[:, None] ** 2\n", | |
"\n", | |
" sld = (sld - sld[..., :1]) * 1e-6 + 1e-36j\n", | |
"\n", | |
" k_z0 = (q / 2).to(c_dtype)\n", | |
"\n", | |
" if k_z0.dim() == 1:\n", | |
" k_z0.unsqueeze_(0)\n", | |
"\n", | |
" if k_z0.dim() == 2:\n", | |
" k_z0.unsqueeze_(-1)\n", | |
"\n", | |
" k_n = torch.sqrt(k_z0**2 - 4 * math.pi * sld)\n", | |
"\n", | |
" # k_n.shape - (batch, q, layers)\n", | |
"\n", | |
" k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]\n", | |
"\n", | |
" beta = 1j * thickness * k_n\n", | |
"\n", | |
" exp_beta = torch.exp(beta)\n", | |
" exp_m_beta = torch.exp(-beta)\n", | |
"\n", | |
" rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(-2 * k_n * k_np1 * roughness)\n", | |
"\n", | |
" c_matrices = torch.stack(\n", | |
" [\n", | |
" torch.stack([exp_beta, rn * exp_m_beta], -1),\n", | |
" torch.stack([rn * exp_beta, exp_m_beta], -1),\n", | |
" ],\n", | |
" -1,\n", | |
" )\n", | |
"\n", | |
" c_matrices = [c.squeeze(-3) for c in c_matrices.split(1, -3)]\n", | |
"\n", | |
" m = reduce(torch.matmul, c_matrices)\n", | |
"\n", | |
" r = (m[..., 1, 0] / m[..., 0, 0]).abs() ** 2\n", | |
" r = torch.clamp_max_(r, 1.0)\n", | |
"\n", | |
" return r\n", | |
"\n", | |
"\n", | |
"# @torch.jit.script # commented so far due to complex numbers issue\n", | |
"def abeles_compiled(\n", | |
" q: Tensor,\n", | |
" thickness: Tensor,\n", | |
" roughness: Tensor,\n", | |
" sld: Tensor,\n", | |
"):\n", | |
" return abeles(q, thickness, roughness, sld)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f1a23b7e-d4cf-4df6-8fb3-f45fb354f6ac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mps_device = torch.device(\"mps\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "03845bb8-cc8d-4ef9-aabc-8e3941513815", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sub = [[100, 3.47, 0, 3], [500, -0.5, 0.00001, 3]]\n", | |
"\n", | |
"w = np.r_[[[0, 2.07, 0, 0]], sub * 75, [[0, 6.36, 0, 4]]]\n", | |
"nq = np.geomspace(0.01, 0.3, 20001)\n", | |
"\n", | |
"q = torch.from_numpy(np.atleast_2d(nq.astype(np.float32))).to(mps_device)\n", | |
"thickness = torch.from_numpy(np.atleast_2d(w[1:-1, 0].astype(np.float32))).to(mps_device)\n", | |
"rough = torch.from_numpy(np.atleast_2d(w[1:, -1].astype(np.float32))).to(mps_device)\n", | |
"sld = torch.from_numpy(np.atleast_2d(np.array(w[:, 1] + 1J*w[:, 2], dtype=np.complex64))).to(mps_device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "c69b92d9-f5fa-4702-92ee-c12511eb2fd1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"94.6 ms ± 133 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit abeles(q, thickness, rough, sld)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "8cbad460-704f-4cd2-892e-1e64b6286275", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from refnx.reflect import abeles as rnx_abeles\n", | |
"from refnx.reflect import available_backends, use_reflect_backend" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"id": "1668dbe7-c453-4f8b-99b4-8abe9b5d19ed", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18.8 ms ± 97.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit f(nq, w, threads=-1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c5be564d-f9d1-4ac9-9f99-bf5a540aecbd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment