Last active
July 16, 2023 15:23
-
-
Save redgeoff/479ee41eb2abd4b4561338bce2ea1352 to your computer and use it in GitHub Desktop.
mpt-7b-chat-interactive-optimized.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# !pip install -qU transformers accelerate einops langchain xformers triton | |
# !pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python | |
from torch import cuda, bfloat16 | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' | |
print(f"device={device}") | |
# Initialize the tokenizer and the model | |
tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-chat", trust_remote_code=True) | |
# Initialize model with Triton optimization. This is supposed to speed up the | |
# model at the cost of using more mem, but I haven't been able to get it to work | |
# yet | |
optimize = True | |
if optimize: | |
config = AutoConfig.from_pretrained( | |
'mosaicml/mpt-7b-chat', | |
trust_remote_code=True | |
) | |
config.attn_config['attn_impl'] = 'triton' | |
# config.update({"max_seq_len": 100}) | |
config.init_device = device | |
else: | |
config={"init_device": "meta"} | |
model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b-chat", | |
trust_remote_code=True, | |
config=config, | |
torch_dtype=bfloat16) | |
print('loaded') | |
# tokenizer.eval() # fails! | |
# tokenizer.to(device) | |
# model.eval() # TODO: needed? | |
model.to(device) | |
import time | |
from IPython.display import Markdown | |
def ask_question(question, max_length=100): | |
start_time = time.time() | |
# Encode the question | |
input_ids = tokenizer.encode(question, return_tensors='pt') | |
input_ids = input_ids.to(device) | |
# input_ids = input_ids.to('cuda') | |
# mtp-7b is trained to add "<|endoftext|>" at the end of generations | |
stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) | |
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
# define custom stopping criteria object | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
for stop_id in stop_token_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) | |
# Generate a response | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
# max_length=1000, | |
temperature=0.9, | |
# pad_token_id=stop_token_ids[0], | |
# num_return_sequences=1, | |
stopping_criteria=stopping_criteria, | |
# top_p=0.15, # select from top tokens whose probability add up to 15% | |
# top_k=0, # select from top 0 tokens (because zero, relies on top_p) | |
# max_new_tokens=64, # max number of tokens to generate in the output | |
#repetition_penalty=1.1 # without this output begins repeating | |
) | |
# Decode the response | |
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
end_time = time.time() | |
duration = end_time - start_time | |
print(response) | |
print("Function duration:", duration, "seconds") | |
# Ask a question | |
# ask_question("What is the capital of France?") | |
# ask_question("Explain to me the difference between nuclear fission and fusion.", 200) | |
# ask_question("write python code that converts a csv into a pdf", 400) | |
termination_condition = False | |
while not termination_condition: | |
user_input = input("User: ") | |
# Check for termination condition (e.g., if the user enters 'exit') | |
if user_input.lower() == 'exit': | |
termination_condition = True | |
else: | |
ask_question(user_input) | |
print("done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment