Last active
September 23, 2025 06:01
-
-
Save awni/05320a81970c204b245e0444c6728bd3 to your computer and use it in GitHub Desktop.
Remember with MLX LM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import mlx.core as mx | |
from pathlib import Path | |
from mlx_lm import load, stream_generate | |
from mlx_lm.generate import generate_step | |
from mlx_lm.models.cache import make_prompt_cache | |
DEFAULT_MAX_TOKENS = 2048 | |
DEFAULT_MODEL = "mlx-community/Qwen3-4B-Instruct-2507-4bit" | |
SYSTEM_PROMPT = ( | |
"You are a helpful memory assistant. You have access to a file `memory.txt` with " | |
"a bunch of important information that the user wants you to help them remember " | |
"and retrieve from.\n" | |
"The user may ask you a question which you should try and use the memory to " | |
"answer if it makes sense. Otherwise answer to the best of your ability." | |
) | |
def setup_arg_parser(): | |
"""Set up and return the argument parser.""" | |
parser = argparse.ArgumentParser(description="Chat with an LLM") | |
parser.add_argument( | |
"--model", | |
type=str, | |
help="The path to the local model directory or Hugging Face repo.", | |
default=DEFAULT_MODEL, | |
) | |
parser.add_argument( | |
"--max-tokens", | |
"-m", | |
type=int, | |
default=DEFAULT_MAX_TOKENS, | |
help="Maximum number of tokens to generate", | |
) | |
return parser | |
def main(): | |
parser = setup_arg_parser() | |
args = parser.parse_args() | |
model, tokenizer = load(args.model) | |
def print_help(): | |
print("The command list:") | |
print("- '/r <text>' to remember something") | |
print("- '/h' for help") | |
messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
memory_file = Path.home() / ".cache/mlx-lm/memory.txt" | |
memory_file.parent.mkdir(parents=True, exist_ok=True) | |
if memory_file.exists(): | |
with open(memory_file, "r") as f: | |
memory_text = f.read() | |
else: | |
memory_text = "" | |
def append_to_memory(text): | |
nonlocal memory_text | |
text = text + "\n---\n" | |
with open(memory_file, "a") as f: | |
f.write(text) | |
memory_text += text | |
return memory_text | |
prompt_len = 0 | |
def generate_prompt(text, prefill_only=True): | |
local_messages = list(messages) | |
if len(text) > 0: | |
local_messages.append({"role": "user", "content": text}) | |
prompt = tokenizer.apply_chat_template( | |
local_messages, | |
add_generation_prompt=not prefill_only, | |
continue_final_message=prefill_only and len(text) > 0, | |
)[prompt_len:] | |
return prompt | |
def prefill_callback(n, total): | |
if n == total: | |
print(" " * 50, end="\r") | |
else: | |
print(f"[Prefilling memory tokens {n}/{total}]", end="\r") | |
def prefill(prompt): | |
nonlocal prompt_len | |
prompt = generate_prompt(memory_text) | |
prompt_len += len(prompt) | |
for _ in generate_step( | |
mx.array(prompt), | |
model, | |
max_tokens=0, | |
prompt_cache=prompt_cache, | |
prompt_progress_callback=prefill_callback, | |
): | |
pass | |
prompt_cache = make_prompt_cache(model) | |
prefill(memory_text) | |
print(f"[INFO] Resuming memory session with {args.model}.") | |
print_help() | |
while True: | |
query = input(">> ") | |
if query.startswith("/r"): | |
# Add everything after /r to the memory file | |
# and process the additional prompt | |
prefill(append_to_memory(query[2:].strip())) | |
elif query.startswith("/h"): | |
print_help() | |
else: | |
# It's a query, generate a response | |
prompt = generate_prompt(memory_text + "\n\n" + query, False) | |
for response in stream_generate( | |
model, | |
tokenizer, | |
prompt, | |
max_tokens=args.max_tokens, | |
prompt_cache=copy.deepcopy(prompt_cache), | |
): | |
print(response.text, flush=True, end="") | |
print() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment