This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import tempfile | |
from pathlib import Path | |
from typing import Union | |
import lm_eval | |
from mlx_lm.utils import convert |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlx.core as mx | |
@mx.custom_function | |
def grid_sample(x, grid): | |
"""Grid sample that matches torch.nn.functional.grid_sample with default arguments.""" | |
assert x.ndim == 4, "`x` must be 4D." | |
assert grid.ndim == 4, "`grid` must be 4D." |