Created
June 16, 2023 15:29
-
-
Save CoffeeVampir3/7f858e81f1ef56b99cdf520055f6c7b3 to your computer and use it in GitHub Desktop.
langchain-exllama
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from typing import Any, Dict, Generator, List, Optional | |
from pydantic import Field, root_validator | |
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
import os, glob, time, json, sys, logging | |
class Exllama(LLM): | |
client: Any #: :meta private: | |
model_path: str | |
"""The path to the GPTQ model folder.""" | |
cache: ExLlamaCache = None | |
disallowed_tokens: Optional[List[str]] = Field(None, description="List of tokens to disallow during generation.") | |
temperature: Optional[float] = Field(0.95, description="Temperature for sampling diversity.") | |
top_k: Optional[int] = Field(40, description="Consider the most probable top_k samples, 0 to disable top_k sampling.") | |
top_p: Optional[float] = Field(0.65, description="Consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling.") | |
min_p: Optional[float] = Field(0.0, description="Do not consider tokens with probability less than this.") | |
typical: Optional[float] = Field(0.0, description="Locally typical sampling threshold, 0.0 to disable typical sampling.") | |
repetition_penalty_max: Optional[float] = Field(1.15, description="Repetition penalty for most recent tokens.") | |
repetition_penalty_sustain: Optional[int] = Field(256, description="No. most recent tokens to repeat penalty for, -1 to apply to whole context.") | |
repetition_penalty_decay: Optional[int] = Field(128, description="Gradually decrease penalty over this many tokens.") | |
beams: Optional[int] = Field(1, description="Number of beams for beam search.") | |
beam_length: Optional[int] = Field(1, description="Length of beams for beam search.") | |
@root_validator() | |
def validate_environment(cls, values: Dict) -> Dict: | |
model_path = values["model_path"] | |
tokenizer_path = os.path.join(model_path, "tokenizer.model") | |
model_config_path = os.path.join(model_path, "config.json") | |
st_pattern = os.path.join(model_path, "*.safetensors") | |
model_path = glob.glob(st_pattern)[0] | |
config = ExLlamaConfig(model_config_path) | |
config.model_path = model_path | |
tokenizer = ExLlamaTokenizer(tokenizer_path) | |
model_param_names = [ | |
"temperature", | |
"top_k", | |
"top_p", | |
"min_p", | |
"typical", | |
"repetition_penalty_max", | |
"repetition_penalty_sustain", | |
"repetition_penalty_decay", | |
"beams", | |
"beam_length" | |
] | |
model_params = {k: values.get(k) for k in model_param_names} | |
model = ExLlama(config) | |
cls.generator = ExLlamaGenerator(model, tokenizer, ExLlamaCache(model)) # create generator | |
for key, value in model_params.items(): | |
setattr(cls.generator.settings, key, value) | |
cls.generator.disallow_tokens(values.get("disallowed_tokens")) | |
values["client"] = model | |
return values | |
@property | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "Exllama" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
outputs = self.generator.generate_simple(prompt, max_new_tokens = 200) | |
return outputs | |
llm = Exllama(model_path=os.path.abspath(sys.argv[1])) | |
op = llm("Test") | |
print(op) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment