Skip to content

Instantly share code, notes, and snippets.

@kishida
Created August 11, 2023 03:31
Show Gist options
  • Save kishida/d88c40d0d73325dbca641822becd9b01 to your computer and use it in GitHub Desktop.
Save kishida/d88c40d0d73325dbca641822becd9b01 to your computer and use it in GitHub Desktop.
StableCode Instruction UI
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "stabilityai/stablecode-instruct-alpha-3b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
#load_in_8bit=True,
torch_dtype="auto",
)
model.cuda()
# print(model)
# inputs = tokenizer("import torch\nimport torch.nn as nn", return_tensors="pt").to("cuda")
def generate(input):
prompt = f"""###Instruction
{input}
###Response
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
tokens = model.generate(
inputs.input_ids,
max_new_tokens=512,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
attention_mask=inputs.attention_mask,
)
return tokenizer.decode(tokens[0], skip_special_tokens=True)[len(prompt):]
"""
The following `model_kwargs` are not used by the model: ['token_type_ids'] (note: typos in the generate arguments will also show up in this list)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
"""
# print(tokenizer.decode(tokens[0], skip_special_tokens=True))
import gradio as gr
gr.Interface(fn = generate, inputs="text", outputs="text").launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment