Created
June 29, 2025 13:40
-
-
Save psobot/2198232bfa8808b72eb4d1076c1df7e9 to your computer and use it in GitHub Desktop.
Should you use `functools.partial` to build a neural network in Python?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env uv run | |
# /// script | |
# dependencies = [ | |
# "torch>=2", | |
# ] | |
# /// | |
import functools | |
import time | |
import torch | |
from torch import nn | |
class BasicBlock(nn.Module): | |
"""A minimal residual block: Conv3x3 -> BN -> Act -> Conv3x3 -> BN -> Add.""" | |
expansion = 1 | |
def __init__( | |
self, in_planes, planes, stride=1, activation_function=nn.ReLU(inplace=True) | |
): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.activation_function = activation_function # <-- activation is injected | |
self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.skip = ( | |
nn.Identity() | |
if stride == 1 and in_planes == planes | |
else nn.Conv2d(in_planes, planes, 1, stride, bias=False) | |
) | |
def forward(self, x): | |
out = self.activation_function(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
return self.activation_function(out + self.skip(x)) | |
class TinyResNet(nn.Module): | |
"""Roughly ResNet-18 but narrower and only 3 stages for speed.""" | |
def __init__(self, activation_function=nn.ReLU(inplace=True)): | |
super().__init__() | |
self.stem = nn.Sequential( | |
nn.Conv2d(3, 32, 3, 1, 1, bias=False), | |
nn.BatchNorm2d(32), | |
activation_function, | |
) | |
# 3 stages of two blocks each: 32→64→128 channels | |
self.stage1 = nn.Sequential( | |
*[ | |
BasicBlock(32, 32, activation_function=activation_function) | |
for _ in range(2) | |
] | |
) | |
self.stage2 = nn.Sequential( | |
BasicBlock(32, 64, stride=2, activation_function=activation_function), | |
BasicBlock(64, 64, activation_function=activation_function), | |
) | |
self.stage3 = nn.Sequential( | |
BasicBlock(64, 128, stride=2, activation_function=activation_function), | |
BasicBlock(128, 128, activation_function=activation_function), | |
) | |
self.pool = nn.AdaptiveAvgPool2d(1) | |
self.fc = nn.Linear(128, 10) | |
def forward(self, x): | |
x = self.stem(x) | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.pool(x).flatten(1) | |
return self.fc(x) | |
# Regular, highly optimized ReLU: | |
relu_layer = nn.ReLU(inplace=True) | |
# Raymond Hettinger's functional ReLU suggestion: | |
# https://x.com/raymondh/status/1939023200404361507 | |
hettinger_relu = functools.partial(max, 0.0) | |
# ...which has to be wrapped in an nn.Module for compatibility with Torch: | |
class PyReLU(nn.Module): | |
def forward(self, x): | |
return torch.tensor([hettinger_relu(v.item()) for v in x.view(-1)]).view_as(x) | |
pytorch_net = TinyResNet().eval() | |
hettinger_net = TinyResNet(activation_function=PyReLU()).eval() | |
batch = 32 | |
inp = torch.randn(batch, 3, 32, 32) | |
real_iterations = 20 | |
for name, net in (("PyTorch", pytorch_net), ("functools.partial", hettinger_net)): | |
with torch.no_grad(): # inference-only; obviously we can't differentiate functools.partial | |
for iteration in range(real_iterations): | |
a = time.time() | |
net(inp) | |
b = time.time() | |
print(f"One inference with {name}: {1000.0 * (b - a):.2f}ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment