Skip to content

Instantly share code, notes, and snippets.

import argparse
import json
import os
import tempfile
from pathlib import Path
from typing import Union
import lm_eval
from mlx_lm.utils import convert
@barronalex
barronalex / gist:33d9956a866fdfd4ee20b8185bfa0c80
Last active January 15, 2025 16:00
MLX Grid Sample Custom Kernel
import mlx.core as mx
@mx.custom_function
def grid_sample(x, grid):
"""Grid sample that matches torch.nn.functional.grid_sample with default arguments."""
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."