Skip to content

Instantly share code, notes, and snippets.

@redgeoff
Last active July 16, 2023 15:23
Show Gist options
  • Save redgeoff/479ee41eb2abd4b4561338bce2ea1352 to your computer and use it in GitHub Desktop.
Save redgeoff/479ee41eb2abd4b4561338bce2ea1352 to your computer and use it in GitHub Desktop.
mpt-7b-chat-interactive-optimized.py
# !pip install -qU transformers accelerate einops langchain xformers triton
# !pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python
from torch import cuda, bfloat16
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
print(f"device={device}")
# Initialize the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-chat", trust_remote_code=True)
# Initialize model with Triton optimization. This is supposed to speed up the
# model at the cost of using more mem, but I haven't been able to get it to work
# yet
optimize = True
if optimize:
config = AutoConfig.from_pretrained(
'mosaicml/mpt-7b-chat',
trust_remote_code=True
)
config.attn_config['attn_impl'] = 'triton'
# config.update({"max_seq_len": 100})
config.init_device = device
else:
config={"init_device": "meta"}
model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b-chat",
trust_remote_code=True,
config=config,
torch_dtype=bfloat16)
print('loaded')
# tokenizer.eval() # fails!
# tokenizer.to(device)
# model.eval() # TODO: needed?
model.to(device)
import time
from IPython.display import Markdown
def ask_question(question, max_length=100):
start_time = time.time()
# Encode the question
input_ids = tokenizer.encode(question, return_tensors='pt')
input_ids = input_ids.to(device)
# input_ids = input_ids.to('cuda')
# mtp-7b is trained to add "<|endoftext|>" at the end of generations
stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
# Generate a response
output = model.generate(
input_ids,
max_length=max_length,
# max_length=1000,
temperature=0.9,
# pad_token_id=stop_token_ids[0],
# num_return_sequences=1,
stopping_criteria=stopping_criteria,
# top_p=0.15, # select from top tokens whose probability add up to 15%
# top_k=0, # select from top 0 tokens (because zero, relies on top_p)
# max_new_tokens=64, # max number of tokens to generate in the output
#repetition_penalty=1.1 # without this output begins repeating
)
# Decode the response
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
end_time = time.time()
duration = end_time - start_time
print(response)
print("Function duration:", duration, "seconds")
# Ask a question
# ask_question("What is the capital of France?")
# ask_question("Explain to me the difference between nuclear fission and fusion.", 200)
# ask_question("write python code that converts a csv into a pdf", 400)
termination_condition = False
while not termination_condition:
user_input = input("User: ")
# Check for termination condition (e.g., if the user enters 'exit')
if user_input.lower() == 'exit':
termination_condition = True
else:
ask_question(user_input)
print("done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment