Created
August 17, 2024 07:32
-
-
Save TheodoreGalanos/71801d3a1b8fd7a866aaae15b0be85ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import heapq | |
import argparse | |
import numpy as np | |
import language_tool_python | |
from functools import lru_cache | |
from collections import Counter | |
from nltk.translate.bleu_score import sentence_bleu | |
from nltk.tokenize import word_tokenize | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM | |
class BeamNRPA: | |
def __init__(self, config): | |
self.model, self.tokenizer = load_finetuned_llm(config['model']['path_or_name']) | |
initial_policy = { | |
'model': self.model, | |
'tokenizer': self.tokenizer, | |
'base_prompt': config['model']['initial_prompt'], | |
'generation_params': config['generation_params'], | |
'max_context_length': min(self.model.config.max_position_embeddings, 1024), | |
'forbidden_words': [], | |
} | |
self.policy_adapter = PolicyAdapter(initial_policy) | |
self.action_sampler = ActionSampler(self.policy_adapter.policy) | |
self.scorer = Scorer(config['scorer']) | |
self.max_level = config['beam_nrpa']['max_level'] | |
self.beam_width = config['beam_nrpa']['beam_width'] | |
self.adapt_iterations = config['beam_nrpa']['adapt_iterations'] | |
self.rollout_limit = config['beam_nrpa']['rollout_limit'] | |
self.best_solution = None | |
def nested_beam_search(self, level): | |
if level == 0: | |
return [self.rollout() for _ in range(self.beam_width)] | |
next_beam = [] | |
for _ in range(self.beam_width): | |
sequence = [] | |
for _ in range(self.rollout_limit): | |
action = self.action_sampler.sample_action([sequence])[0] | |
sequence.append(action) | |
child_beam = self.nested_beam_search(level - 1) | |
best_child = max(child_beam, key=lambda x: x[1]) | |
sequence.extend(best_child[0]) | |
score = self.scorer.score(sequence) | |
heapq.heappush(next_beam, (-score, sequence)) | |
if len(next_beam) > self.beam_width: | |
heapq.heappop(next_beam) | |
return [(-score, seq) for score, seq in next_beam] | |
def rollout(self): | |
sequence = [] | |
for _ in range(self.rollout_limit): | |
action = self.action_sampler.sample_action([sequence])[0] | |
sequence.append(action) | |
score = self.scorer.score(sequence) | |
return (sequence, score) | |
def run(self): | |
for _ in range(self.adapt_iterations): | |
beam = self.nested_beam_search(self.max_level) | |
best_in_iteration = max(beam, key=lambda x: x[0]) | |
if self.best_solution is None or best_in_iteration[0] > self.best_solution[0]: | |
self.best_solution = best_in_iteration | |
self.policy_adapter.adapt_policy(best_in_iteration[1], best_in_iteration[0]) | |
self.action_sampler.policy = self.policy_adapter.policy | |
return self.best_solution | |
def load_finetuned_llm(model_path_or_name): | |
""" | |
Load a fine-tuned language model from a local path or Hugging Face model hub. | |
Args: | |
model_path_or_name (str): Local path or Hugging Face model name | |
Returns: | |
tuple: (model, tokenizer) | |
""" | |
try: | |
print(f"Loading model from {model_path_or_name}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) | |
# Load model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path_or_name, | |
device_map="auto", # Automatically choose the best device setup | |
torch_dtype=torch.float16 # Use float16 for efficiency if supported | |
) | |
print(f"Model loaded successfully. Model type: {type(model).__name__}") | |
return model, tokenizer | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
class ActionSampler: | |
def __init__(self, policy): | |
self.policy = policy | |
self.model = policy['model'] | |
self.tokenizer = policy['tokenizer'] | |
self.cache = {} | |
@lru_cache(maxsize=1000) | |
def _tokenize(self, text): | |
return self.tokenizer(text, return_tensors="pt") | |
def _truncate_inputs(self, inputs): | |
max_length = self.policy['max_context_length'] | |
if inputs.input_ids.shape[1] > max_length: | |
inputs.input_ids = inputs.input_ids[:, -max_length:] | |
inputs.attention_mask = inputs.attention_mask[:, -max_length:] | |
return inputs | |
def _adaptive_temperature(self, sequence): | |
# Implement adaptive temperature based on sequence length or other factors | |
base_temp = self.policy['generation_params']['temperature'] | |
return base_temp * (1 + len(sequence) * 0.01) # Example: increase temperature slightly as sequence grows | |
def _constrained_decoding(self, outputs): | |
# Implement constrained decoding | |
forbidden_words = self.policy.get('forbidden_words', []) | |
for word in forbidden_words: | |
outputs = outputs.filter(lambda x: word not in x) | |
return outputs | |
def sample_action(self, sequences): | |
try: | |
full_prompts = [self.policy['base_prompt'] + ' '.join(seq) for seq in sequences] | |
inputs = self.tokenizer(full_prompts, return_tensors="pt", padding=True, truncation=True) | |
# Move inputs to the same device as the model | |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
do_sample=True, | |
**self.policy['generation_params'], | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode and post-process | |
generated_texts = self.tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
generated_texts = [self.post_process(text) for text in generated_texts] | |
return generated_texts | |
except Exception as e: | |
print(f"Error in sample_action: {str(e)}") | |
return ["" for _ in sequences] # Return empty strings in case of error | |
def post_process(self, text): | |
# Implement post-processing | |
text = text.strip() | |
# Add more post-processing steps as needed | |
return text | |
class PolicyAdapter: | |
def __init__(self, initial_policy): | |
self.policy = initial_policy | |
self.adaptation_rate = 0.1 | |
self.successful_sequences = [] | |
self.token_frequencies = Counter() | |
def adapt_policy(self, successful_sequence, score): | |
self.successful_sequences.append((successful_sequence, score)) | |
# Update token frequencies | |
self.update_token_frequencies(successful_sequence) | |
# Adapt generation parameters | |
self.adapt_generation_params(score) | |
# Adapt base prompt | |
self.adapt_base_prompt() | |
# Adapt forbidden words | |
self.adapt_forbidden_words() | |
return self.policy | |
def update_token_frequencies(self, sequence): | |
for token in sequence: | |
self.token_frequencies[token] += 1 | |
def adjust_param(self, current_value, score, lower_bound, upper_bound): | |
if score > self.policy.get('average_score', 0): | |
new_value = current_value * (1 - self.adaptation_rate) | |
else: | |
new_value = current_value * (1 + self.adaptation_rate) | |
return max(lower_bound, min(new_value, upper_bound)) | |
def adapt_generation_params(self, score): | |
params = self.policy['generation_params'] | |
# Adjust temperature | |
params['temperature'] = self.adjust_param( | |
params['temperature'], score, lower_bound=0.1, upper_bound=2.0 | |
) | |
# Adjust top_p | |
params['top_p'] = self.adjust_param( | |
params['top_p'], score, lower_bound=0.1, upper_bound=1.0 | |
) | |
def adapt_base_prompt(self): | |
most_common = self.token_frequencies.most_common(5) | |
frequent_tokens = [token for token, _ in most_common] | |
additional_context = f" Consider using these relevant terms: {', '.join(frequent_tokens)}." | |
if len(self.policy['base_prompt'] + additional_context) <= 500: # Limit prompt length | |
self.policy['base_prompt'] += additional_context | |
def adapt_forbidden_words(self): | |
low_scoring_words = set() | |
threshold = np.median([score for _, score in self.successful_sequences]) | |
for sequence, score in self.successful_sequences[-10:]: # Consider last 10 sequences | |
if score < threshold: | |
low_scoring_words.update(sequence) | |
current_forbidden = set(self.policy.get('forbidden_words', [])) | |
new_forbidden = list(current_forbidden.union(low_scoring_words))[:20] # Limit to 20 forbidden words | |
self.policy['forbidden_words'] = new_forbidden | |
class Scorer: | |
def __init__(self, config): | |
# Initialize sentiment analysis model | |
self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
self.sentiment_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
self.sentiment_weight = config['sentiment_weight'] | |
self.grammar_weight = config['grammar_weight'] | |
self.bleu_weight = config['bleu_weight'] | |
self.keyword_weight = config['keyword_weight'] | |
self.length_weight = config['length_weight'] | |
# Initialize language tool for grammar checking | |
self.language_tool = language_tool_python.LanguageTool('en-US') | |
# Load your target/reference sentences for BLEU score | |
self.reference_sentences = [ | |
"The quick brown fox jumps over the lazy dog.", | |
"A journey of a thousand miles begins with a single step.", | |
# Add more reference sentences relevant to your task | |
] | |
# Keywords that should be rewarded if present | |
self.keywords = ["innovative", "efficient", "sustainable", "user-friendly"] | |
def score(self, sequence): | |
text = ' '.join(sequence) | |
# Compute individual scores | |
sentiment_score = self.get_sentiment_score(text) | |
grammar_score = self.get_grammar_score(text) | |
bleu_score = self.get_bleu_score(text) | |
keyword_score = self.get_keyword_score(text) | |
length_score = self.get_length_score(text) | |
# Combine scores (you can adjust weights based on importance) | |
final_score = ( | |
self.sentiment_weight * sentiment_score + | |
self.grammar_weight * grammar_score + | |
self.bleu_weight * bleu_score + | |
self.keyword_weight * keyword_score + | |
self.length_weight * length_score | |
) | |
return final_score | |
def get_sentiment_score(self, text): | |
inputs = self.sentiment_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.sentiment_model(**inputs) | |
scores = torch.nn.functional.softmax(outputs.logits, dim=1) | |
return scores[:, 1].item() # Positive sentiment score | |
def get_grammar_score(self, text): | |
matches = self.language_tool.check(text) | |
error_count = len(matches) | |
return max(0, 1 - (error_count / len(text.split()))) # Normalize by word count | |
def get_bleu_score(self, text): | |
candidate = word_tokenize(text.lower()) | |
references = [word_tokenize(ref.lower()) for ref in self.reference_sentences] | |
return sentence_bleu(references, candidate) | |
def get_keyword_score(self, text): | |
word_count = len(text.split()) | |
keyword_count = sum(1 for keyword in self.keywords if keyword.lower() in text.lower()) | |
return min(1, keyword_count / (word_count * 0.1)) # Expect about 10% of words to be keywords | |
def get_length_score(self, text): | |
ideal_length = 100 # Adjust based on your requirements | |
actual_length = len(text.split()) | |
return max(0, 1 - abs(actual_length - ideal_length) / ideal_length) | |
def initialize_policy(): | |
policy = { | |
'model': load_finetuned_llm('path_to_your_model'), | |
'base_prompt': "You are an AI assistant tasked with [specific task]. ", | |
'generation_params': { | |
'temperature': 0.7, | |
'top_p': 0.9, | |
'max_tokens': 100 | |
}, | |
'action_history': [], | |
'performance_metrics': { | |
'cumulative_reward': 0, | |
'successful_sequences': [] | |
} | |
} | |
return policy | |
def load_config(config_path): | |
with open(config_path, 'r') as file: | |
return yaml.safe_load(file) | |
def main(args): | |
config = load_config(args.config) | |
beam_nrpa = BeamNRPA(config) | |
best_sequence, best_score = beam_nrpa.run() | |
print(f"Best sequence: {' '.join(best_sequence)}") | |
print(f"Best score: {best_score}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run Beam NRPA with a fine-tuned language model.") | |
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the YAML configuration file') | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment