Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created July 5, 2022 23:23
Show Gist options
  • Save ptrblck/92240042730b37b3788b811957b126c5 to your computer and use it in GitHub Desktop.
Save ptrblck/92240042730b37b3788b811957b126c5 to your computer and use it in GitHub Desktop.
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape, )
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# setup
iteration_count = 100
N, C, H, W = 2, 3, 4, 4
x = torch.randn(N, C, H, W, device='cuda')
norm = LayerNorm(C).to('cuda')
parameters = norm.parameters()
norm_scripted = torch.jit.script(norm)
grad_output = torch.rand_like(x)
x_channels_last = x.permute(0, 2, 3, 1).contiguous()
grad_output_channels_last = grad_output.permute(0, 2, 3, 1).contiguous()
backward = True
# reference
out = norm(x)
out_scripted = norm_scripted(x)
print((out_scripted - out).abs().max())
# tensor(2.9802e-07, device='cuda:0', grad_fn=<MaxBackward1>)
out_channels_last = F.layer_norm(x_channels_last, norm.normalized_shape, norm.weight, norm.bias, norm.eps)
print((out_channels_last.permute(0, 3, 1, 2) - out).abs().max())
# tensor(3.5763e-07, device='cuda:0', grad_fn=<MaxBackward1>)
# Eager
# Perform warm-up iterations
for _ in range(3):
# Run model, forward and backward
output = norm(x)
if backward:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
# Run model, forward and backward
output = norm(x)
if backward:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))
# Scripted
# Perform warm-up iterations
for _ in range(3):
# Run model, forward and backward
output = norm_scripted(x)
if backward:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
# Run model, forward and backward
output = norm_scripted(x)
if backward:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))
# Channels-last with native implementation
# Perform warm-up iterations
for _ in range(3):
# Run model, forward and backward
output = F.layer_norm(x_channels_last, norm.normalized_shape, norm.weight, norm.bias, norm.eps)
if backward:
output.backward(grad_output_channels_last)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
# Run model, forward and backward
output = F.layer_norm(x_channels_last, norm.normalized_shape, norm.weight, norm.bias, norm.eps)
if backward:
output.backward(grad_output_channels_last)
# delete gradiens to avoid profiling the gradient accumulation
for p in parameters:
p.grad = None
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment