Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created October 3, 2024 02:54
Show Gist options
  • Save pszemraj/31bf6fbc3fa6b22e247ab63c7f8fa479 to your computer and use it in GitHub Desktop.
Save pszemraj/31bf6fbc3fa6b22e247ab63c7f8fa479 to your computer and use it in GitHub Desktop.
compare two t5 models of the same arch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
# Define model names for base and fine-tuned versions
base_model_name = "pszemraj/tFINE-900m-e16-d32-1024ctx"
ft_model_name = "BEE-spoke-data/tFINE-900m-e16-d32-instruct_2e"
# Load the models with appropriate dtype
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, torch_dtype=torch.float32)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.float32)
# Function to calculate weight difference
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
# Function to calculate layer-wise weight differences for encoder and decoder
def calculate_layer_diffs(base_model, ft_model):
encoder_diffs = []
decoder_diffs = []
# Calculate encoder layer differences
for base_layer, chat_layer in zip(base_model.encoder.block, ft_model.encoder.block):
layer_diff = {
'SelfAttention.q': calculate_weight_diff(base_layer.layer[0].SelfAttention.q.weight, chat_layer.layer[0].SelfAttention.q.weight),
'SelfAttention.k': calculate_weight_diff(base_layer.layer[0].SelfAttention.k.weight, chat_layer.layer[0].SelfAttention.k.weight),
'SelfAttention.v': calculate_weight_diff(base_layer.layer[0].SelfAttention.v.weight, chat_layer.layer[0].SelfAttention.v.weight),
'SelfAttention.o': calculate_weight_diff(base_layer.layer[0].SelfAttention.o.weight, chat_layer.layer[0].SelfAttention.o.weight),
'DenseReluDense.wi_0': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wi_0.weight, chat_layer.layer[1].DenseReluDense.wi_0.weight),
'DenseReluDense.wi_1': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wi_1.weight, chat_layer.layer[1].DenseReluDense.wi_1.weight),
'DenseReluDense.wo': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wo.weight, chat_layer.layer[1].DenseReluDense.wo.weight),
'layer_norm': calculate_weight_diff(base_layer.layer[1].layer_norm.weight, chat_layer.layer[1].layer_norm.weight)
}
encoder_diffs.append(layer_diff)
# Calculate decoder layer differences
for base_layer, chat_layer in zip(base_model.decoder.block, ft_model.decoder.block):
layer_diff = {
'SelfAttention.q': calculate_weight_diff(base_layer.layer[0].SelfAttention.q.weight, chat_layer.layer[0].SelfAttention.q.weight),
'SelfAttention.k': calculate_weight_diff(base_layer.layer[0].SelfAttention.k.weight, chat_layer.layer[0].SelfAttention.k.weight),
'SelfAttention.v': calculate_weight_diff(base_layer.layer[0].SelfAttention.v.weight, chat_layer.layer[0].SelfAttention.v.weight),
'SelfAttention.o': calculate_weight_diff(base_layer.layer[0].SelfAttention.o.weight, chat_layer.layer[0].SelfAttention.o.weight),
'EncDecAttention.q': calculate_weight_diff(base_layer.layer[1].EncDecAttention.q.weight, chat_layer.layer[1].EncDecAttention.q.weight),
'EncDecAttention.k': calculate_weight_diff(base_layer.layer[1].EncDecAttention.k.weight, chat_layer.layer[1].EncDecAttention.k.weight),
'EncDecAttention.v': calculate_weight_diff(base_layer.layer[1].EncDecAttention.v.weight, chat_layer.layer[1].EncDecAttention.v.weight),
'EncDecAttention.o': calculate_weight_diff(base_layer.layer[1].EncDecAttention.o.weight, chat_layer.layer[1].EncDecAttention.o.weight),
'DenseReluDense.wi_0': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wi_0.weight, chat_layer.layer[2].DenseReluDense.wi_0.weight),
'DenseReluDense.wi_1': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wi_1.weight, chat_layer.layer[2].DenseReluDense.wi_1.weight),
'DenseReluDense.wo': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wo.weight, chat_layer.layer[2].DenseReluDense.wo.weight),
'layer_norm': calculate_weight_diff(base_layer.layer[2].layer_norm.weight, chat_layer.layer[2].layer_norm.weight)
}
decoder_diffs.append(layer_diff)
return encoder_diffs, decoder_diffs
# Visualization function for encoder and decoder separately
def visualize_layer_diffs(encoder_diffs, decoder_diffs):
def plot_layer_diffs(layer_diffs, title):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(title, fontsize=16)
for i, component in enumerate(layer_diffs[0].keys()):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(
component_diffs,
annot=True,
fmt=".6f",
cmap="YlGnBu",
ax=axs[i],
cbar_kws={"shrink": 0.8},
)
axs[i].set_title(component)
axs[i].set_xlabel("Component")
axs[i].set_ylabel("Layer Idx")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
plt.show()
# Plot encoder and decoder differences separately
plot_layer_diffs(encoder_diffs, "Encoder Layer Weight Differences")
plot_layer_diffs(decoder_diffs, "Decoder Layer Weight Differences")
# Calculate and visualize layer differences
encoder_diffs, decoder_diffs = calculate_layer_diffs(base_model, ft_model)
visualize_layer_diffs(encoder_diffs, decoder_diffs)
@pszemraj
Copy link
Author

pszemraj commented Oct 3, 2024

@pszemraj
Copy link
Author

pszemraj commented Oct 3, 2024

output

base model: pszemraj/tFINE-900m-e16-d32-1024ctx

fine-tuned model: BEE-spoke-data/tFINE-900m-e16-d32-instruct_2e

encoder

image

decoder

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment