Created
August 28, 2023 15:41
-
-
Save justdoit0823/d05a06fd1948ea75b848cacc66eda8d5 to your computer and use it in GitHub Desktop.
Llama2 python code completion web ui
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from threading import Event, Thread | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig | |
model_name = "codellama/CodeLlama-7b-Python-hf" | |
print(f"Starting to load the model {model_name} into memory") | |
tok = AutoTokenizer.from_pretrained( | |
model_name, | |
) | |
m = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
# load_in_4bit=True, | |
torch_dtype=torch.bfloat16, | |
# device_map={"": 0} | |
) | |
print(f"Successfully loaded the model {model_name} into memory") | |
max_new_tokens = 512 | |
def generate(input_code, temperature, top_p, top_k, repetition_penalty): | |
print(f"input code: {input_code}") | |
# Tokenize the messages string | |
input_ids = tok(input_code, return_tensors="pt").input_ids | |
print("inpute shape:", input_ids.shape) | |
streamer = TextIteratorStreamer( | |
tok, timeout=None, skip_prompt=True, skip_special_tokens=True) | |
generation_config = GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
do_sample=temperature > 0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=1, | |
) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
streamer=streamer, | |
) | |
stream_complete = Event() | |
def generate_and_signal_complete(): | |
print("start generating...") | |
try: | |
m.generate(**generate_kwargs) | |
print("genetation done...") | |
finally: | |
stream_complete.set() | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
partial_text = '' | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
# Setup the gradio Demo. | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=".disclaimer {font-variant-caps: all-small-caps;}", | |
) as demo: | |
gr.Markdown( | |
"""Llama2 Python Code Demo | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.Textbox( | |
label="Code Input", | |
placeholder="import sys", | |
show_label=True, | |
lines=5, | |
) | |
gen = submit = gr.Button("Generate") | |
output = gr.Textbox( | |
label="Code Ouput", | |
placeholder="", | |
show_label=True, | |
) | |
with gr.Row(): | |
with gr.Accordion("Advanced Options:", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.5, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.9, | |
minimum=0.0, | |
maximum=1, | |
step=0.01, | |
interactive=True, | |
info=( | |
"Sample from the smallest possible set of tokens whose cumulative probability " | |
"exceeds top_p. Set to 1 to disable and sample from all tokens." | |
), | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_k = gr.Slider( | |
label="Top-k", | |
value=0, | |
minimum=0.0, | |
maximum=200, | |
step=1, | |
interactive=True, | |
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
value=1.0, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.1, | |
interactive=True, | |
info="Penalize repetition — 1.0 to disable.", | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce " | |
"factually accurate information. The model was trained on various public datasets; while great efforts " | |
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " | |
"biased, or otherwise offensive outputs.", | |
elem_classes=["disclaimer"], | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)", | |
elem_classes=["disclaimer"], | |
) | |
submit_event = input.submit( | |
fn=generate, | |
inputs=[ | |
input, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
], | |
outputs=output, | |
queue=True, | |
) | |
submit_click_event = gen.click( | |
fn=generate, | |
inputs=[ | |
input, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
], | |
outputs=output, | |
queue=True, | |
) | |
demo.queue(max_size=128, concurrency_count=2) | |
# Launch your Guanaco Demo! | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment