Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save graylan0/46748a9102a463f98a85daca68a9f42e to your computer and use it in GitHub Desktop.
Save graylan0/46748a9102a463f98a85daca68a9f42e to your computer and use it in GitHub Desktop.
import streamlit as st
from mlx_lm import load
from mlx_lm.utils import generate_step
import mlx.core as mx
from tqdm import tqdm
import pennylane as qml
import numpy as np
import re
from textblob import TextBlob
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from collections import Counter
tqdm(disable=True, total=0)
title = "MLX Chat"
ver = "0.7.3"
debug = False
st.set_page_config(page_title=title, page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
st.title(title)
assistant_greeting = "How may I help you?"
temp = 0.8
model_ref = st.sidebar.text_input("model", "mlx-community/Nous-Hermes-2-Mixtral-8x7B-DPO-4bit")
n_ctx = st.sidebar.number_input('context length', value=300, min_value=100, step=100, max_value=32000)
actions = st.sidebar.columns(2)
@st.cache_resource(show_spinner=True)
def load_model(ref):
return load(ref)
model, tokenizer = load_model(model_ref)
@qml.qnode(qml.device("default.qubit", wires=4))
def quantum_circuit(color_code, amplitude):
r, g, b = (int(color_code[i:i+2], 16) for i in (0, 2, 4))
r, g, b = r / 255.0, g / 255.0, b / 255.0
qml.RY(r * np.pi, wires=0)
qml.RY(g * np.pi, wires=1)
qml.RY(b * np.pi, wires=2)
qml.RY(amplitude * np.pi, wires=3)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
qml.CNOT(wires=[2, 3])
return qml.probs(wires=[0, 1, 2, 3])
async def sentiment_to_amplitude(text):
analysis = TextBlob(text)
return (analysis.sentiment.polarity + 1) / 2
def is_code_like(chunk):
code_patterns = r'\b(def|class|import|if|else|for|while|return|function|var|let|const|print)\b|[\{\}\(\)=><\+\-\*/]'
return bool(re.search(code_patterns, chunk))
def determine_token(chunk, memory, max_words_to_check=500):
combined_chunk = f"{memory} {chunk}"
if not combined_chunk:
return "[attention]"
if is_code_like(combined_chunk):
return "[code]"
words = word_tokenize(combined_chunk)[:max_words_to_check]
tagged_words = pos_tag(words)
pos_counts = Counter(tag[:2] for _, tag in tagged_words)
most_common_pos, _ = pos_counts.most_common(1)[0]
if most_common_pos == 'VB':
return "[action]"
elif most_common_pos == 'NN':
return "[subject]"
elif most_common_pos in ['JJ', 'RB']:
return "[description]"
else:
return "[general]"
def generate(the_prompt, the_model):
tokens = []
skip = 0
memory = ""
for token, _ in zip(generate_step(mx.array(tokenizer.encode(the_prompt)), the_model, temp), range(n_ctx)):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
s = tokenizer.decode(tokens)
chunk = s[skip:]
token_type = determine_token(chunk, memory)
memory += f" {token_type} {chunk}"
yield memory
skip = len(s)
def show_chat(the_prompt, previous=""):
with st.chat_message("assistant"):
message_placeholder = st.empty()
response = previous
for chunk in generate(the_prompt, model):
response = chunk.replace('�', '')
message_placeholder.markdown(response + "▌")
message_placeholder.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
def remove_last_occurrence_in_array(array_of_dicts, criteria):
for i in reversed(range(len(array_of_dicts))):
if criteria(array_of_dicts[i]):
del array_of_dicts[i]
break
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": assistant_greeting}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
full_prompt = f"user\n{prompt}\n\nassistant\n"
show_chat(full_prompt)
if st.session_state.messages and sum(msg["role"] == "assistant" for msg in st.session_state.messages) > 1:
if actions[0].button("Reset", key='reset'):
st.session_state.messages = [{"role": "assistant", "content": assistant_greeting}]
st.rerun()
if st.session_state.messages and sum(msg["role"] == "assistant" for msg in st.session_state.messages) > 1:
if actions[1].button("Continue", key='continue'):
user_prompts = [msg["content"] for msg in st.session_state.messages if msg["role"] == "user"]
last_prompt = user_prompts[-1] or "Please continue your response."
assistant_responses = [msg["content"] for msg in st.session_state.messages if msg["role"] == "assistant"]
remove_last_occurrence_in_array(st.session_state.messages, lambda item: item.get("role") == "assistant")
last_assistant_response = assistant_responses[-1] if assistant_responses else ""
last_assistant_response_lines = last_assistant_response.split('\n')
if len(last_assistant_response_lines) > 1:
last_assistant_response_lines.pop()
last_assistant_response = "\n".join(last_assistant_response_lines)
full_prompt = f"user\n{last_prompt}\nassistant\n{last_assistant_response}"
show_chat(full_prompt, last_assistant_response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment