Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created June 16, 2023 15:29
Show Gist options
  • Save CoffeeVampir3/7f858e81f1ef56b99cdf520055f6c7b3 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/7f858e81f1ef56b99cdf520055f6c7b3 to your computer and use it in GitHub Desktop.
langchain-exllama
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from typing import Any, Dict, Generator, List, Optional
from pydantic import Field, root_validator
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import os, glob, time, json, sys, logging
class Exllama(LLM):
client: Any #: :meta private:
model_path: str
"""The path to the GPTQ model folder."""
cache: ExLlamaCache = None
disallowed_tokens: Optional[List[str]] = Field(None, description="List of tokens to disallow during generation.")
temperature: Optional[float] = Field(0.95, description="Temperature for sampling diversity.")
top_k: Optional[int] = Field(40, description="Consider the most probable top_k samples, 0 to disable top_k sampling.")
top_p: Optional[float] = Field(0.65, description="Consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling.")
min_p: Optional[float] = Field(0.0, description="Do not consider tokens with probability less than this.")
typical: Optional[float] = Field(0.0, description="Locally typical sampling threshold, 0.0 to disable typical sampling.")
repetition_penalty_max: Optional[float] = Field(1.15, description="Repetition penalty for most recent tokens.")
repetition_penalty_sustain: Optional[int] = Field(256, description="No. most recent tokens to repeat penalty for, -1 to apply to whole context.")
repetition_penalty_decay: Optional[int] = Field(128, description="Gradually decrease penalty over this many tokens.")
beams: Optional[int] = Field(1, description="Number of beams for beam search.")
beam_length: Optional[int] = Field(1, description="Length of beams for beam search.")
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
model_path = values["model_path"]
tokenizer_path = os.path.join(model_path, "tokenizer.model")
model_config_path = os.path.join(model_path, "config.json")
st_pattern = os.path.join(model_path, "*.safetensors")
model_path = glob.glob(st_pattern)[0]
config = ExLlamaConfig(model_config_path)
config.model_path = model_path
tokenizer = ExLlamaTokenizer(tokenizer_path)
model_param_names = [
"temperature",
"top_k",
"top_p",
"min_p",
"typical",
"repetition_penalty_max",
"repetition_penalty_sustain",
"repetition_penalty_decay",
"beams",
"beam_length"
]
model_params = {k: values.get(k) for k in model_param_names}
model = ExLlama(config)
cls.generator = ExLlamaGenerator(model, tokenizer, ExLlamaCache(model)) # create generator
for key, value in model_params.items():
setattr(cls.generator.settings, key, value)
cls.generator.disallow_tokens(values.get("disallowed_tokens"))
values["client"] = model
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "Exllama"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
outputs = self.generator.generate_simple(prompt, max_new_tokens = 200)
return outputs
llm = Exllama(model_path=os.path.abspath(sys.argv[1]))
op = llm("Test")
print(op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment