Skip to content

Instantly share code, notes, and snippets.

@psobot
Created June 29, 2025 13:40
Show Gist options
  • Save psobot/2198232bfa8808b72eb4d1076c1df7e9 to your computer and use it in GitHub Desktop.
Save psobot/2198232bfa8808b72eb4d1076c1df7e9 to your computer and use it in GitHub Desktop.
Should you use `functools.partial` to build a neural network in Python?
#!/usr/bin/env uv run
# /// script
# dependencies = [
# "torch>=2",
# ]
# ///
import functools
import time
import torch
from torch import nn
class BasicBlock(nn.Module):
"""A minimal residual block: Conv3x3 -> BN -> Act -> Conv3x3 -> BN -> Add."""
expansion = 1
def __init__(
self, in_planes, planes, stride=1, activation_function=nn.ReLU(inplace=True)
):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.activation_function = activation_function # <-- activation is injected
self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.skip = (
nn.Identity()
if stride == 1 and in_planes == planes
else nn.Conv2d(in_planes, planes, 1, stride, bias=False)
)
def forward(self, x):
out = self.activation_function(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.activation_function(out + self.skip(x))
class TinyResNet(nn.Module):
"""Roughly ResNet-18 but narrower and only 3 stages for speed."""
def __init__(self, activation_function=nn.ReLU(inplace=True)):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
activation_function,
)
# 3 stages of two blocks each: 32→64→128 channels
self.stage1 = nn.Sequential(
*[
BasicBlock(32, 32, activation_function=activation_function)
for _ in range(2)
]
)
self.stage2 = nn.Sequential(
BasicBlock(32, 64, stride=2, activation_function=activation_function),
BasicBlock(64, 64, activation_function=activation_function),
)
self.stage3 = nn.Sequential(
BasicBlock(64, 128, stride=2, activation_function=activation_function),
BasicBlock(128, 128, activation_function=activation_function),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.stem(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.pool(x).flatten(1)
return self.fc(x)
# Regular, highly optimized ReLU:
relu_layer = nn.ReLU(inplace=True)
# Raymond Hettinger's functional ReLU suggestion:
# https://x.com/raymondh/status/1939023200404361507
hettinger_relu = functools.partial(max, 0.0)
# ...which has to be wrapped in an nn.Module for compatibility with Torch:
class PyReLU(nn.Module):
def forward(self, x):
return torch.tensor([hettinger_relu(v.item()) for v in x.view(-1)]).view_as(x)
pytorch_net = TinyResNet().eval()
hettinger_net = TinyResNet(activation_function=PyReLU()).eval()
batch = 32
inp = torch.randn(batch, 3, 32, 32)
real_iterations = 20
for name, net in (("PyTorch", pytorch_net), ("functools.partial", hettinger_net)):
with torch.no_grad(): # inference-only; obviously we can't differentiate functools.partial
for iteration in range(real_iterations):
a = time.time()
net(inp)
b = time.time()
print(f"One inference with {name}: {1000.0 * (b - a):.2f}ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment