Skip to content

Instantly share code, notes, and snippets.

@awni
Last active September 23, 2025 06:01
Show Gist options
  • Save awni/05320a81970c204b245e0444c6728bd3 to your computer and use it in GitHub Desktop.
Save awni/05320a81970c204b245e0444c6728bd3 to your computer and use it in GitHub Desktop.
Remember with MLX LM
import argparse
import copy
import mlx.core as mx
from pathlib import Path
from mlx_lm import load, stream_generate
from mlx_lm.generate import generate_step
from mlx_lm.models.cache import make_prompt_cache
DEFAULT_MAX_TOKENS = 2048
DEFAULT_MODEL = "mlx-community/Qwen3-4B-Instruct-2507-4bit"
SYSTEM_PROMPT = (
"You are a helpful memory assistant. You have access to a file `memory.txt` with "
"a bunch of important information that the user wants you to help them remember "
"and retrieve from.\n"
"The user may ask you a question which you should try and use the memory to "
"answer if it makes sense. Otherwise answer to the best of your ability."
)
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
model, tokenizer = load(args.model)
def print_help():
print("The command list:")
print("- '/r <text>' to remember something")
print("- '/h' for help")
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
memory_file = Path.home() / ".cache/mlx-lm/memory.txt"
memory_file.parent.mkdir(parents=True, exist_ok=True)
if memory_file.exists():
with open(memory_file, "r") as f:
memory_text = f.read()
else:
memory_text = ""
def append_to_memory(text):
nonlocal memory_text
text = text + "\n---\n"
with open(memory_file, "a") as f:
f.write(text)
memory_text += text
return memory_text
prompt_len = 0
def generate_prompt(text, prefill_only=True):
local_messages = list(messages)
if len(text) > 0:
local_messages.append({"role": "user", "content": text})
prompt = tokenizer.apply_chat_template(
local_messages,
add_generation_prompt=not prefill_only,
continue_final_message=prefill_only and len(text) > 0,
)[prompt_len:]
return prompt
def prefill_callback(n, total):
if n == total:
print(" " * 50, end="\r")
else:
print(f"[Prefilling memory tokens {n}/{total}]", end="\r")
def prefill(prompt):
nonlocal prompt_len
prompt = generate_prompt(memory_text)
prompt_len += len(prompt)
for _ in generate_step(
mx.array(prompt),
model,
max_tokens=0,
prompt_cache=prompt_cache,
prompt_progress_callback=prefill_callback,
):
pass
prompt_cache = make_prompt_cache(model)
prefill(memory_text)
print(f"[INFO] Resuming memory session with {args.model}.")
print_help()
while True:
query = input(">> ")
if query.startswith("/r"):
# Add everything after /r to the memory file
# and process the additional prompt
prefill(append_to_memory(query[2:].strip()))
elif query.startswith("/h"):
print_help()
else:
# It's a query, generate a response
prompt = generate_prompt(memory_text + "\n\n" + query, False)
for response in stream_generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
prompt_cache=copy.deepcopy(prompt_cache),
):
print(response.text, flush=True, end="")
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment