Last active
January 13, 2024 23:32
-
-
Save aminnj/c1d66cc7d5be4f14a9f1e093731d7f75 to your computer and use it in GitHub Desktop.
Evaluate layers based on a list of indices with MLX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py | |
index 9b9a602..5fd5146 100644 | |
--- a/llms/mistral/mistral.py | |
+++ b/llms/mistral/mistral.py | |
@@ -144,6 +144,7 @@ class Mistral(nn.Module): | |
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) | |
+ self.ilayers = list(range(len(self.layers))) | |
def __call__( | |
self, | |
@@ -158,9 +159,10 @@ class Mistral(nn.Module): | |
mask = mask.astype(h.dtype) | |
if cache is None: | |
- cache = [None] * len(self.layers) | |
+ cache = [None] * len(self.ilayers) | |
- for e, layer in enumerate(self.layers): | |
+ for e, ilayer in enumerate(self.ilayers): | |
+ layer = self.layers[ilayer] | |
h, cache[e] = layer(h, mask, cache[e]) | |
return self.output(self.norm(h)), cache | |
@@ -267,6 +269,21 @@ if __name__ == "__main__": | |
print("[INFO] Loading model from disk.") | |
model, tokenizer = load_model(args.model_path) | |
+ # default = list(range(model.n_layers)) | |
+ # model.ilayers = default | |
+ | |
+ overlap_8_by_4 = ( | |
+ [] | |
+ + list(range(0,8)) | |
+ + list(range(4,12)) | |
+ + list(range(8,16)) | |
+ + list(range(12,20)) | |
+ + list(range(16,24)) | |
+ + list(range(20,28)) | |
+ + list(range(24,32)) | |
+ ) | |
+ model.ilayers = overlap_8_by_4 | |
+ | |
print("[INFO] Starting generation...") | |
tic = time.time() | |
print(args.prompt, end="", flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
pip install mlx
on an Apple Silicon mac.mlx-examples
ilayers
. It is a list of transformer layer indices. For example,ilayers = [1,2,3,1,2,3]
will stack the first 3 layers twice. Doubling each layer viailayers
will not double memory usage.python mistral.py --model-path ../../../../../../lora/mistral-mlx/ --max-tokens 100 --prompt "A sci-fi story about aliens. Title: Alien virus. Story:"