Created
October 12, 2022 10:05
-
-
Save danielkelshaw/c195efd4ce28ea2e49f15531ec704656 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import einops | |
import opt_einsum as oe | |
import itertools as it | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.autograd import Variable | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class eInput(enum.IntEnum): | |
x = 0 | |
y = 1 | |
t = 2 | |
class eOutput(enum.IntEnum): | |
u = 0 | |
v = 1 | |
p = 2 | |
class Model(nn.Module): | |
def __init__(self, nu: float, nf: int, layers: list[int]) -> None: | |
super().__init__() | |
# hardcode parameters for now | |
self.nu = nu | |
self.nf = nf | |
# layers | |
_layer_list = [] | |
for i in range(1, len(layers)): | |
_layer_list.append(nn.Linear(layers[i - 1], layers[i])) | |
self.layers = nn.ModuleList(_layer_list) | |
self.activation = nn.Tanh() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
for layer in self.layers[:-1]: | |
x = self.activation(layer(x)) | |
x = self.layers[-1](x) | |
return x | |
def batched_jacobian(y, x, create_graph=False, batched=True): | |
jac = [] | |
flat_y = einops.rearrange(y, 'b ... -> b (...)') | |
grad_y = torch.zeros_like(flat_y, requires_grad=True) | |
for i in range(flat_y.shape[1]): | |
grad_y.data[..., i] = 1.0 | |
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph) | |
jac.append(grad_x.reshape(x.shape)) | |
grad_y.data[..., i] = 0.0 | |
return torch.stack(jac, -1).reshape(x.shape + y.shape[1:]) | |
def batched_jac_hess(y, x): | |
jac = batched_jacobian(y, x, create_graph=True) | |
hess = batched_jacobian(jac, x) | |
return jac, hess | |
def compute_residual_loss(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
jac, hess = batched_jac_hess(y, x) | |
u = y[..., :(eOutput.v + 1)] | |
u_t = jac[..., eInput.t, :(eOutput.v + 1)] | |
u_x = jac[..., :(eInput.y + 1), :(eOutput.v + 1)] | |
p_x = jac[..., :(eInput.y + 1), eOutput.p] | |
fx = torch.zeros_like(x[:, :(eInput.y + 1)]) | |
fx[..., 0] = torch.sin(4 * x[:, eInput.y]) | |
u_dot_nabla_u = oe.contract('bj, bji -> bi', u, u_x) | |
laplacian_u = einops.repeat(oe.contract('biii -> b', hess[..., :(eInput.y + 1), :(eInput.y + 1), :(eOutput.v + 1)]), 'b -> b n', n=2) | |
residual = u_t + u_dot_nabla_u + p_x - laplacian_u - fx | |
residual_loss = oe.contract('b n -> ', residual ** 2) / residual.numel() | |
return residual_loss | |
def compute_continuity_loss(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
jac = batched_jacobian(y, x, create_graph=True) | |
div_u = oe.contract('bii -> b', jac[..., :(eInput.y + 1), :(eOutput.v + 1)]) | |
continuity_loss = oe.contract('... -> ', div_u ** 2) / div_u.numel() | |
return continuity_loss | |
def periodic_equivalents(xt: torch.Tensor, ndim: int = 2) -> torch.Tensor: | |
eqs = [] | |
non_central = lambda x: not(all(map(lambda y: y == 0, x))) | |
for ij in filter(non_central, it.product(*(range(-1, 2) for _ in range(ndim)))): | |
# iterate over indices denoting neighbouring domains | |
ej = torch.zeros_like(xt) | |
for p, q in zip(range(ndim), ij): | |
ej[:, p] = 2.0 * np.pi * q | |
eqs.append(xt + ej) | |
return torch.stack(eqs, 0) | |
# ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment