Recall, MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Even loading arrays from a file is lazy:
weights = mx.load("model.safetensors")
""" | |
A minimal, fast example generating text with Llama 3.1 in MLX. | |
To run, install the requirements: | |
pip install -U mlx transformers fire | |
Then generate text with: | |
python l3min.py "How tall is K2?" |
# Requires: | |
# pip install pyobjc-framework-Metal | |
import numpy as np | |
import Metal | |
# Get the default GPU device | |
device = Metal.MTLCreateSystemDefaultDevice() | |
# Make a command queue to encode command buffers to | |
command_queue = device.newCommandQueue() |
Recall, MLX is lazy. No actual computation happens until you explicitly or implicitly evaluate the graph. Even loading arrays from a file is lazy:
weights = mx.load("model.safetensors")
from typing import Callable, Tuple | |
import operator | |
from functools import reduce | |
from itertools import product | |
import mlx.core as mx | |
def _interpolate( | |
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False | |
): |
This is a short article on a common type of not-yet-supported operation in MLX: ops where the output shape depends on the input data. Here's an outline:
import time | |
import mlx.core as mx | |
import mlx.nn as nn | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple, Union | |
@dataclass | |
class ModelArgs: |