Created
October 3, 2024 02:54
-
-
Save pszemraj/31bf6fbc3fa6b22e247ab63c7f8fa479 to your computer and use it in GitHub Desktop.
compare two t5 models of the same arch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Define model names for base and fine-tuned versions | |
base_model_name = "pszemraj/tFINE-900m-e16-d32-1024ctx" | |
ft_model_name = "BEE-spoke-data/tFINE-900m-e16-d32-instruct_2e" | |
# Load the models with appropriate dtype | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, torch_dtype=torch.float32) | |
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.float32) | |
# Function to calculate weight difference | |
def calculate_weight_diff(base_weight, chat_weight): | |
return torch.abs(base_weight - chat_weight).mean().item() | |
# Function to calculate layer-wise weight differences for encoder and decoder | |
def calculate_layer_diffs(base_model, ft_model): | |
encoder_diffs = [] | |
decoder_diffs = [] | |
# Calculate encoder layer differences | |
for base_layer, chat_layer in zip(base_model.encoder.block, ft_model.encoder.block): | |
layer_diff = { | |
'SelfAttention.q': calculate_weight_diff(base_layer.layer[0].SelfAttention.q.weight, chat_layer.layer[0].SelfAttention.q.weight), | |
'SelfAttention.k': calculate_weight_diff(base_layer.layer[0].SelfAttention.k.weight, chat_layer.layer[0].SelfAttention.k.weight), | |
'SelfAttention.v': calculate_weight_diff(base_layer.layer[0].SelfAttention.v.weight, chat_layer.layer[0].SelfAttention.v.weight), | |
'SelfAttention.o': calculate_weight_diff(base_layer.layer[0].SelfAttention.o.weight, chat_layer.layer[0].SelfAttention.o.weight), | |
'DenseReluDense.wi_0': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wi_0.weight, chat_layer.layer[1].DenseReluDense.wi_0.weight), | |
'DenseReluDense.wi_1': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wi_1.weight, chat_layer.layer[1].DenseReluDense.wi_1.weight), | |
'DenseReluDense.wo': calculate_weight_diff(base_layer.layer[1].DenseReluDense.wo.weight, chat_layer.layer[1].DenseReluDense.wo.weight), | |
'layer_norm': calculate_weight_diff(base_layer.layer[1].layer_norm.weight, chat_layer.layer[1].layer_norm.weight) | |
} | |
encoder_diffs.append(layer_diff) | |
# Calculate decoder layer differences | |
for base_layer, chat_layer in zip(base_model.decoder.block, ft_model.decoder.block): | |
layer_diff = { | |
'SelfAttention.q': calculate_weight_diff(base_layer.layer[0].SelfAttention.q.weight, chat_layer.layer[0].SelfAttention.q.weight), | |
'SelfAttention.k': calculate_weight_diff(base_layer.layer[0].SelfAttention.k.weight, chat_layer.layer[0].SelfAttention.k.weight), | |
'SelfAttention.v': calculate_weight_diff(base_layer.layer[0].SelfAttention.v.weight, chat_layer.layer[0].SelfAttention.v.weight), | |
'SelfAttention.o': calculate_weight_diff(base_layer.layer[0].SelfAttention.o.weight, chat_layer.layer[0].SelfAttention.o.weight), | |
'EncDecAttention.q': calculate_weight_diff(base_layer.layer[1].EncDecAttention.q.weight, chat_layer.layer[1].EncDecAttention.q.weight), | |
'EncDecAttention.k': calculate_weight_diff(base_layer.layer[1].EncDecAttention.k.weight, chat_layer.layer[1].EncDecAttention.k.weight), | |
'EncDecAttention.v': calculate_weight_diff(base_layer.layer[1].EncDecAttention.v.weight, chat_layer.layer[1].EncDecAttention.v.weight), | |
'EncDecAttention.o': calculate_weight_diff(base_layer.layer[1].EncDecAttention.o.weight, chat_layer.layer[1].EncDecAttention.o.weight), | |
'DenseReluDense.wi_0': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wi_0.weight, chat_layer.layer[2].DenseReluDense.wi_0.weight), | |
'DenseReluDense.wi_1': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wi_1.weight, chat_layer.layer[2].DenseReluDense.wi_1.weight), | |
'DenseReluDense.wo': calculate_weight_diff(base_layer.layer[2].DenseReluDense.wo.weight, chat_layer.layer[2].DenseReluDense.wo.weight), | |
'layer_norm': calculate_weight_diff(base_layer.layer[2].layer_norm.weight, chat_layer.layer[2].layer_norm.weight) | |
} | |
decoder_diffs.append(layer_diff) | |
return encoder_diffs, decoder_diffs | |
# Visualization function for encoder and decoder separately | |
def visualize_layer_diffs(encoder_diffs, decoder_diffs): | |
def plot_layer_diffs(layer_diffs, title): | |
num_layers = len(layer_diffs) | |
num_components = len(layer_diffs[0]) | |
fig, axs = plt.subplots(1, num_components, figsize=(24, 8)) | |
fig.suptitle(title, fontsize=16) | |
for i, component in enumerate(layer_diffs[0].keys()): | |
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs] | |
sns.heatmap( | |
component_diffs, | |
annot=True, | |
fmt=".6f", | |
cmap="YlGnBu", | |
ax=axs[i], | |
cbar_kws={"shrink": 0.8}, | |
) | |
axs[i].set_title(component) | |
axs[i].set_xlabel("Component") | |
axs[i].set_ylabel("Layer Idx") | |
axs[i].set_xticks([]) | |
axs[i].set_yticks(range(num_layers)) | |
axs[i].set_yticklabels(range(num_layers)) | |
axs[i].invert_yaxis() | |
plt.tight_layout() | |
plt.show() | |
# Plot encoder and decoder differences separately | |
plot_layer_diffs(encoder_diffs, "Encoder Layer Weight Differences") | |
plot_layer_diffs(decoder_diffs, "Decoder Layer Weight Differences") | |
# Calculate and visualize layer differences | |
encoder_diffs, decoder_diffs = calculate_layer_diffs(base_model, ft_model) | |
visualize_layer_diffs(encoder_diffs, decoder_diffs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
credit to https://gist.github.com/StableFluffy/1c6f8be84cbe9499de2f9b63d7105ff0 for the idea