Skip to content

Instantly share code, notes, and snippets.

@yrom
Last active May 16, 2025 03:48
Show Gist options
  • Save yrom/e42dad506fd4a98cec9fb4b88acd016d to your computer and use it in GitHub Desktop.
Save yrom/e42dad506fd4a98cec9fb4b88acd016d to your computer and use it in GitHub Desktop.
benchmark conv1d and conv_transpose1d between torch and mlx
conv1d input: 8x256x1000 weight: 256x32x12 groups: 8
conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256
torch(mps) conv1d: 0.943 ms
mlx_conv1d: 3.906 ms
diff: -2.9631301250046818
torch(mps) conv_transpose1d: 2.912 ms
mlx conv_transpose1d: 5.282 ms
diff: -2.3704653340100776
import time
import numpy as np
import torch
import torch.nn.functional as F
import mlx.core as mx
def sync_if_needed(x):
if isinstance(x, torch.Tensor) and x.device != torch.device("cpu"):
torch.mps.synchronize()
elif isinstance(x, mx.array):
mx.eval(x)
def bench(f, *args):
for i in range(10):
f(*args)
s = time.perf_counter()
for i in range(100):
f(*args)
e = time.perf_counter()
return e - s
@torch.no_grad()
def conv1d(x, y, groups=1):
ys = []
for i in range(10):
ys.append(F.conv1d(x, y, groups=groups))
sync_if_needed(x)
return ys[-1]
def mlx_conv1d(x, y, groups=1):
ys = []
for i in range(10):
ys.append(mx.conv1d(x, y, groups=groups))
mx.eval(ys)
return ys[-1]
@torch.no_grad()
def conv_transpose1d(x, y, groups=1):
ys = []
for i in range(10):
ys.append(F.conv_transpose1d(x, y, groups=groups))
sync_if_needed(x)
return ys[-1]
def mlx_conv_transpose1d(x, y, groups=1):
ys = []
for i in range(10):
ys.append(mx.conv_transpose1d(x, y, groups=groups))
mx.eval(ys)
return ys[-1]
def run_conv_transpose_1D(
N,
Cin,
Cout,
iH,
kH,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / kH, (N, iH, Cin)).astype(np_dtype)
filter_np = np.random.normal(0, 1.0 / kH, (kH,)).astype(np_dtype)
# wt_np = np.random.normal(0, 1.0 / Cin, (Cout, kH, int(Cin / groups))).astype(np_dtype)
print(in_np.shape, filter_np.shape)
in_mx, wt_mx = map(mx.array, (in_np, np.broadcast_to(filter_np.reshape(1, -1, 1), (Cout, kH, Cin // groups))))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).to("mps")
wt_pt = torch.from_numpy(filter_np).reshape(1, 1, -1).expand(Cin, Cout // groups, -1).to("mps")
print("mx", in_mx.shape, wt_mx.shape)
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
print("torch", in_pt.shape, wt_pt.shape)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
# out_pt = torch.transpose(out_pt, 2, 1)
print(out_pt.shape, out_mx.shape)
np.testing.assert_allclose(out_pt.transpose(2, 1).cpu().numpy(), out_mx, atol=atol)
if __name__ == "__main__":
N = 8
channels = 256
kernel_size = 12
groups = 8
L = 1000
torch.manual_seed(0)
np.random.seed(0)
# [N, C, L]
x = np.random.randn(N, channels, L).astype(np.float32)
# [C, C/groups, kernel_size]
y = np.random.normal(0, 1 / kernel_size, (channels, channels // groups, kernel_size)).astype(np.float32)
print("conv1d input:", "x".join(map(str, x.shape)), "weight:", "x".join(map(str, y.shape)), "groups:", groups)
x_mps = torch.from_numpy(x).to("mps")
y_mps = torch.from_numpy(y).to("mps")
mx_x = mx.array(x.swapaxes(1, 2))
mx_y = mx.array(y.swapaxes(1, 2))
mx.eval(mx_x, mx_y)
z1 = torch.conv1d(x_mps, y_mps, groups=groups)
z2 = mx.conv1d(mx_x, mx_y, groups=groups)
torch.mps.synchronize()
mx.eval(z2)
np.testing.assert_allclose(z1.cpu().numpy().swapaxes(1, 2), np.array(z2), atol=0.0001)
transpose1d_groups = channels
y2 = np.random.normal(0, 1 / 50, (1, kernel_size)).astype(np.float32) # [1 = channels//groups, kernel_size]
y2_mps = torch.from_numpy(y2).expand(channels, -1, -1).to("mps") # [Cin, Cout // groups, kernel_size]
mx_y2 = mx.array(
np.broadcast_to(y2.reshape(1, kernel_size, 1), (channels, kernel_size, 1))
) # [Cout, kernel_size, Cin/groups = 1]
print(
"conv_transpose1d input:",
"x".join(map(str, x.shape)),
"weight:",
"x".join(map(str, y2_mps.shape)),
"groups:",
transpose1d_groups,
)
z1 = F.conv_transpose1d(x_mps, y2_mps, groups=transpose1d_groups) # [N, C, L]
z2 = mx.conv_transpose1d(mx_x, mx_y2, groups=transpose1d_groups)
torch.mps.synchronize()
mx.eval(z2)
np.testing.assert_allclose(z1.transpose(2, 1).cpu().numpy(), z2, atol=0.0001)
conv1d_ms = bench(conv1d, x_mps, y_mps, groups)
mlx_conv1d_ms = bench(mlx_conv1d, mx_x, mx_y, groups)
print(f"torch(mps) conv1d: {conv1d_ms:.3f} ms")
print(f"mlx_conv1d: {mlx_conv1d_ms:.3f} ms")
print("diff:", conv1d_ms - mlx_conv1d_ms)
conv_transpose1d_ms = bench(conv_transpose1d, x_mps, y2_mps, transpose1d_groups)
print(f"torch(mps) conv_transpose1d: {conv_transpose1d_ms:.3f} ms")
mlx_conv_transpose1d_ms = bench(mlx_conv_transpose1d, mx_x, mx_y2, transpose1d_groups)
print(f"mlx conv_transpose1d: {mlx_conv_transpose1d_ms:.3f} ms")
print("diff:", conv_transpose1d_ms - mlx_conv_transpose1d_ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment