Last active
January 23, 2024 07:17
-
-
Save pszemraj/14f7b13bd2d953176db2371e5d320915 to your computer and use it in GitHub Desktop.
basic implementation of a custom wrapper class for using the grammar synthesis text2text models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Class for correcting text using a pretrained model grammar synthesis model. | |
- models are available here: https://hf.co/models?other=grammar%20synthesis | |
requirements for this snippet: | |
pip install -U transformers accelerate | |
NOTE: if you want to use 9-bit to fit the model on a smaller GPU, you need bitsandbytes: | |
pip install -U transformers accelerate bitsandbytes | |
""" | |
import warnings | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
class GrammarSynthesizer: | |
""" | |
Class for correcting text using a pretrained model grammar synthesis model. | |
models are available here: https://hf.co/models?other=grammar%20synthesis | |
# Example usage with the XL | |
corrector = GrammarSynthesizer("pszemraj/flan-t5-xl-grammar-synthesis") | |
raw_text = 'sky is blu.' | |
results = corrector(raw_text, num_beams=2) | |
print(results) | |
""" | |
DEFAULT_MAX_INPUT_LENGTH = 384 | |
DEFAULT_MAX_LENGTH = 128 | |
DEFAULT_NUM_BEAMS = 4 | |
def __init__( | |
self, | |
model_name_or_path: str, | |
should_compile: bool = True, | |
load_in_8bit: bool = False, | |
): | |
""" | |
Initializes the GrammarSynthesizer. | |
Args: | |
model_name_or_path: The name or path of the pretrained model. | |
should_compile: If True, tries to compile the model for faster inference. | |
load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes | |
""" | |
self.model_name_or_path = model_name_or_path | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.model = self._load_and_compile_model(model_name_or_path, should_compile) | |
def _load_and_compile_model( | |
self, model_name_or_path: str, should_compile: bool, load_in_8bit: bool | |
): | |
""" | |
Load and compile the model. | |
Args: | |
model_name_or_path: The name or path of the pretrained model. | |
should_compile: If True, tries to compile the model for faster inference. | |
load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes | |
Returns: | |
The loaded and potentially compiled model. | |
""" | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name_or_path, load_in_8bit=load_in_8bit, device_map="auto" | |
) | |
if should_compile: | |
try: | |
model = torch.compile(model) | |
except Exception as e: | |
print(f"Unable to compile model for faster inference. Reason: {e}") | |
should_compile = False | |
self.compiled_model = should_compile | |
return model | |
def _prepare_inputs(self, input_text: str): | |
""" | |
Prepares the inputs for the model. | |
Args: | |
input_text: The input text to prepare. | |
Returns: | |
The prepared inputs. | |
""" | |
inputs = self.tokenizer.encode(input_text, return_tensors="pt").to( | |
self.model.device | |
) | |
if len(inputs) > self.DEFAULT_MAX_INPUT_LENGTH: | |
warnings.warn( | |
"Input is longer than model training data. Unexpected behavior may occur. " | |
"Consider batch-processing smaller chunks." | |
) | |
return inputs | |
def generate_text( | |
self, | |
input_text: str, | |
max_length: int = DEFAULT_MAX_LENGTH, | |
num_beams: int = DEFAULT_NUM_BEAMS, | |
**kwargs, | |
): | |
""" | |
Generates text from the input. | |
Args: | |
input_text: The input text to generate from. | |
max_length: The maximum length of the generated text. | |
num_beams: The number of beams for beam search. | |
Returns: | |
The generated text. | |
""" | |
if len(input_text) < 2: | |
warnings.warn( | |
f"input text is too short to correct, returning:\t{input_text}" | |
) | |
return input_text | |
inputs = self._prepare_inputs(input_text) | |
outputs = self.model.generate( | |
inputs, max_length=max_length, num_beams=num_beams, **kwargs | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def __call__(self, input_text: str, **kwargs): | |
return self.generate_text(input_text, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
click-able link for the models
See here: https://hf.co/models?other=grammar%20synthesis