Skip to content

Instantly share code, notes, and snippets.

@danielkelshaw
Created October 12, 2022 10:05
Show Gist options
  • Save danielkelshaw/c195efd4ce28ea2e49f15531ec704656 to your computer and use it in GitHub Desktop.
Save danielkelshaw/c195efd4ce28ea2e49f15531ec704656 to your computer and use it in GitHub Desktop.
import enum
import torch
import torch.nn as nn
import numpy as np
import einops
import opt_einsum as oe
import itertools as it
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class eInput(enum.IntEnum):
x = 0
y = 1
t = 2
class eOutput(enum.IntEnum):
u = 0
v = 1
p = 2
class Model(nn.Module):
def __init__(self, nu: float, nf: int, layers: list[int]) -> None:
super().__init__()
# hardcode parameters for now
self.nu = nu
self.nf = nf
# layers
_layer_list = []
for i in range(1, len(layers)):
_layer_list.append(nn.Linear(layers[i - 1], layers[i]))
self.layers = nn.ModuleList(_layer_list)
self.activation = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers[:-1]:
x = self.activation(layer(x))
x = self.layers[-1](x)
return x
def batched_jacobian(y, x, create_graph=False, batched=True):
jac = []
flat_y = einops.rearrange(y, 'b ... -> b (...)')
grad_y = torch.zeros_like(flat_y, requires_grad=True)
for i in range(flat_y.shape[1]):
grad_y.data[..., i] = 1.0
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
jac.append(grad_x.reshape(x.shape))
grad_y.data[..., i] = 0.0
return torch.stack(jac, -1).reshape(x.shape + y.shape[1:])
def batched_jac_hess(y, x):
jac = batched_jacobian(y, x, create_graph=True)
hess = batched_jacobian(jac, x)
return jac, hess
def compute_residual_loss(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
jac, hess = batched_jac_hess(y, x)
u = y[..., :(eOutput.v + 1)]
u_t = jac[..., eInput.t, :(eOutput.v + 1)]
u_x = jac[..., :(eInput.y + 1), :(eOutput.v + 1)]
p_x = jac[..., :(eInput.y + 1), eOutput.p]
fx = torch.zeros_like(x[:, :(eInput.y + 1)])
fx[..., 0] = torch.sin(4 * x[:, eInput.y])
u_dot_nabla_u = oe.contract('bj, bji -> bi', u, u_x)
laplacian_u = einops.repeat(oe.contract('biii -> b', hess[..., :(eInput.y + 1), :(eInput.y + 1), :(eOutput.v + 1)]), 'b -> b n', n=2)
residual = u_t + u_dot_nabla_u + p_x - laplacian_u - fx
residual_loss = oe.contract('b n -> ', residual ** 2) / residual.numel()
return residual_loss
def compute_continuity_loss(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
jac = batched_jacobian(y, x, create_graph=True)
div_u = oe.contract('bii -> b', jac[..., :(eInput.y + 1), :(eOutput.v + 1)])
continuity_loss = oe.contract('... -> ', div_u ** 2) / div_u.numel()
return continuity_loss
def periodic_equivalents(xt: torch.Tensor, ndim: int = 2) -> torch.Tensor:
eqs = []
non_central = lambda x: not(all(map(lambda y: y == 0, x)))
for ij in filter(non_central, it.product(*(range(-1, 2) for _ in range(ndim)))):
# iterate over indices denoting neighbouring domains
ej = torch.zeros_like(xt)
for p, q in zip(range(ndim), ij):
ej[:, p] = 2.0 * np.pi * q
eqs.append(xt + ej)
return torch.stack(eqs, 0)
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment