Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active December 21, 2024 22:26
Show Gist options
  • Save grahama1970/8c0dbe2fc6d54dc463eb3aad4cdb1737 to your computer and use it in GitHub Desktop.
Save grahama1970/8c0dbe2fc6d54dc463eb3aad4cdb1737 to your computer and use it in GitHub Desktop.
tinyllama_model_merge_wip: well I thought I could make this work....maybe leave for another time
import os
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
from huggingface_hub import snapshot_download
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
LlamaForCausalLM,
BitsAndBytesConfig # Add this import
)
from loguru import logger
from dotenv import load_dotenv
import json
load_dotenv('../.env')
class LoRAAdapterMerger:
def __init__(self, payload):
self.payload = payload
self.base_model = None
self.quantization = payload.get("adapter_parameters", {}).get("quantization", None)
self.debug_mode = payload.get("debug_mode", False)
load_dotenv()
self.hf_token = os.getenv('HF_TOKEN')
if not self.hf_token:
logger.error("HF_TOKEN not found in environment variables")
raise ValueError("Missing HF_TOKEN in environment variables")
logger.info(f"Initialized with HF token: {self.hf_token[:8]}...")
self.tokenizer = None
self.project_dir = payload.get("project_dir", "/tmp/lorax")
self.default_target_modules = ["q_proj", "v_proj", "o_proj"]
logger.add("merger.log", rotation="500 MB")
def ensure_model_or_adapter_exists(self, model_id):
"""Ensure model or adapter exists locally or download it."""
try:
logger.info(f"Checking for model/adapter `{model_id}`...")
cache_dir = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
model_cache_dir = os.path.join(cache_dir, "hub", model_id.replace("/", "--"))
# Check if model/adapter already exists in the cache
if os.path.exists(model_cache_dir):
logger.info(f"Found cached `{model_id}` at `{model_cache_dir}`.")
return model_cache_dir
# Download the model/adapter if not cached
logger.info(f"Downloading `{model_id}`...")
snapshot_path = snapshot_download(
repo_id=model_id,
cache_dir=cache_dir,
token=self.hf_token,
allow_patterns=["*.json", "*.bin", "*.safetensors", "*.model", "tokenizer*", "vocab*", "merges*"]
)
logger.info(f"`{model_id}` downloaded successfully to `{snapshot_path}`.")
return snapshot_path
except Exception as e:
logger.error(f"Failed to download `{model_id}`: {str(e)}")
raise
def validate_adapter(self, adapter_id, base_model_name):
"""Validate if the adapter is trained on the expected base model."""
try:
logger.info(f"Validating adapter `{adapter_id}`...")
adapter_path = self.ensure_model_or_adapter_exists(adapter_id)
config_path = os.path.join(adapter_path, "adapter_config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"`adapter_config.json` missing for {adapter_id}")
with open(config_path, "r") as f:
config = json.load(f)
trained_on = config.get("base_model_name_or_path", "Unknown")
if trained_on.lower() != base_model_name.lower():
logger.warning(
f"Adapter `{adapter_id}` trained on `{trained_on}`, not `{base_model_name}`. Compatibility issues may occur."
)
else:
logger.info(f"Adapter `{adapter_id}` is compatible with `{base_model_name}`")
except Exception as e:
logger.error(f"Error validating adapter `{adapter_id}`: {str(e)}")
def resize_model_embeddings(self, adapter_vocab_size):
"""
Resize the base model embeddings to match the adapter vocabulary size.
Args:
adapter_vocab_size (int): The vocabulary size of the adapter.
"""
base_vocab_size = self.base_model.config.vocab_size
if adapter_vocab_size > base_vocab_size:
logger.warning(
f"Resizing base model vocabulary from {base_vocab_size} to {adapter_vocab_size}."
)
self.base_model.resize_token_embeddings(adapter_vocab_size)
self.tokenizer.add_tokens([f"<extra_token_{i}>" for i in range(adapter_vocab_size - base_vocab_size)])
elif adapter_vocab_size < base_vocab_size:
logger.error(
f"Adapter vocabulary size ({adapter_vocab_size}) is smaller than the base model's size ({base_vocab_size})."
)
raise ValueError("Adapter vocabulary size must not be smaller than the base model's vocabulary size.")
def load_adapter(self, adapter_path):
"""
Load the adapter and ensure the base model is resized appropriately.
Args:
adapter_path (str): Path to the adapter directory.
"""
try:
# Load adapter configuration to get vocabulary size
config_path = os.path.join(adapter_path, "adapter_config.json")
with open(config_path, "r") as f:
adapter_config = json.load(f)
# Get vocab sizes
adapter_vocab_size = adapter_config.get("vocab_size", 32005) # Default to larger vocab size
base_vocab_size = self.base_model.config.vocab_size
logger.info(f"Base model vocab size: {base_vocab_size}")
logger.info(f"Adapter vocab size: {adapter_vocab_size}")
# Always resize to the larger vocabulary size
if adapter_vocab_size != base_vocab_size:
logger.info(f"Resizing model vocabulary from {base_vocab_size} to {adapter_vocab_size}")
# Resize both model and tokenizer
self.base_model.resize_token_embeddings(adapter_vocab_size)
# Add new tokens to tokenizer if needed
if adapter_vocab_size > base_vocab_size:
new_tokens = [f"<extra_token_{i}>" for i in range(base_vocab_size, adapter_vocab_size)]
self.tokenizer.add_tokens(new_tokens)
logger.info(f"Model and tokenizer resized to vocabulary size {adapter_vocab_size}")
# Load adapter with correct arguments
logger.info(f"Loading adapter from {adapter_path}...")
self.base_model = PeftModel.from_pretrained(
self.base_model, # The base model to adapt
adapter_path, # Path to the adapter
is_trainable=False,
adapter_name="default" # Add explicit adapter name
)
logger.info("Adapter loaded successfully")
except Exception as e:
logger.error(f"Error loading adapter: {str(e)}")
raise
def load_resources(self):
"""Load base model and tokenizer."""
logger.info("Loading resources...")
model_name = self.payload["model"]
model_path = self.ensure_model_or_adapter_exists(model_name)
logger.info(f"Loading base model from {model_path}...")
if self.quantization == "8bit":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
quantization_config=quantization_config,
low_cpu_mem_usage=True # Explicitly set to load all shards
)
else:
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
low_cpu_mem_usage=True # Explicitly set to load all shards
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Base model and tokenizer loaded successfully.")
def merge_adapters(self):
"""
Merge multiple adapters into the base model.
"""
logger.info("Merging adapters...")
try:
# Load first adapter
first_adapter = self.payload["adapter_parameters"]["adapter_ids"][0]
adapter_path = self.ensure_model_or_adapter_exists(first_adapter)
self.load_adapter(adapter_path)
# Merge subsequent adapters one by one
for adapter_id in self.payload["adapter_parameters"]["adapter_ids"][1:]:
logger.info(f"Merging adapter: {adapter_id}")
adapter_path = self.ensure_model_or_adapter_exists(adapter_id)
# Load adapter config
with open(os.path.join(adapter_path, "adapter_config.json"), "r") as f:
adapter_config = json.load(f)
# Add adapter
self.base_model.load_adapter(
adapter_path,
adapter_name=adapter_id,
is_trainable=False
)
# Merge all adapters
logger.info("Merging all adapters...")
adapter_names = list(self.base_model.peft_config.keys())
logger.info(f"Adapters to merge: {adapter_names}")
# Merge and unload
logger.info("Performing final merge...")
merged_model = self.base_model.merge_and_unload()
# Convert to base model type
logger.info("Converting to base model type...")
config = merged_model.config
state_dict = merged_model.state_dict()
base_model = LlamaForCausalLM(config)
missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False)
logger.info(f"Missing keys during conversion: {missing_keys}")
logger.info(f"Unexpected keys during conversion: {unexpected_keys}")
self.base_model = base_model.eval().to("cuda", dtype=torch.float16)
logger.info("Adapters merged successfully.")
# Validate the merged model
self.validate_adapter_usage()
except Exception as e:
logger.error(f"Error during adapter merging: {str(e)}")
raise
def perform_inference(self, prompt):
"""Perform inference using the merged model."""
logger.info("Performing inference on the merged model...")
try:
# Ensure tokenizer is properly configured
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
# Create pipeline with merged model
generator = pipeline(
"text-generation",
model=self.base_model, # Already merged in merge_adapters()
tokenizer=self.tokenizer,
torch_dtype=torch.float16,
device_map="auto",
max_new_tokens=self.payload.get("generation_config", {}).get("max_tokens", 100),
do_sample=self.payload.get("generation_config", {}).get("do_sample", True),
temperature=self.payload.get("generation_config", {}).get("temperature", 0.7),
top_p=self.payload.get("generation_config", {}).get("top_p", 0.9),
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
)
# Format prompt
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
response = generator(
formatted_prompt,
return_full_text=False,
num_return_sequences=1,
clean_up_tokenization_spaces=True
)
logger.info(f"LLM Response: {response}")
return response
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
def validate_adapter_usage(self):
"""Validate that adapters are actually being used in the model."""
logger.info("Validating adapter usage...")
# 1. Check if model is actually a PeftModel
if not isinstance(self.base_model, PeftModel):
raise ValueError("Model is not a PeftModel - adapters were not properly loaded")
# 2. Check active adapters
active_adapters = self.base_model.active_adapters
logger.info(f"Active adapters: {active_adapters}")
if not active_adapters:
raise ValueError("No active adapters found")
# 3. Check model parameters for LoRA weights
lora_params = 0
base_params = 0
for name, param in self.base_model.named_parameters():
if 'lora_' in name:
lora_params += 1
logger.info(f"Found LoRA parameter: {name} with shape {param.shape}")
else:
base_params += 1
logger.info(f"Found {lora_params} LoRA parameters and {base_params} base parameters")
# 4. Check if target modules were modified
target_modules = self.payload["adapter_parameters"]["target_modules"]
modified_modules = []
for name, module in self.base_model.named_modules():
if any(target in name for target in target_modules):
if hasattr(module, 'active_adapter'):
modified_modules.append(name)
logger.info(f"Module {name} has active adapter: {module.active_adapter}")
logger.info(f"Modified modules: {modified_modules}")
if not modified_modules:
raise ValueError("No target modules were modified by adapters")
return True
def execute(self):
"""Main execution logic."""
try:
logger.info("Starting execution...")
self.load_resources()
# Validate adapters before merging
for adapter_id in self.payload["adapter_parameters"]["adapter_ids"]:
self.validate_adapter(adapter_id, self.payload["model"])
# Merge adapters
self.merge_adapters()
# Validate adapter usage
self.validate_adapter_usage()
# Do multiple inference tests with different prompts
prompts = [
"""<s> Table: 2-11365528-2
Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location']
Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic?
SQL Query:""",
# # Add a prompt that should trigger finance knowledge
# """<s> What is the difference between EBITDA and net income?""",
# # Add a prompt about rugby
# """<s> How many players are on a touch rugby team?"""
]
responses = []
for prompt in prompts:
logger.info(f"\nTesting with prompt: {prompt}")
response = self.perform_inference(prompt)
responses.append(response)
except Exception as e:
logger.error(f"Execution failed: {str(e)}")
raise
finally:
del self.base_model
torch.cuda.empty_cache()
gc.collect()
# Example Usage
if __name__ == "__main__":
payload = {
"project_dir": "/tmp/lorax",
"debug_mode": True,
"model": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
"adapter_parameters": {
"adapter_ids": [
"smangrul/tinyllama_lora_norobots",
"smangrul/tinyllama_lora_sql",
"smangrul/tinyllama_lora_adcopy",
],
"weights": [1.0, 1.0, 1.0],
"target_modules": ["q_proj", "v_proj", "o_proj"],
"merge_strategy": "linear",
"quantization": "8bit"
},
"generation_config": {
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"num_return_sequences": 1,
"do_sample": True,
}
}
merger = LoRAAdapterMerger(payload)
merger.execute()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment