Last active
November 20, 2023 13:24
-
-
Save dittops/da7ad221f21d623754faaba525133f4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import fire | |
import gradio as gr | |
import torch | |
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer | |
def get_prompt(prompt): | |
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request | |
### Instruction: | |
{prompt} | |
### Response: | |
""" | |
def main( | |
load_8bit: bool = False, | |
base_model: str = "", | |
lora_weights: str = "" | |
): | |
device = "cuda" | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model, | |
#load_in_8bit=load_8bit, | |
#torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
def evaluate( | |
instruction, | |
temperature=0.1, | |
top_p=0.75, | |
top_k=40, | |
num_beams=4, | |
max_new_tokens=128, | |
**kwargs, | |
): | |
prompt = get_prompt(instruction) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
output = output.split("### Response:")[1].strip() | |
yield output | |
gr.Interface( | |
fn=evaluate, | |
inputs=[ | |
gr.components.Textbox( | |
lines=2, | |
label="Query", | |
placeholder="Ask me anything", | |
), | |
gr.components.Slider( | |
minimum=0, maximum=1, value=0.1, label="Temperature" | |
), | |
gr.components.Slider( | |
minimum=0, maximum=1, value=0.75, label="Top p" | |
), | |
gr.components.Slider( | |
minimum=0, maximum=100, step=1, value=40, label="Top k" | |
), | |
gr.components.Slider( | |
minimum=1, maximum=4, step=1, value=4, label="Beams" | |
), | |
gr.components.Slider( | |
minimum=1, maximum=8192, step=1, value=128, label="Max tokens" | |
) | |
], | |
outputs=[ | |
gr.inputs.Textbox( | |
lines=5, | |
label="Output", | |
) | |
], | |
title="Bud Code", | |
#description="An instruction finetuned model", | |
).queue().launch(server_name="0.0.0.0", share=True) | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment