Created
January 5, 2024 18:32
-
-
Save shreyaskarnik/dbe75ebba5d4819419f9ea2526650ba0 to your computer and use it in GitHub Desktop.
updated code to load model from mlx-community for finetuning see issue https://github.com/ml-explore/mlx-examples/issues/232
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright © 2023 Apple Inc. | |
import math | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple | |
import mlx.core as mx | |
import mlx.nn as nn | |
from mlx.utils import tree_map, tree_unflatten | |
@dataclass | |
class ModelArgs: | |
dim: int | |
n_layers: int | |
hidden_dim: int | |
n_heads: int | |
n_kv_heads: int | |
norm_eps: float | |
vocab_size: int | |
rope_traditional: bool | |
model_type: str | |
head_dim: int = None | |
class LoRALinear(nn.Module): | |
@staticmethod | |
def from_linear(linear: nn.Linear, rank: int = 8): | |
# TODO remove when input_dims and output_dims are attributes | |
# on linear and quantized linear | |
output_dims, input_dims = linear.weight.shape | |
if isinstance(linear, nn.QuantizedLinear): | |
input_dims *= 32 // linear.bits | |
lora_lin = LoRALinear(input_dims, output_dims, rank) | |
lora_lin.linear = linear | |
return lora_lin | |
def __init__( | |
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False | |
): | |
super().__init__() | |
# Regular linear layer weights | |
self.linear = nn.Linear(input_dims, output_dims, bias=bias) | |
# Low rank lora weights | |
scale = 1 / math.sqrt(input_dims) | |
self.lora_a = mx.random.uniform( | |
low=-scale, | |
high=scale, | |
shape=(input_dims, lora_rank), | |
) | |
self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) | |
def __call__(self, x): | |
dtype = self.linear.weight.dtype | |
if isinstance(self.linear, nn.QuantizedLinear): | |
dtype = self.linear.scales.dtype | |
y = self.linear(x.astype(dtype)) | |
z = (x @ self.lora_a) @ self.lora_b | |
return y + 2.0 * z | |
class RMSNorm(nn.Module): | |
def __init__(self, dims: int, eps: float = 1e-5): | |
super().__init__() | |
self.weight = mx.ones((dims,)) | |
self.eps = eps | |
def _norm(self, x): | |
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) | |
def __call__(self, x): | |
output = self._norm(x.astype(mx.float32)).astype(x.dtype) | |
return self.weight * output | |
class Attention(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.n_heads: int = args.n_heads | |
self.n_kv_heads: int = args.n_kv_heads | |
self.repeats = self.n_heads // self.n_kv_heads | |
args.head_dim = args.dim // self.n_heads | |
self.scale = self.args.head_dim**-0.5 | |
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) | |
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | |
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | |
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) | |
self.rope = nn.RoPE(args.head_dim, traditional=True) | |
def __call__( | |
self, | |
x: mx.array, | |
mask: Optional[mx.array] = None, | |
cache: Optional[Tuple[mx.array, mx.array]] = None, | |
) -> mx.array: | |
B, L, D = x.shape | |
queries, keys, values = self.wq(x), self.wk(x), self.wv(x) | |
# Prepare the queries, keys and values for the attention computation | |
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | |
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
def repeat(a): | |
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) | |
return a.reshape([B, self.n_heads, L, -1]) | |
if self.repeats > 1: | |
keys, values = map(repeat, (keys, values)) | |
if cache is not None: | |
key_cache, value_cache = cache | |
queries = self.rope(queries, offset=key_cache.shape[2]) | |
keys = self.rope(keys, offset=key_cache.shape[2]) | |
keys = mx.concatenate([key_cache, keys], axis=2) | |
values = mx.concatenate([value_cache, values], axis=2) | |
else: | |
queries = self.rope(queries) | |
keys = self.rope(keys) | |
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) | |
if mask is not None: | |
scores += mask | |
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) | |
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) | |
return self.wo(output), (keys, values) | |
class FeedForward(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) | |
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) | |
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) | |
def __call__(self, x) -> mx.array: | |
return self.w2(nn.silu(self.w1(x)) * self.w3(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.n_heads = args.n_heads | |
self.dim = args.dim | |
self.attention = Attention(args) | |
self.feed_forward = FeedForward(args=args) | |
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.args = args | |
def __call__( | |
self, | |
x: mx.array, | |
mask: Optional[mx.array] = None, | |
cache: Optional[Tuple[mx.array, mx.array]] = None, | |
) -> mx.array: | |
r, cache = self.attention(self.attention_norm(x), mask, cache) | |
h = x + r | |
r = self.feed_forward(self.ffn_norm(h)) | |
out = h + r | |
return out, cache | |
class Model(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.vocab_size = args.vocab_size | |
self.n_layers = args.n_layers | |
assert self.vocab_size > 0 | |
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) | |
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) | |
def __call__( | |
self, | |
inputs: mx.array, | |
cache=None, | |
): | |
h = self.tok_embeddings(inputs) | |
mask = None | |
if h.shape[1] > 1: | |
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | |
mask = mask.astype(h.dtype) | |
if cache is None: | |
cache = [None] * len(self.layers) | |
for e, layer in enumerate(self.layers): | |
h, cache[e] = layer(h, mask, cache[e]) | |
return self.output(self.norm(h)), cache |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment