Created
November 8, 2024 07:24
-
-
Save SergioEanX/a9697bcffc6726d1d840dc5c3c4f688a to your computer and use it in GitHub Desktop.
Inference function for trained LLM and tokenixzer fixing error: "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results. Setting pad_token_id to eos_token_id:0 for open-end generation."
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100): | |
| """ | |
| Generates a continuation of the input text using the provided model and tokenizer. | |
| Args: | |
| text (str): The input text prompt. | |
| model: The pre-trained language model for generation. | |
| tokenizer: The tokenizer corresponding to the model. | |
| max_input_tokens (int, optional): Maximum number of tokens for the input. Defaults to 1000. | |
| max_output_tokens (int, optional): Maximum number of tokens to generate. Defaults to 100. | |
| Returns: | |
| str: The generated text continuation. | |
| """ | |
| try: | |
| # Ensure the tokenizer has a pad_token. If not, set it to eos_token. | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.resize_token_embeddings(len(tokenizer)) | |
| print(f"Set pad_token to eos_token: '{tokenizer.pad_token}'") | |
| else: | |
| # If eos_token is also not defined, add a new pad_token. | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| print("Added a new pad_token: '[PAD]'") | |
| # Tokenize the input text and obtain both input_ids and attention_mask | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", # Return PyTorch tensors | |
| truncation=True, # Truncate sequences longer than max_length | |
| max_length=max_input_tokens, | |
| padding='max_length' # Pad sequences to max_length | |
| ) | |
| input_ids = inputs['input_ids'] # Tensor of token IDs | |
| attention_mask = inputs['attention_mask']# Tensor indicating which tokens are padding | |
| # Move tensors to the same device as the model | |
| device = model.device | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| # Calculate the total max_length for generation (input + output) | |
| # Alternatively, use max_new_tokens if supported | |
| if hasattr(model, 'generate') and 'max_new_tokens' in model.generate.__code__.co_varnames: | |
| # Use max_new_tokens if available | |
| generated_tokens_with_prompt = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_output_tokens, | |
| pad_token_id=tokenizer.eos_token_id, # Set pad_token_id to eos_token_id | |
| do_sample=True, # Enable sampling for diversity (optional) | |
| top_p=0.95, # Nucleus sampling (optional) | |
| top_k=50 # Top-K sampling (optional) | |
| ) | |
| else: | |
| # Fallback to max_length if max_new_tokens is not available | |
| total_max_length = max_input_tokens + max_output_tokens | |
| generated_tokens_with_prompt = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_length=total_max_length, # Total tokens (input + output) | |
| pad_token_id=tokenizer.eos_token_id, # Set pad_token_id to eos_token_id | |
| do_sample=True, # Enable sampling for diversity (optional) | |
| top_p=0.95, # Nucleus sampling (optional) | |
| top_k=50 # Top-K sampling (optional) | |
| ) | |
| # Decode the generated tokens to text | |
| generated_text_with_prompt = tokenizer.batch_decode( | |
| generated_tokens_with_prompt, | |
| skip_special_tokens=True | |
| ) | |
| # Extract the generated answer by removing the original prompt | |
| generated_text_answer = generated_text_with_prompt[0][len(text):].strip() | |
| return generated_text_answer | |
| except Exception as e: | |
| print(f"An error occurred during inference: {e}") | |
| return "" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment