Last active
December 21, 2024 22:26
-
-
Save grahama1970/8c0dbe2fc6d54dc463eb3aad4cdb1737 to your computer and use it in GitHub Desktop.
tinyllama_model_merge_wip: well I thought I could make this work....maybe leave for another time
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import gc | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from peft import PeftModel | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| pipeline, | |
| LlamaForCausalLM, | |
| BitsAndBytesConfig # Add this import | |
| ) | |
| from loguru import logger | |
| from dotenv import load_dotenv | |
| import json | |
| load_dotenv('../.env') | |
| class LoRAAdapterMerger: | |
| def __init__(self, payload): | |
| self.payload = payload | |
| self.base_model = None | |
| self.quantization = payload.get("adapter_parameters", {}).get("quantization", None) | |
| self.debug_mode = payload.get("debug_mode", False) | |
| load_dotenv() | |
| self.hf_token = os.getenv('HF_TOKEN') | |
| if not self.hf_token: | |
| logger.error("HF_TOKEN not found in environment variables") | |
| raise ValueError("Missing HF_TOKEN in environment variables") | |
| logger.info(f"Initialized with HF token: {self.hf_token[:8]}...") | |
| self.tokenizer = None | |
| self.project_dir = payload.get("project_dir", "/tmp/lorax") | |
| self.default_target_modules = ["q_proj", "v_proj", "o_proj"] | |
| logger.add("merger.log", rotation="500 MB") | |
| def ensure_model_or_adapter_exists(self, model_id): | |
| """Ensure model or adapter exists locally or download it.""" | |
| try: | |
| logger.info(f"Checking for model/adapter `{model_id}`...") | |
| cache_dir = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) | |
| model_cache_dir = os.path.join(cache_dir, "hub", model_id.replace("/", "--")) | |
| # Check if model/adapter already exists in the cache | |
| if os.path.exists(model_cache_dir): | |
| logger.info(f"Found cached `{model_id}` at `{model_cache_dir}`.") | |
| return model_cache_dir | |
| # Download the model/adapter if not cached | |
| logger.info(f"Downloading `{model_id}`...") | |
| snapshot_path = snapshot_download( | |
| repo_id=model_id, | |
| cache_dir=cache_dir, | |
| token=self.hf_token, | |
| allow_patterns=["*.json", "*.bin", "*.safetensors", "*.model", "tokenizer*", "vocab*", "merges*"] | |
| ) | |
| logger.info(f"`{model_id}` downloaded successfully to `{snapshot_path}`.") | |
| return snapshot_path | |
| except Exception as e: | |
| logger.error(f"Failed to download `{model_id}`: {str(e)}") | |
| raise | |
| def validate_adapter(self, adapter_id, base_model_name): | |
| """Validate if the adapter is trained on the expected base model.""" | |
| try: | |
| logger.info(f"Validating adapter `{adapter_id}`...") | |
| adapter_path = self.ensure_model_or_adapter_exists(adapter_id) | |
| config_path = os.path.join(adapter_path, "adapter_config.json") | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"`adapter_config.json` missing for {adapter_id}") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| trained_on = config.get("base_model_name_or_path", "Unknown") | |
| if trained_on.lower() != base_model_name.lower(): | |
| logger.warning( | |
| f"Adapter `{adapter_id}` trained on `{trained_on}`, not `{base_model_name}`. Compatibility issues may occur." | |
| ) | |
| else: | |
| logger.info(f"Adapter `{adapter_id}` is compatible with `{base_model_name}`") | |
| except Exception as e: | |
| logger.error(f"Error validating adapter `{adapter_id}`: {str(e)}") | |
| def resize_model_embeddings(self, adapter_vocab_size): | |
| """ | |
| Resize the base model embeddings to match the adapter vocabulary size. | |
| Args: | |
| adapter_vocab_size (int): The vocabulary size of the adapter. | |
| """ | |
| base_vocab_size = self.base_model.config.vocab_size | |
| if adapter_vocab_size > base_vocab_size: | |
| logger.warning( | |
| f"Resizing base model vocabulary from {base_vocab_size} to {adapter_vocab_size}." | |
| ) | |
| self.base_model.resize_token_embeddings(adapter_vocab_size) | |
| self.tokenizer.add_tokens([f"<extra_token_{i}>" for i in range(adapter_vocab_size - base_vocab_size)]) | |
| elif adapter_vocab_size < base_vocab_size: | |
| logger.error( | |
| f"Adapter vocabulary size ({adapter_vocab_size}) is smaller than the base model's size ({base_vocab_size})." | |
| ) | |
| raise ValueError("Adapter vocabulary size must not be smaller than the base model's vocabulary size.") | |
| def load_adapter(self, adapter_path): | |
| """ | |
| Load the adapter and ensure the base model is resized appropriately. | |
| Args: | |
| adapter_path (str): Path to the adapter directory. | |
| """ | |
| try: | |
| # Load adapter configuration to get vocabulary size | |
| config_path = os.path.join(adapter_path, "adapter_config.json") | |
| with open(config_path, "r") as f: | |
| adapter_config = json.load(f) | |
| # Get vocab sizes | |
| adapter_vocab_size = adapter_config.get("vocab_size", 32005) # Default to larger vocab size | |
| base_vocab_size = self.base_model.config.vocab_size | |
| logger.info(f"Base model vocab size: {base_vocab_size}") | |
| logger.info(f"Adapter vocab size: {adapter_vocab_size}") | |
| # Always resize to the larger vocabulary size | |
| if adapter_vocab_size != base_vocab_size: | |
| logger.info(f"Resizing model vocabulary from {base_vocab_size} to {adapter_vocab_size}") | |
| # Resize both model and tokenizer | |
| self.base_model.resize_token_embeddings(adapter_vocab_size) | |
| # Add new tokens to tokenizer if needed | |
| if adapter_vocab_size > base_vocab_size: | |
| new_tokens = [f"<extra_token_{i}>" for i in range(base_vocab_size, adapter_vocab_size)] | |
| self.tokenizer.add_tokens(new_tokens) | |
| logger.info(f"Model and tokenizer resized to vocabulary size {adapter_vocab_size}") | |
| # Load adapter with correct arguments | |
| logger.info(f"Loading adapter from {adapter_path}...") | |
| self.base_model = PeftModel.from_pretrained( | |
| self.base_model, # The base model to adapt | |
| adapter_path, # Path to the adapter | |
| is_trainable=False, | |
| adapter_name="default" # Add explicit adapter name | |
| ) | |
| logger.info("Adapter loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading adapter: {str(e)}") | |
| raise | |
| def load_resources(self): | |
| """Load base model and tokenizer.""" | |
| logger.info("Loading resources...") | |
| model_name = self.payload["model"] | |
| model_path = self.ensure_model_or_adapter_exists(model_name) | |
| logger.info(f"Loading base model from {model_path}...") | |
| if self.quantization == "8bit": | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| self.base_model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| quantization_config=quantization_config, | |
| low_cpu_mem_usage=True # Explicitly set to load all shards | |
| ) | |
| else: | |
| self.base_model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| low_cpu_mem_usage=True # Explicitly set to load all shards | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| logger.info("Base model and tokenizer loaded successfully.") | |
| def merge_adapters(self): | |
| """ | |
| Merge multiple adapters into the base model. | |
| """ | |
| logger.info("Merging adapters...") | |
| try: | |
| # Load first adapter | |
| first_adapter = self.payload["adapter_parameters"]["adapter_ids"][0] | |
| adapter_path = self.ensure_model_or_adapter_exists(first_adapter) | |
| self.load_adapter(adapter_path) | |
| # Merge subsequent adapters one by one | |
| for adapter_id in self.payload["adapter_parameters"]["adapter_ids"][1:]: | |
| logger.info(f"Merging adapter: {adapter_id}") | |
| adapter_path = self.ensure_model_or_adapter_exists(adapter_id) | |
| # Load adapter config | |
| with open(os.path.join(adapter_path, "adapter_config.json"), "r") as f: | |
| adapter_config = json.load(f) | |
| # Add adapter | |
| self.base_model.load_adapter( | |
| adapter_path, | |
| adapter_name=adapter_id, | |
| is_trainable=False | |
| ) | |
| # Merge all adapters | |
| logger.info("Merging all adapters...") | |
| adapter_names = list(self.base_model.peft_config.keys()) | |
| logger.info(f"Adapters to merge: {adapter_names}") | |
| # Merge and unload | |
| logger.info("Performing final merge...") | |
| merged_model = self.base_model.merge_and_unload() | |
| # Convert to base model type | |
| logger.info("Converting to base model type...") | |
| config = merged_model.config | |
| state_dict = merged_model.state_dict() | |
| base_model = LlamaForCausalLM(config) | |
| missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False) | |
| logger.info(f"Missing keys during conversion: {missing_keys}") | |
| logger.info(f"Unexpected keys during conversion: {unexpected_keys}") | |
| self.base_model = base_model.eval().to("cuda", dtype=torch.float16) | |
| logger.info("Adapters merged successfully.") | |
| # Validate the merged model | |
| self.validate_adapter_usage() | |
| except Exception as e: | |
| logger.error(f"Error during adapter merging: {str(e)}") | |
| raise | |
| def perform_inference(self, prompt): | |
| """Perform inference using the merged model.""" | |
| logger.info("Performing inference on the merged model...") | |
| try: | |
| # Ensure tokenizer is properly configured | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.padding_side = "left" | |
| # Create pipeline with merged model | |
| generator = pipeline( | |
| "text-generation", | |
| model=self.base_model, # Already merged in merge_adapters() | |
| tokenizer=self.tokenizer, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| max_new_tokens=self.payload.get("generation_config", {}).get("max_tokens", 100), | |
| do_sample=self.payload.get("generation_config", {}).get("do_sample", True), | |
| temperature=self.payload.get("generation_config", {}).get("temperature", 0.7), | |
| top_p=self.payload.get("generation_config", {}).get("top_p", 0.9), | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| bos_token_id=self.tokenizer.bos_token_id, | |
| ) | |
| # Format prompt | |
| formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
| response = generator( | |
| formatted_prompt, | |
| return_full_text=False, | |
| num_return_sequences=1, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| logger.info(f"LLM Response: {response}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Inference failed: {str(e)}") | |
| raise | |
| def validate_adapter_usage(self): | |
| """Validate that adapters are actually being used in the model.""" | |
| logger.info("Validating adapter usage...") | |
| # 1. Check if model is actually a PeftModel | |
| if not isinstance(self.base_model, PeftModel): | |
| raise ValueError("Model is not a PeftModel - adapters were not properly loaded") | |
| # 2. Check active adapters | |
| active_adapters = self.base_model.active_adapters | |
| logger.info(f"Active adapters: {active_adapters}") | |
| if not active_adapters: | |
| raise ValueError("No active adapters found") | |
| # 3. Check model parameters for LoRA weights | |
| lora_params = 0 | |
| base_params = 0 | |
| for name, param in self.base_model.named_parameters(): | |
| if 'lora_' in name: | |
| lora_params += 1 | |
| logger.info(f"Found LoRA parameter: {name} with shape {param.shape}") | |
| else: | |
| base_params += 1 | |
| logger.info(f"Found {lora_params} LoRA parameters and {base_params} base parameters") | |
| # 4. Check if target modules were modified | |
| target_modules = self.payload["adapter_parameters"]["target_modules"] | |
| modified_modules = [] | |
| for name, module in self.base_model.named_modules(): | |
| if any(target in name for target in target_modules): | |
| if hasattr(module, 'active_adapter'): | |
| modified_modules.append(name) | |
| logger.info(f"Module {name} has active adapter: {module.active_adapter}") | |
| logger.info(f"Modified modules: {modified_modules}") | |
| if not modified_modules: | |
| raise ValueError("No target modules were modified by adapters") | |
| return True | |
| def execute(self): | |
| """Main execution logic.""" | |
| try: | |
| logger.info("Starting execution...") | |
| self.load_resources() | |
| # Validate adapters before merging | |
| for adapter_id in self.payload["adapter_parameters"]["adapter_ids"]: | |
| self.validate_adapter(adapter_id, self.payload["model"]) | |
| # Merge adapters | |
| self.merge_adapters() | |
| # Validate adapter usage | |
| self.validate_adapter_usage() | |
| # Do multiple inference tests with different prompts | |
| prompts = [ | |
| """<s> Table: 2-11365528-2 | |
| Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location'] | |
| Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic? | |
| SQL Query:""", | |
| # # Add a prompt that should trigger finance knowledge | |
| # """<s> What is the difference between EBITDA and net income?""", | |
| # # Add a prompt about rugby | |
| # """<s> How many players are on a touch rugby team?""" | |
| ] | |
| responses = [] | |
| for prompt in prompts: | |
| logger.info(f"\nTesting with prompt: {prompt}") | |
| response = self.perform_inference(prompt) | |
| responses.append(response) | |
| except Exception as e: | |
| logger.error(f"Execution failed: {str(e)}") | |
| raise | |
| finally: | |
| del self.base_model | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Example Usage | |
| if __name__ == "__main__": | |
| payload = { | |
| "project_dir": "/tmp/lorax", | |
| "debug_mode": True, | |
| "model": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", | |
| "adapter_parameters": { | |
| "adapter_ids": [ | |
| "smangrul/tinyllama_lora_norobots", | |
| "smangrul/tinyllama_lora_sql", | |
| "smangrul/tinyllama_lora_adcopy", | |
| ], | |
| "weights": [1.0, 1.0, 1.0], | |
| "target_modules": ["q_proj", "v_proj", "o_proj"], | |
| "merge_strategy": "linear", | |
| "quantization": "8bit" | |
| }, | |
| "generation_config": { | |
| "max_tokens": 100, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "num_return_sequences": 1, | |
| "do_sample": True, | |
| } | |
| } | |
| merger = LoRAAdapterMerger(payload) | |
| merger.execute() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment