Last active
August 2, 2022 17:58
-
-
Save yueyericardo/24158433a2021c51eeef9c3e2722df99 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import argparse | |
from functorch import vmap, jacrev, jacfwd | |
import torch | |
import torch.nn as nn | |
torch.backends.cuda.matmul.allow_tf32 = False | |
_ = torch.manual_seed(0) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
D1 = 2 # x, y | |
D2 = 3 # u, v, p | |
B = 10000 | |
x = torch.randn(B, D1).to(device) | |
run_backward = False | |
model = nn.Sequential( | |
nn.Linear(D1, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, D2), | |
).to(device) | |
def predict(x): | |
torch.cuda.nvtx.range_push("forward") | |
out = model(x) | |
torch.cuda.nvtx.range_pop() | |
return out, out # return two outputs is needed for jacrev auxiliary object | |
def reference_hessian(): | |
x_ = x.clone().requires_grad_() | |
ones = torch.ones(B, device=x.device) | |
pred, _ = predict(x_) | |
jacobian_rows = [None] * D2 | |
hessian_rows = [None] * (D2 * D1) | |
for i in range(D2): | |
torch.cuda.nvtx.range_push("autograd jacobian") | |
jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[ | |
0 | |
] | |
torch.cuda.nvtx.range_pop() | |
for i in range(D2): | |
for j in range(D1): | |
torch.cuda.nvtx.range_push("autograd hesian") | |
hessian_rows[i * D1 + j] = torch.autograd.grad( | |
jacobian_rows[i][:, j], x_, ones, create_graph=True | |
)[0] | |
torch.cuda.nvtx.range_pop() | |
jacobian = torch.stack(jacobian_rows) # [D2, B, D1] | |
hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1] | |
if run_backward: | |
l = hessian.sum() | |
l.backward() | |
return hessian.transpose(0, 1), pred | |
def functorch_hessian(): | |
x_ = x.clone().requires_grad_() | |
hessian, pred = vmap( | |
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True), | |
in_dims=0, | |
)( | |
x_ | |
) # [B, D2, D1, D1] | |
if run_backward: | |
l = hessian.sum() | |
l.backward() | |
return hessian, pred | |
def validate_result(): | |
# test functorch result | |
ref_hes, ref_pred = reference_hessian() | |
ft_hes, ft_pred = functorch_hessian() | |
ref_hes = ref_hes.view_as(ft_hes) | |
print(f"max pred error: functorch: {(ref_pred - ft_pred).max():.2e}") | |
print(f"max hessian error: functorch: {(ref_hes - ft_hes).max():.2e}") | |
def benchmark(func): | |
N = 20 | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(N): | |
torch.cuda.nvtx.range_push(func.__name__) | |
_ = func() | |
torch.cuda.nvtx.range_pop() | |
torch.cuda.synchronize() | |
time_ms = ((time.time() - start) / N) * 1000 | |
print(f"{func.__name__}: {time_ms:.3f} ms") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-b", "--backward", default=False, action="store_true") | |
args = parser.parse_args() | |
if args.backward: | |
run_backward = True | |
print("===== benchmark with backward =====") | |
else: | |
print("===== benchmark without backward =====") | |
validate_result() | |
# warm up | |
for i in range(10): | |
reference_hessian() | |
functorch_hessian() | |
# benchmark hessian | |
benchmark(reference_hessian) | |
benchmark(functorch_hessian) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.benchmark as benchmark | |
import torch.nn as nn | |
# -------------------------------------------------------- | |
# just for test | |
_ = torch.manual_seed(0) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
D1 = 2 # x, y | |
D2 = 3 # u, v, p | |
B = 10000 | |
x = torch.randn(B, D1).to(device).requires_grad_() | |
model = nn.Sequential( | |
nn.Linear(D1, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, D2), | |
).to(device) | |
pred = model(x) | |
loss = pred.sum() | |
torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True) | |
loss.backward(retain_graph=True) | |
# -------------------------------------------------------- | |
# benchmark autograd.grad | |
t0 = benchmark.Timer( | |
stmt="torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True)", | |
setup=""" | |
import torch | |
import torch.nn as nn | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
D1 = 2 # x, y | |
D2 = 3 # u, v, p | |
B = 1000 | |
x = torch.randn(B, D1).to(device).requires_grad_() | |
model = nn.Sequential( | |
nn.Linear(D1, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, D2), | |
).to(device) | |
pred = model(x) | |
loss = pred.sum() | |
""", | |
num_threads=1, | |
) | |
print(t0.blocked_autorange()) | |
print(t0.collect_callgrind()) | |
# benchmark loss.backward | |
t1 = benchmark.Timer( | |
stmt="loss.backward(retain_graph=True)", | |
setup=""" | |
import torch | |
import torch.nn as nn | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
D1 = 2 # x, y | |
D2 = 3 # u, v, p | |
B = 1000 | |
x = torch.randn(B, D1).to(device).requires_grad_() | |
model = nn.Sequential( | |
nn.Linear(D1, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, D2), | |
).to(device) | |
pred = model(x) | |
loss = pred.sum() | |
""", | |
num_threads=1, | |
) | |
print(t1.blocked_autorange()) | |
print(t1.collect_callgrind()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment