Skip to content

Instantly share code, notes, and snippets.

@SergioEanX
Created November 8, 2024 07:24
Show Gist options
  • Select an option

  • Save SergioEanX/a9697bcffc6726d1d840dc5c3c4f688a to your computer and use it in GitHub Desktop.

Select an option

Save SergioEanX/a9697bcffc6726d1d840dc5c3c4f688a to your computer and use it in GitHub Desktop.
Inference function for trained LLM and tokenixzer fixing error: "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results. Setting pad_token_id to eos_token_id:0 for open-end generation."
def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100):
"""
Generates a continuation of the input text using the provided model and tokenizer.
Args:
text (str): The input text prompt.
model: The pre-trained language model for generation.
tokenizer: The tokenizer corresponding to the model.
max_input_tokens (int, optional): Maximum number of tokens for the input. Defaults to 1000.
max_output_tokens (int, optional): Maximum number of tokens to generate. Defaults to 100.
Returns:
str: The generated text continuation.
"""
try:
# Ensure the tokenizer has a pad_token. If not, set it to eos_token.
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
print(f"Set pad_token to eos_token: '{tokenizer.pad_token}'")
else:
# If eos_token is also not defined, add a new pad_token.
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
print("Added a new pad_token: '[PAD]'")
# Tokenize the input text and obtain both input_ids and attention_mask
inputs = tokenizer(
text,
return_tensors="pt", # Return PyTorch tensors
truncation=True, # Truncate sequences longer than max_length
max_length=max_input_tokens,
padding='max_length' # Pad sequences to max_length
)
input_ids = inputs['input_ids'] # Tensor of token IDs
attention_mask = inputs['attention_mask']# Tensor indicating which tokens are padding
# Move tensors to the same device as the model
device = model.device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Calculate the total max_length for generation (input + output)
# Alternatively, use max_new_tokens if supported
if hasattr(model, 'generate') and 'max_new_tokens' in model.generate.__code__.co_varnames:
# Use max_new_tokens if available
generated_tokens_with_prompt = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_output_tokens,
pad_token_id=tokenizer.eos_token_id, # Set pad_token_id to eos_token_id
do_sample=True, # Enable sampling for diversity (optional)
top_p=0.95, # Nucleus sampling (optional)
top_k=50 # Top-K sampling (optional)
)
else:
# Fallback to max_length if max_new_tokens is not available
total_max_length = max_input_tokens + max_output_tokens
generated_tokens_with_prompt = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=total_max_length, # Total tokens (input + output)
pad_token_id=tokenizer.eos_token_id, # Set pad_token_id to eos_token_id
do_sample=True, # Enable sampling for diversity (optional)
top_p=0.95, # Nucleus sampling (optional)
top_k=50 # Top-K sampling (optional)
)
# Decode the generated tokens to text
generated_text_with_prompt = tokenizer.batch_decode(
generated_tokens_with_prompt,
skip_special_tokens=True
)
# Extract the generated answer by removing the original prompt
generated_text_answer = generated_text_with_prompt[0][len(text):].strip()
return generated_text_answer
except Exception as e:
print(f"An error occurred during inference: {e}")
return ""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment