Skip to content

Instantly share code, notes, and snippets.

@andyfaff
Created September 30, 2024 00:39
Show Gist options
  • Save andyfaff/14dc7bfe5215a5d7e156c19b6b419b3d to your computer and use it in GitHub Desktop.
Save andyfaff/14dc7bfe5215a5d7e156c19b6b419b3d to your computer and use it in GitHub Desktop.
torch + Metal GPU
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "00defb49-487a-4aeb-8a7a-bd20c79e2dbf",
"metadata": {},
"outputs": [],
"source": [
"# -*- coding: utf-8 -*-\n",
"import math\n",
"from functools import reduce\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from torch import Tensor\n",
"\n",
"\n",
"def abeles(\n",
" q: Tensor,\n",
" thickness: Tensor,\n",
" roughness: Tensor,\n",
" sld: Tensor,\n",
"):\n",
" \"\"\"Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method\n",
"\n",
" Args:\n",
" q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]\n",
" thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]\n",
" roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]\n",
" sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom). The tensor shape should be one of the following:\n",
" - [batch_size, n_layers + 1]: in this case, the ambient SLD is not included but assumed to be 0\n",
" - [batch_size, n_layers + 2]: this shape includes the ambient SLD as the first element in the tensor\n",
"\n",
" Returns:\n",
" Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]\n",
" \"\"\"\n",
" c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64\n",
"\n",
" batch_size, num_layers = thickness.shape\n",
"\n",
" if sld.shape[-1] == num_layers + 1:\n",
" # add zero ambient sld\n",
" sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)\n",
" if sld.shape[-1] != num_layers + 2:\n",
" raise ValueError(\n",
" \"Number of SLD values does not equal to num_layers + 2 (substrate + ambient).\"\n",
" )\n",
"\n",
" sld = sld[:, None]\n",
"\n",
" # add zero thickness for ambient layer:\n",
" thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[\n",
" :, None\n",
" ]\n",
"\n",
" roughness = roughness[:, None] ** 2\n",
"\n",
" sld = (sld - sld[..., :1]) * 1e-6 + 1e-36j\n",
"\n",
" k_z0 = (q / 2).to(c_dtype)\n",
"\n",
" if k_z0.dim() == 1:\n",
" k_z0.unsqueeze_(0)\n",
"\n",
" if k_z0.dim() == 2:\n",
" k_z0.unsqueeze_(-1)\n",
"\n",
" k_n = torch.sqrt(k_z0**2 - 4 * math.pi * sld)\n",
"\n",
" # k_n.shape - (batch, q, layers)\n",
"\n",
" k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]\n",
"\n",
" beta = 1j * thickness * k_n\n",
"\n",
" exp_beta = torch.exp(beta)\n",
" exp_m_beta = torch.exp(-beta)\n",
"\n",
" rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(-2 * k_n * k_np1 * roughness)\n",
"\n",
" c_matrices = torch.stack(\n",
" [\n",
" torch.stack([exp_beta, rn * exp_m_beta], -1),\n",
" torch.stack([rn * exp_beta, exp_m_beta], -1),\n",
" ],\n",
" -1,\n",
" )\n",
"\n",
" c_matrices = [c.squeeze(-3) for c in c_matrices.split(1, -3)]\n",
"\n",
" m = reduce(torch.matmul, c_matrices)\n",
"\n",
" r = (m[..., 1, 0] / m[..., 0, 0]).abs() ** 2\n",
" r = torch.clamp_max_(r, 1.0)\n",
"\n",
" return r\n",
"\n",
"\n",
"# @torch.jit.script # commented so far due to complex numbers issue\n",
"def abeles_compiled(\n",
" q: Tensor,\n",
" thickness: Tensor,\n",
" roughness: Tensor,\n",
" sld: Tensor,\n",
"):\n",
" return abeles(q, thickness, roughness, sld)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1a23b7e-d4cf-4df6-8fb3-f45fb354f6ac",
"metadata": {},
"outputs": [],
"source": [
"mps_device = torch.device(\"mps\")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "03845bb8-cc8d-4ef9-aabc-8e3941513815",
"metadata": {},
"outputs": [],
"source": [
"sub = [[100, 3.47, 0, 3], [500, -0.5, 0.00001, 3]]\n",
"\n",
"w = np.r_[[[0, 2.07, 0, 0]], sub * 75, [[0, 6.36, 0, 4]]]\n",
"nq = np.geomspace(0.01, 0.3, 20001)\n",
"\n",
"q = torch.from_numpy(np.atleast_2d(nq.astype(np.float32))).to(mps_device)\n",
"thickness = torch.from_numpy(np.atleast_2d(w[1:-1, 0].astype(np.float32))).to(mps_device)\n",
"rough = torch.from_numpy(np.atleast_2d(w[1:, -1].astype(np.float32))).to(mps_device)\n",
"sld = torch.from_numpy(np.atleast_2d(np.array(w[:, 1] + 1J*w[:, 2], dtype=np.complex64))).to(mps_device)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "c69b92d9-f5fa-4702-92ee-c12511eb2fd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"94.6 ms ± 133 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit abeles(q, thickness, rough, sld)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "8cbad460-704f-4cd2-892e-1e64b6286275",
"metadata": {},
"outputs": [],
"source": [
"from refnx.reflect import abeles as rnx_abeles\n",
"from refnx.reflect import available_backends, use_reflect_backend"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "1668dbe7-c453-4f8b-99b4-8abe9b5d19ed",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18.8 ms ± 97.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit f(nq, w, threads=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5be564d-f9d1-4ac9-9f99-bf5a540aecbd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment