Last active
May 16, 2025 03:48
-
-
Save yrom/e42dad506fd4a98cec9fb4b88acd016d to your computer and use it in GitHub Desktop.
benchmark conv1d and conv_transpose1d between torch and mlx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
conv1d input: 8x256x1000 weight: 256x32x12 groups: 8 | |
conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256 | |
torch(mps) conv1d: 0.943 ms | |
mlx_conv1d: 3.906 ms | |
diff: -2.9631301250046818 | |
torch(mps) conv_transpose1d: 2.912 ms | |
mlx conv_transpose1d: 5.282 ms | |
diff: -2.3704653340100776 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import mlx.core as mx | |
def sync_if_needed(x): | |
if isinstance(x, torch.Tensor) and x.device != torch.device("cpu"): | |
torch.mps.synchronize() | |
elif isinstance(x, mx.array): | |
mx.eval(x) | |
def bench(f, *args): | |
for i in range(10): | |
f(*args) | |
s = time.perf_counter() | |
for i in range(100): | |
f(*args) | |
e = time.perf_counter() | |
return e - s | |
@torch.no_grad() | |
def conv1d(x, y, groups=1): | |
ys = [] | |
for i in range(10): | |
ys.append(F.conv1d(x, y, groups=groups)) | |
sync_if_needed(x) | |
return ys[-1] | |
def mlx_conv1d(x, y, groups=1): | |
ys = [] | |
for i in range(10): | |
ys.append(mx.conv1d(x, y, groups=groups)) | |
mx.eval(ys) | |
return ys[-1] | |
@torch.no_grad() | |
def conv_transpose1d(x, y, groups=1): | |
ys = [] | |
for i in range(10): | |
ys.append(F.conv_transpose1d(x, y, groups=groups)) | |
sync_if_needed(x) | |
return ys[-1] | |
def mlx_conv_transpose1d(x, y, groups=1): | |
ys = [] | |
for i in range(10): | |
ys.append(mx.conv_transpose1d(x, y, groups=groups)) | |
mx.eval(ys) | |
return ys[-1] | |
def run_conv_transpose_1D( | |
N, | |
Cin, | |
Cout, | |
iH, | |
kH, | |
stride=1, | |
padding=0, | |
output_padding=0, | |
dilation=1, | |
groups=1, | |
dtype="float32", | |
atol=1e-5, | |
): | |
np_dtype = getattr(np, dtype) | |
np.random.seed(0) | |
in_np = np.random.normal(0, 1.0 / kH, (N, iH, Cin)).astype(np_dtype) | |
filter_np = np.random.normal(0, 1.0 / kH, (kH,)).astype(np_dtype) | |
# wt_np = np.random.normal(0, 1.0 / Cin, (Cout, kH, int(Cin / groups))).astype(np_dtype) | |
print(in_np.shape, filter_np.shape) | |
in_mx, wt_mx = map(mx.array, (in_np, np.broadcast_to(filter_np.reshape(1, -1, 1), (Cout, kH, Cin // groups)))) | |
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).to("mps") | |
wt_pt = torch.from_numpy(filter_np).reshape(1, 1, -1).expand(Cin, Cout // groups, -1).to("mps") | |
print("mx", in_mx.shape, wt_mx.shape) | |
out_mx = mx.conv_transpose1d( | |
in_mx, | |
wt_mx, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
) | |
print("torch", in_pt.shape, wt_pt.shape) | |
out_pt = torch.conv_transpose1d( | |
in_pt, | |
wt_pt, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
) | |
# out_pt = torch.transpose(out_pt, 2, 1) | |
print(out_pt.shape, out_mx.shape) | |
np.testing.assert_allclose(out_pt.transpose(2, 1).cpu().numpy(), out_mx, atol=atol) | |
if __name__ == "__main__": | |
N = 8 | |
channels = 256 | |
kernel_size = 12 | |
groups = 8 | |
L = 1000 | |
torch.manual_seed(0) | |
np.random.seed(0) | |
# [N, C, L] | |
x = np.random.randn(N, channels, L).astype(np.float32) | |
# [C, C/groups, kernel_size] | |
y = np.random.normal(0, 1 / kernel_size, (channels, channels // groups, kernel_size)).astype(np.float32) | |
print("conv1d input:", "x".join(map(str, x.shape)), "weight:", "x".join(map(str, y.shape)), "groups:", groups) | |
x_mps = torch.from_numpy(x).to("mps") | |
y_mps = torch.from_numpy(y).to("mps") | |
mx_x = mx.array(x.swapaxes(1, 2)) | |
mx_y = mx.array(y.swapaxes(1, 2)) | |
mx.eval(mx_x, mx_y) | |
z1 = torch.conv1d(x_mps, y_mps, groups=groups) | |
z2 = mx.conv1d(mx_x, mx_y, groups=groups) | |
torch.mps.synchronize() | |
mx.eval(z2) | |
np.testing.assert_allclose(z1.cpu().numpy().swapaxes(1, 2), np.array(z2), atol=0.0001) | |
transpose1d_groups = channels | |
y2 = np.random.normal(0, 1 / 50, (1, kernel_size)).astype(np.float32) # [1 = channels//groups, kernel_size] | |
y2_mps = torch.from_numpy(y2).expand(channels, -1, -1).to("mps") # [Cin, Cout // groups, kernel_size] | |
mx_y2 = mx.array( | |
np.broadcast_to(y2.reshape(1, kernel_size, 1), (channels, kernel_size, 1)) | |
) # [Cout, kernel_size, Cin/groups = 1] | |
print( | |
"conv_transpose1d input:", | |
"x".join(map(str, x.shape)), | |
"weight:", | |
"x".join(map(str, y2_mps.shape)), | |
"groups:", | |
transpose1d_groups, | |
) | |
z1 = F.conv_transpose1d(x_mps, y2_mps, groups=transpose1d_groups) # [N, C, L] | |
z2 = mx.conv_transpose1d(mx_x, mx_y2, groups=transpose1d_groups) | |
torch.mps.synchronize() | |
mx.eval(z2) | |
np.testing.assert_allclose(z1.transpose(2, 1).cpu().numpy(), z2, atol=0.0001) | |
conv1d_ms = bench(conv1d, x_mps, y_mps, groups) | |
mlx_conv1d_ms = bench(mlx_conv1d, mx_x, mx_y, groups) | |
print(f"torch(mps) conv1d: {conv1d_ms:.3f} ms") | |
print(f"mlx_conv1d: {mlx_conv1d_ms:.3f} ms") | |
print("diff:", conv1d_ms - mlx_conv1d_ms) | |
conv_transpose1d_ms = bench(conv_transpose1d, x_mps, y2_mps, transpose1d_groups) | |
print(f"torch(mps) conv_transpose1d: {conv_transpose1d_ms:.3f} ms") | |
mlx_conv_transpose1d_ms = bench(mlx_conv_transpose1d, mx_x, mx_y2, transpose1d_groups) | |
print(f"mlx conv_transpose1d: {mlx_conv_transpose1d_ms:.3f} ms") | |
print("diff:", conv_transpose1d_ms - mlx_conv_transpose1d_ms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment