Skip to content

Instantly share code, notes, and snippets.

@ltiao
Created May 8, 2022 14:08
Show Gist options
  • Save ltiao/5a6cf96cdbd0d228ddab63773e6179f9 to your computer and use it in GitHub Desktop.
Save ltiao/5a6cf96cdbd0d228ddab63773e6179f9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c2bded24",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-08 15:08:21.189672: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
]
}
],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"\n",
"from gpflow.kernels import SquaredExponential, Matern12, Matern32, Matern52\n",
"from gpflow.inducing_variables import InducingPoints\n",
"from gpflow.covariances import Kuu, Kuf\n",
"\n",
"from gpflow_decomposed.benchmarking.plotting import plot_predictive, WIDTH, HEIGHT, ASPECT"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3b3f01be",
"metadata": {},
"outputs": [],
"source": [
"rc = {\n",
" \"figure.figsize\": (WIDTH, HEIGHT),\n",
" \"figure.dpi\": 300,\n",
" \"font.serif\": [\"Palatino\"],\n",
" \"text.usetex\": True,\n",
"}\n",
"sns.set(context=\"paper\", style=\"ticks\", palette=\"colorblind\", font=\"serif\", rc=rc)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b859969",
"metadata": {},
"outputs": [],
"source": [
"input_dim = 1\n",
"\n",
"x_min, x_max = -5., 5.\n",
"z_min, z_max = -3., 3.\n",
"\n",
"n_index_points = 16\n",
"n_inducing = 8\n",
"\n",
"seed = 8888\n",
"\n",
"random_state = np.random.RandomState(seed)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfcadaa6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-08 15:08:22.806229: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1\n",
"2022-05-08 15:08:22.865383: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.865823: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: \n",
"pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3070 Laptop GPU computeCapability: 8.6\n",
"coreClock: 1.62GHz coreCount: 40 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 417.29GiB/s\n",
"2022-05-08 15:08:22.865870: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
"2022-05-08 15:08:22.878062: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11\n",
"2022-05-08 15:08:22.878217: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11\n",
"2022-05-08 15:08:22.966931: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10\n",
"2022-05-08 15:08:22.967515: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10\n",
"2022-05-08 15:08:22.968655: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11\n",
"2022-05-08 15:08:22.970360: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11\n",
"2022-05-08 15:08:22.970676: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8\n",
"2022-05-08 15:08:22.970832: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.971190: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.971424: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0\n",
"2022-05-08 15:08:22.972651: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2022-05-08 15:08:22.975089: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.975375: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: \n",
"pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3070 Laptop GPU computeCapability: 8.6\n",
"coreClock: 1.62GHz coreCount: 40 deviceMemorySize: 7.79GiB deviceMemoryBandwidth: 417.29GiB/s\n",
"2022-05-08 15:08:22.975485: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.975751: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:22.975955: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0\n",
"2022-05-08 15:08:22.976006: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
"2022-05-08 15:08:23.258460: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
"2022-05-08 15:08:23.258479: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264] 0 \n",
"2022-05-08 15:08:23.258483: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0: N \n",
"2022-05-08 15:08:23.258638: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:23.258740: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:23.258813: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-05-08 15:08:23.258884: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 6144 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3070 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6)\n"
]
},
{
"data": {
"text/html": [
"<gpflow.kernels.stationaries.Matern52 object at 0x7f5f43ca8be0>\n",
"<table>\n",
"<thead>\n",
"<tr><th>name </th><th>class </th><th>transform </th><th>prior </th><th>trainable </th><th>shape </th><th>dtype </th><th style=\"text-align: right;\"> value</th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>Matern52.variance </td><td>Parameter</td><td>Softplus </td><td> </td><td>True </td><td>() </td><td>float64</td><td style=\"text-align: right;\"> 1</td></tr>\n",
"<tr><td>Matern52.lengthscales</td><td>Parameter</td><td>Softplus </td><td> </td><td>True </td><td>() </td><td>float64</td><td style=\"text-align: right;\"> 1</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"<gpflow.kernels.stationaries.Matern52 object at 0x7f5f43ca8be0>\n",
"╒═══════════════════════╤═══════════╤═════════════╤═════════╤═════════════╤═════════╤═════════╤═════════╕\n",
"│ name │ class │ transform │ prior │ trainable │ shape │ dtype │ value │\n",
"╞═══════════════════════╪═══════════╪═════════════╪═════════╪═════════════╪═════════╪═════════╪═════════╡\n",
"│ Matern52.variance │ Parameter │ Softplus │ │ True │ () │ float64 │ 1 │\n",
"├───────────────────────┼───────────┼─────────────┼─────────┼─────────────┼─────────┼─────────┼─────────┤\n",
"│ Matern52.lengthscales │ Parameter │ Softplus │ │ True │ () │ float64 │ 1 │\n",
"╘═══════════════════════╧═══════════╧═════════════╧═════════╧═════════════╧═════════╧═════════╧═════════╛"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kernel = Matern52()\n",
"kernel"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "337b0a56",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(16, 1)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_grid = np.linspace(x_min, x_max, n_index_points).reshape(-1, input_dim)\n",
"X_grid.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bdde7c3c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(8, 1)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Z = random_state.uniform(low=z_min, high=z_max, size=(n_inducing, input_dim))\n",
"Z.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2c957594",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"&lt;gpflow.inducing_variables.inducing_variables.InducingPoints object at 0x7f5f317a3130&gt;\n",
"<table>\n",
"<thead>\n",
"<tr><th>name </th><th>class </th><th>transform </th><th>prior </th><th>trainable </th><th>shape </th><th>dtype </th><th>value </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>InducingPoints.Z</td><td>Parameter</td><td>Identity </td><td> </td><td>True </td><td>(8, 1) </td><td>float64</td><td>[[2.77430592...</td></tr>\n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"<gpflow.inducing_variables.inducing_variables.InducingPoints object at 0x7f5f317a3130>\n",
"╒══════════════════╤═══════════╤═════════════╤═════════╤═════════════╤═════════╤═════════╤═════════════════╕\n",
"│ name │ class │ transform │ prior │ trainable │ shape │ dtype │ value │\n",
"╞══════════════════╪═══════════╪═════════════╪═════════╪═════════════╪═════════╪═════════╪═════════════════╡\n",
"│ InducingPoints.Z │ Parameter │ Identity │ │ True │ (8, 1) │ float64 │ [[2.77430592... │\n",
"╘══════════════════╧═══════════╧═════════════╧═════════╧═════════════╧═════════╧═════════╧═════════════════╛"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inducing_variables = InducingPoints(Z=Z)\n",
"inducing_variables"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "557316cb",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-08 15:08:23.482478: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11\n",
"2022-05-08 15:08:24.050873: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11\n",
"2022-05-08 15:08:24.056961: I tensorflow/core/util/cuda_solvers.cc:180] Creating CudaSolver handles for stream 0x55738329eb10\n",
"2022-05-08 15:08:24.057026: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11\n",
"2022-05-08 15:08:24.099112: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11\n"
]
}
],
"source": [
"Knn = kernel(X_grid, full_cov=False)\n",
"Kmn = Kuf(inducing_variables, kernel, X_grid)\n",
"Kmm = Kuu(inducing_variables, kernel)\n",
"Lm = tf.linalg.cholesky(Kmm)"
]
},
{
"cell_type": "markdown",
"id": "43c6be6c",
"metadata": {},
"source": [
"## Two triangular_solves"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1d2c35c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(8, 16), dtype=float64, numpy=\n",
"array([[ 3.35932369e-06, 1.26589491e-05, 4.70096909e-05,\n",
" 1.71545511e-04, 6.12815636e-04, 2.13206027e-03,\n",
" 7.17179250e-03, 2.30755338e-02, 6.98374456e-02,\n",
" 1.93269495e-01, 4.63892956e-01, 8.61795895e-01,\n",
" 9.59629746e-01, 5.87704063e-01, 2.61406798e-01,\n",
" 9.81517519e-02],\n",
" [ 8.21358385e-05, 2.97130381e-04, 1.05010197e-03,\n",
" 3.60407069e-03, 1.19097978e-02, 3.74050780e-02,\n",
" 1.09341220e-01, 2.86713356e-01, 6.26541518e-01,\n",
" 9.59717222e-01, 7.36940153e-01, 2.28613292e-01,\n",
" -5.88897025e-02, -8.25591826e-02, -4.43354348e-02,\n",
" -1.80649964e-02],\n",
" [ 1.47410338e-03, 5.01181394e-03, 1.63565275e-02,\n",
" 5.04928460e-02, 1.43905673e-01, 3.62310469e-01,\n",
" 7.34897310e-01, 9.57604256e-01, 5.64527742e-01,\n",
" 7.46830799e-02, -8.56270259e-02, -3.80352964e-02,\n",
" 1.06475053e-02, 1.52879424e-02, 8.30003182e-03,\n",
" 3.40319294e-03],\n",
" [ 1.11020201e-01, 2.91028689e-01, 6.36280798e-01,\n",
" 9.81194181e-01, 8.09562182e-01, 4.00025022e-01,\n",
" 1.25034215e-01, -1.23617712e-03, -1.90421882e-02,\n",
" -3.12925995e-03, 3.76593147e-03, 1.69767299e-03,\n",
" -4.77424183e-04, -6.86383174e-04, -3.72862988e-04,\n",
" -1.52932290e-04],\n",
" [-9.56718570e-02, -2.21754905e-01, -3.67572553e-01,\n",
" -1.55525226e-01, 5.37701639e-01, 4.93959821e-01,\n",
" 1.90986805e-01, -2.04316523e-03, -3.21962531e-02,\n",
" -5.33153214e-03, 6.42920139e-03, 2.89999116e-03,\n",
" -8.15692321e-04, -1.17276472e-03, -6.37094045e-04,\n",
" -2.61311873e-04],\n",
" [ 5.42811298e-04, 1.20161951e-03, 1.80948617e-03,\n",
" 5.50424652e-04, 2.85531023e-03, 2.49966181e-02,\n",
" 4.79559730e-02, -2.46375344e-03, -1.74772629e-01,\n",
" -1.19184520e-01, 4.79487779e-01, 3.49422644e-01,\n",
" -1.11286128e-01, -1.65290648e-01, -9.10814446e-02,\n",
" -3.76582791e-02],\n",
" [-1.91939108e-03, -4.24892829e-03, -6.39827426e-03,\n",
" -1.94618783e-03, -1.00817890e-02, -8.76783525e-02,\n",
" -1.65457751e-01, 8.09638556e-03, 4.59820021e-01,\n",
" 1.25825579e-01, -2.57397454e-02, 4.89512494e-02,\n",
" -1.97432521e-02, -3.08171537e-02, -1.73336030e-02,\n",
" -7.24755458e-03],\n",
" [-3.42374032e-03, -7.57896923e-03, -1.14124476e-02,\n",
" -3.47086298e-03, -1.79037933e-02, -1.52530160e-01,\n",
" -2.72726060e-01, 1.10902763e-02, 1.78718834e-01,\n",
" -2.88809463e-02, 8.26579272e-03, -1.61260396e-02,\n",
" 6.52771374e-03, 1.01964408e-02, 5.73680432e-03,\n",
" 2.39905370e-03]])>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Qmn_white = tf.linalg.triangular_solve(Lm, Kmn, lower=True) # [..., M, N]\n",
"Qmn_white"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "302bdbdc",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(8, 16), dtype=float64, numpy=\n",
"array([[-6.00071698e-05, -1.32832916e-04, -2.00013449e-04,\n",
" -6.08209325e-05, -3.12364166e-04, -2.60391923e-03,\n",
" -4.37757329e-03, 1.34189997e-04, -5.30586670e-03,\n",
" 1.28528241e-02, -3.36718745e-02, 5.66844329e-01,\n",
" 1.05067930e+00, 7.21923692e-01, 3.35127707e-01,\n",
" 1.28577273e-01],\n",
" [-1.62591812e-03, -3.59916093e-03, -5.41944515e-03,\n",
" -1.64796884e-03, -8.46380712e-03, -7.05627561e-02,\n",
" -1.18661517e-01, 3.64332360e-03, -1.50543141e-01,\n",
" 7.93877999e-01, 2.18559183e-01, -3.52908055e-01,\n",
" 1.38691101e-01, 2.15349113e-01, 1.20872205e-01,\n",
" 5.04820759e-02],\n",
" [ 1.51744315e-02, 3.35919499e-02, 5.05864298e-02,\n",
" 1.53893957e-02, 8.00755169e-02, 7.14239147e-01,\n",
" 1.49795409e+00, 9.66565024e-01, -1.74252303e-01,\n",
" 3.93087372e-02, -1.14678157e-02, 2.24006583e-02,\n",
" -9.06920995e-03, -1.41668040e-02, -7.97075154e-03,\n",
" -3.33328467e-03],\n",
" [ 3.77700695e-01, 9.09149633e-01, 1.66091722e+00,\n",
" 1.41590553e+00, -6.77994864e-01, -9.36668033e-01,\n",
" -3.42981817e-01, 1.76906786e-03, -6.78788401e-03,\n",
" 1.66523068e-03, -4.88041344e-04, 9.53594084e-04,\n",
" -3.86091286e-04, -6.03109201e-04, -3.39331943e-04,\n",
" -1.41905305e-04],\n",
" [-2.84320707e-01, -6.58910241e-01, -1.09183617e+00,\n",
" -4.61560056e-01, 1.58669494e+00, 1.41607927e+00,\n",
" 4.81105700e-01, -2.42715493e-03, 9.28444494e-03,\n",
" -2.27737677e-03, 6.67442576e-04, -1.30412925e-03,\n",
" 5.28015973e-04, 8.24808275e-04, 4.64068186e-04,\n",
" 1.94068784e-04],\n",
" [ 4.20669678e-04, 9.31201730e-04, 1.40215902e-03,\n",
" 4.26374450e-04, 2.18977777e-03, 1.82545054e-02,\n",
" 3.06893009e-02, -9.40879160e-04, 3.73460953e-02,\n",
" -9.66861734e-02, 8.72464051e-01, 7.06337786e-01,\n",
" -2.29296456e-01, -3.42128276e-01, -1.88893630e-01,\n",
" -7.81838960e-02],\n",
" [ 4.50045428e-03, 9.96229033e-03, 1.50007615e-02,\n",
" 4.56152271e-03, 2.34312828e-02, 1.95503959e-01,\n",
" 3.29549965e-01, -1.02492209e-02, 5.93558381e-01,\n",
" 3.65129735e-01, -8.08375819e-02, 1.54793966e-01,\n",
" -6.24939584e-02, -9.75657239e-02, -5.48817145e-02,\n",
" -2.29482012e-02],\n",
" [-1.32255811e-02, -2.92768327e-02, -4.40851928e-02,\n",
" -1.34076115e-02, -6.91606399e-02, -5.89209409e-01,\n",
" -1.05351467e+00, 4.28406758e-02, 6.90373750e-01,\n",
" -1.11564331e-01, 3.19299661e-02, -6.22933474e-02,\n",
" 2.52159334e-02, 3.93878750e-02, 2.21607260e-02,\n",
" 9.26731483e-03]])>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Qmn = tf.linalg.triangular_solve(Lm, Qmn_white, lower=True, adjoint=True) # [..., M, N]\n",
"Qmn"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9964ed55",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(8, 16), dtype=float64, numpy=\n",
"array([[-6.00071698e-05, -1.32832916e-04, -2.00013449e-04,\n",
" -6.08209325e-05, -3.12364166e-04, -2.60391923e-03,\n",
" -4.37757329e-03, 1.34189997e-04, -5.30586670e-03,\n",
" 1.28528241e-02, -3.36718745e-02, 5.66844329e-01,\n",
" 1.05067930e+00, 7.21923692e-01, 3.35127707e-01,\n",
" 1.28577273e-01],\n",
" [-1.62591812e-03, -3.59916093e-03, -5.41944515e-03,\n",
" -1.64796884e-03, -8.46380712e-03, -7.05627561e-02,\n",
" -1.18661517e-01, 3.64332360e-03, -1.50543141e-01,\n",
" 7.93877999e-01, 2.18559183e-01, -3.52908055e-01,\n",
" 1.38691101e-01, 2.15349113e-01, 1.20872205e-01,\n",
" 5.04820759e-02],\n",
" [ 1.51744315e-02, 3.35919499e-02, 5.05864298e-02,\n",
" 1.53893957e-02, 8.00755169e-02, 7.14239147e-01,\n",
" 1.49795409e+00, 9.66565024e-01, -1.74252303e-01,\n",
" 3.93087372e-02, -1.14678157e-02, 2.24006583e-02,\n",
" -9.06920995e-03, -1.41668040e-02, -7.97075154e-03,\n",
" -3.33328467e-03],\n",
" [ 3.77700695e-01, 9.09149633e-01, 1.66091722e+00,\n",
" 1.41590553e+00, -6.77994864e-01, -9.36668033e-01,\n",
" -3.42981817e-01, 1.76906786e-03, -6.78788401e-03,\n",
" 1.66523068e-03, -4.88041344e-04, 9.53594084e-04,\n",
" -3.86091286e-04, -6.03109201e-04, -3.39331943e-04,\n",
" -1.41905305e-04],\n",
" [-2.84320707e-01, -6.58910241e-01, -1.09183617e+00,\n",
" -4.61560056e-01, 1.58669494e+00, 1.41607927e+00,\n",
" 4.81105700e-01, -2.42715493e-03, 9.28444494e-03,\n",
" -2.27737677e-03, 6.67442576e-04, -1.30412925e-03,\n",
" 5.28015973e-04, 8.24808275e-04, 4.64068186e-04,\n",
" 1.94068784e-04],\n",
" [ 4.20669678e-04, 9.31201730e-04, 1.40215902e-03,\n",
" 4.26374450e-04, 2.18977777e-03, 1.82545054e-02,\n",
" 3.06893009e-02, -9.40879160e-04, 3.73460953e-02,\n",
" -9.66861734e-02, 8.72464051e-01, 7.06337786e-01,\n",
" -2.29296456e-01, -3.42128276e-01, -1.88893630e-01,\n",
" -7.81838960e-02],\n",
" [ 4.50045428e-03, 9.96229033e-03, 1.50007615e-02,\n",
" 4.56152271e-03, 2.34312828e-02, 1.95503959e-01,\n",
" 3.29549965e-01, -1.02492209e-02, 5.93558381e-01,\n",
" 3.65129735e-01, -8.08375819e-02, 1.54793966e-01,\n",
" -6.24939584e-02, -9.75657239e-02, -5.48817145e-02,\n",
" -2.29482012e-02],\n",
" [-1.32255811e-02, -2.92768327e-02, -4.40851928e-02,\n",
" -1.34076115e-02, -6.91606399e-02, -5.89209409e-01,\n",
" -1.05351467e+00, 4.28406758e-02, 6.90373750e-01,\n",
" -1.11564331e-01, 3.19299661e-02, -6.22933474e-02,\n",
" 2.52159334e-02, 3.93878750e-02, 2.21607260e-02,\n",
" 9.26731483e-03]])>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Qmn = tf.linalg.triangular_solve(tf.linalg.adjoint(Lm), Qmn_white, lower=False) # [..., M, N]\n",
"Qmn"
]
},
{
"cell_type": "markdown",
"id": "02bde07d",
"metadata": {},
"source": [
"## Compared to cholesky_solve"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "903674f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(8, 16), dtype=float64, numpy=\n",
"array([[-6.00071698e-05, -1.32832916e-04, -2.00013449e-04,\n",
" -6.08209325e-05, -3.12364166e-04, -2.60391923e-03,\n",
" -4.37757329e-03, 1.34189997e-04, -5.30586670e-03,\n",
" 1.28528241e-02, -3.36718745e-02, 5.66844329e-01,\n",
" 1.05067930e+00, 7.21923692e-01, 3.35127707e-01,\n",
" 1.28577273e-01],\n",
" [-1.62591812e-03, -3.59916093e-03, -5.41944515e-03,\n",
" -1.64796884e-03, -8.46380712e-03, -7.05627561e-02,\n",
" -1.18661517e-01, 3.64332360e-03, -1.50543141e-01,\n",
" 7.93877999e-01, 2.18559183e-01, -3.52908055e-01,\n",
" 1.38691101e-01, 2.15349113e-01, 1.20872205e-01,\n",
" 5.04820759e-02],\n",
" [ 1.51744315e-02, 3.35919499e-02, 5.05864298e-02,\n",
" 1.53893957e-02, 8.00755169e-02, 7.14239147e-01,\n",
" 1.49795409e+00, 9.66565024e-01, -1.74252303e-01,\n",
" 3.93087372e-02, -1.14678157e-02, 2.24006583e-02,\n",
" -9.06920995e-03, -1.41668040e-02, -7.97075154e-03,\n",
" -3.33328467e-03],\n",
" [ 3.77700695e-01, 9.09149633e-01, 1.66091722e+00,\n",
" 1.41590553e+00, -6.77994864e-01, -9.36668033e-01,\n",
" -3.42981817e-01, 1.76906786e-03, -6.78788401e-03,\n",
" 1.66523068e-03, -4.88041344e-04, 9.53594084e-04,\n",
" -3.86091286e-04, -6.03109201e-04, -3.39331943e-04,\n",
" -1.41905305e-04],\n",
" [-2.84320707e-01, -6.58910241e-01, -1.09183617e+00,\n",
" -4.61560056e-01, 1.58669494e+00, 1.41607927e+00,\n",
" 4.81105700e-01, -2.42715493e-03, 9.28444494e-03,\n",
" -2.27737677e-03, 6.67442576e-04, -1.30412925e-03,\n",
" 5.28015973e-04, 8.24808275e-04, 4.64068186e-04,\n",
" 1.94068784e-04],\n",
" [ 4.20669678e-04, 9.31201730e-04, 1.40215902e-03,\n",
" 4.26374450e-04, 2.18977777e-03, 1.82545054e-02,\n",
" 3.06893009e-02, -9.40879160e-04, 3.73460953e-02,\n",
" -9.66861734e-02, 8.72464051e-01, 7.06337786e-01,\n",
" -2.29296456e-01, -3.42128276e-01, -1.88893630e-01,\n",
" -7.81838960e-02],\n",
" [ 4.50045428e-03, 9.96229033e-03, 1.50007615e-02,\n",
" 4.56152271e-03, 2.34312828e-02, 1.95503959e-01,\n",
" 3.29549965e-01, -1.02492209e-02, 5.93558381e-01,\n",
" 3.65129735e-01, -8.08375819e-02, 1.54793966e-01,\n",
" -6.24939584e-02, -9.75657239e-02, -5.48817145e-02,\n",
" -2.29482012e-02],\n",
" [-1.32255811e-02, -2.92768327e-02, -4.40851928e-02,\n",
" -1.34076115e-02, -6.91606399e-02, -5.89209409e-01,\n",
" -1.05351467e+00, 4.28406758e-02, 6.90373750e-01,\n",
" -1.11564331e-01, 3.19299661e-02, -6.22933474e-02,\n",
" 2.52159334e-02, 3.93878750e-02, 2.21607260e-02,\n",
" 9.26731483e-03]])>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.linalg.cholesky_solve(Lm, Kmn)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment