Skip to content

Instantly share code, notes, and snippets.

@TheodoreGalanos
Created August 17, 2024 07:32
Show Gist options
  • Save TheodoreGalanos/71801d3a1b8fd7a866aaae15b0be85ce to your computer and use it in GitHub Desktop.
Save TheodoreGalanos/71801d3a1b8fd7a866aaae15b0be85ce to your computer and use it in GitHub Desktop.
import yaml
import heapq
import argparse
import numpy as np
import language_tool_python
from functools import lru_cache
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
class BeamNRPA:
def __init__(self, config):
self.model, self.tokenizer = load_finetuned_llm(config['model']['path_or_name'])
initial_policy = {
'model': self.model,
'tokenizer': self.tokenizer,
'base_prompt': config['model']['initial_prompt'],
'generation_params': config['generation_params'],
'max_context_length': min(self.model.config.max_position_embeddings, 1024),
'forbidden_words': [],
}
self.policy_adapter = PolicyAdapter(initial_policy)
self.action_sampler = ActionSampler(self.policy_adapter.policy)
self.scorer = Scorer(config['scorer'])
self.max_level = config['beam_nrpa']['max_level']
self.beam_width = config['beam_nrpa']['beam_width']
self.adapt_iterations = config['beam_nrpa']['adapt_iterations']
self.rollout_limit = config['beam_nrpa']['rollout_limit']
self.best_solution = None
def nested_beam_search(self, level):
if level == 0:
return [self.rollout() for _ in range(self.beam_width)]
next_beam = []
for _ in range(self.beam_width):
sequence = []
for _ in range(self.rollout_limit):
action = self.action_sampler.sample_action([sequence])[0]
sequence.append(action)
child_beam = self.nested_beam_search(level - 1)
best_child = max(child_beam, key=lambda x: x[1])
sequence.extend(best_child[0])
score = self.scorer.score(sequence)
heapq.heappush(next_beam, (-score, sequence))
if len(next_beam) > self.beam_width:
heapq.heappop(next_beam)
return [(-score, seq) for score, seq in next_beam]
def rollout(self):
sequence = []
for _ in range(self.rollout_limit):
action = self.action_sampler.sample_action([sequence])[0]
sequence.append(action)
score = self.scorer.score(sequence)
return (sequence, score)
def run(self):
for _ in range(self.adapt_iterations):
beam = self.nested_beam_search(self.max_level)
best_in_iteration = max(beam, key=lambda x: x[0])
if self.best_solution is None or best_in_iteration[0] > self.best_solution[0]:
self.best_solution = best_in_iteration
self.policy_adapter.adapt_policy(best_in_iteration[1], best_in_iteration[0])
self.action_sampler.policy = self.policy_adapter.policy
return self.best_solution
def load_finetuned_llm(model_path_or_name):
"""
Load a fine-tuned language model from a local path or Hugging Face model hub.
Args:
model_path_or_name (str): Local path or Hugging Face model name
Returns:
tuple: (model, tokenizer)
"""
try:
print(f"Loading model from {model_path_or_name}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_path_or_name,
device_map="auto", # Automatically choose the best device setup
torch_dtype=torch.float16 # Use float16 for efficiency if supported
)
print(f"Model loaded successfully. Model type: {type(model).__name__}")
return model, tokenizer
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
class ActionSampler:
def __init__(self, policy):
self.policy = policy
self.model = policy['model']
self.tokenizer = policy['tokenizer']
self.cache = {}
@lru_cache(maxsize=1000)
def _tokenize(self, text):
return self.tokenizer(text, return_tensors="pt")
def _truncate_inputs(self, inputs):
max_length = self.policy['max_context_length']
if inputs.input_ids.shape[1] > max_length:
inputs.input_ids = inputs.input_ids[:, -max_length:]
inputs.attention_mask = inputs.attention_mask[:, -max_length:]
return inputs
def _adaptive_temperature(self, sequence):
# Implement adaptive temperature based on sequence length or other factors
base_temp = self.policy['generation_params']['temperature']
return base_temp * (1 + len(sequence) * 0.01) # Example: increase temperature slightly as sequence grows
def _constrained_decoding(self, outputs):
# Implement constrained decoding
forbidden_words = self.policy.get('forbidden_words', [])
for word in forbidden_words:
outputs = outputs.filter(lambda x: word not in x)
return outputs
def sample_action(self, sequences):
try:
full_prompts = [self.policy['base_prompt'] + ' '.join(seq) for seq in sequences]
inputs = self.tokenizer(full_prompts, return_tensors="pt", padding=True, truncation=True)
# Move inputs to the same device as the model
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
do_sample=True,
**self.policy['generation_params'],
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and post-process
generated_texts = self.tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
generated_texts = [self.post_process(text) for text in generated_texts]
return generated_texts
except Exception as e:
print(f"Error in sample_action: {str(e)}")
return ["" for _ in sequences] # Return empty strings in case of error
def post_process(self, text):
# Implement post-processing
text = text.strip()
# Add more post-processing steps as needed
return text
class PolicyAdapter:
def __init__(self, initial_policy):
self.policy = initial_policy
self.adaptation_rate = 0.1
self.successful_sequences = []
self.token_frequencies = Counter()
def adapt_policy(self, successful_sequence, score):
self.successful_sequences.append((successful_sequence, score))
# Update token frequencies
self.update_token_frequencies(successful_sequence)
# Adapt generation parameters
self.adapt_generation_params(score)
# Adapt base prompt
self.adapt_base_prompt()
# Adapt forbidden words
self.adapt_forbidden_words()
return self.policy
def update_token_frequencies(self, sequence):
for token in sequence:
self.token_frequencies[token] += 1
def adjust_param(self, current_value, score, lower_bound, upper_bound):
if score > self.policy.get('average_score', 0):
new_value = current_value * (1 - self.adaptation_rate)
else:
new_value = current_value * (1 + self.adaptation_rate)
return max(lower_bound, min(new_value, upper_bound))
def adapt_generation_params(self, score):
params = self.policy['generation_params']
# Adjust temperature
params['temperature'] = self.adjust_param(
params['temperature'], score, lower_bound=0.1, upper_bound=2.0
)
# Adjust top_p
params['top_p'] = self.adjust_param(
params['top_p'], score, lower_bound=0.1, upper_bound=1.0
)
def adapt_base_prompt(self):
most_common = self.token_frequencies.most_common(5)
frequent_tokens = [token for token, _ in most_common]
additional_context = f" Consider using these relevant terms: {', '.join(frequent_tokens)}."
if len(self.policy['base_prompt'] + additional_context) <= 500: # Limit prompt length
self.policy['base_prompt'] += additional_context
def adapt_forbidden_words(self):
low_scoring_words = set()
threshold = np.median([score for _, score in self.successful_sequences])
for sequence, score in self.successful_sequences[-10:]: # Consider last 10 sequences
if score < threshold:
low_scoring_words.update(sequence)
current_forbidden = set(self.policy.get('forbidden_words', []))
new_forbidden = list(current_forbidden.union(low_scoring_words))[:20] # Limit to 20 forbidden words
self.policy['forbidden_words'] = new_forbidden
class Scorer:
def __init__(self, config):
# Initialize sentiment analysis model
self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
self.sentiment_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
self.sentiment_weight = config['sentiment_weight']
self.grammar_weight = config['grammar_weight']
self.bleu_weight = config['bleu_weight']
self.keyword_weight = config['keyword_weight']
self.length_weight = config['length_weight']
# Initialize language tool for grammar checking
self.language_tool = language_tool_python.LanguageTool('en-US')
# Load your target/reference sentences for BLEU score
self.reference_sentences = [
"The quick brown fox jumps over the lazy dog.",
"A journey of a thousand miles begins with a single step.",
# Add more reference sentences relevant to your task
]
# Keywords that should be rewarded if present
self.keywords = ["innovative", "efficient", "sustainable", "user-friendly"]
def score(self, sequence):
text = ' '.join(sequence)
# Compute individual scores
sentiment_score = self.get_sentiment_score(text)
grammar_score = self.get_grammar_score(text)
bleu_score = self.get_bleu_score(text)
keyword_score = self.get_keyword_score(text)
length_score = self.get_length_score(text)
# Combine scores (you can adjust weights based on importance)
final_score = (
self.sentiment_weight * sentiment_score +
self.grammar_weight * grammar_score +
self.bleu_weight * bleu_score +
self.keyword_weight * keyword_score +
self.length_weight * length_score
)
return final_score
def get_sentiment_score(self, text):
inputs = self.sentiment_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.sentiment_model(**inputs)
scores = torch.nn.functional.softmax(outputs.logits, dim=1)
return scores[:, 1].item() # Positive sentiment score
def get_grammar_score(self, text):
matches = self.language_tool.check(text)
error_count = len(matches)
return max(0, 1 - (error_count / len(text.split()))) # Normalize by word count
def get_bleu_score(self, text):
candidate = word_tokenize(text.lower())
references = [word_tokenize(ref.lower()) for ref in self.reference_sentences]
return sentence_bleu(references, candidate)
def get_keyword_score(self, text):
word_count = len(text.split())
keyword_count = sum(1 for keyword in self.keywords if keyword.lower() in text.lower())
return min(1, keyword_count / (word_count * 0.1)) # Expect about 10% of words to be keywords
def get_length_score(self, text):
ideal_length = 100 # Adjust based on your requirements
actual_length = len(text.split())
return max(0, 1 - abs(actual_length - ideal_length) / ideal_length)
def initialize_policy():
policy = {
'model': load_finetuned_llm('path_to_your_model'),
'base_prompt': "You are an AI assistant tasked with [specific task]. ",
'generation_params': {
'temperature': 0.7,
'top_p': 0.9,
'max_tokens': 100
},
'action_history': [],
'performance_metrics': {
'cumulative_reward': 0,
'successful_sequences': []
}
}
return policy
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
def main(args):
config = load_config(args.config)
beam_nrpa = BeamNRPA(config)
best_sequence, best_score = beam_nrpa.run()
print(f"Best sequence: {' '.join(best_sequence)}")
print(f"Best score: {best_score}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Beam NRPA with a fine-tuned language model.")
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the YAML configuration file')
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment