Created
January 21, 2025 09:26
-
-
Save nanguoyu/c1dd1808612902597cb9f56ca13b974a to your computer and use it in GitHub Desktop.
save load gradients and compute RGN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
def compute_rgn_from_gradients_and_params(gradients, parameters): | |
""" | |
根据梯度值和参数值直接计算 RGN | |
参数: | |
gradients (dict): 参数名到梯度张量的映射 | |
parameters (dict): 参数名到参数张量的映射 | |
返回: | |
rgn_dict (dict): 参数名到 RGN 值的映射 | |
""" | |
rgn_dict = {} | |
for name in gradients.keys(): | |
grad = gradients[name] | |
param = parameters[name] | |
if grad is not None and param is not None: | |
grad_norm = torch.norm(grad).item() # 梯度的 L2 范数 | |
param_norm = torch.norm(param).item() # 参数的 L2 范数 | |
if param_norm > 0: # 防止除以零 | |
rgn_dict[name] = grad_norm / param_norm | |
else: | |
rgn_dict[name] = 0.0 # 参数范数为零时的默认值 | |
return rgn_dict | |
# 可视化 RGN 的函数 | |
def visualize_rgn(rgn_dict): | |
""" | |
可视化 RGN 值 | |
参数: | |
rgn_dict (dict): 参数名到 RGN 值的映射 | |
""" | |
names = list(rgn_dict.keys()) | |
values = list(rgn_dict.values()) | |
plt.figure(figsize=(10, 6)) | |
plt.bar(names, values) | |
plt.xlabel("Parameter Names") | |
plt.ylabel("Relative Gradient Norm (RGN)") | |
plt.title("Relative Gradient Norm (RGN) for Each Parameter") | |
plt.xticks(rotation=45, ha='right') | |
plt.tight_layout() | |
plt.show() | |
# 主流程 | |
# 定义一个简单模型 | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.fc1 = nn.Linear(10, 20) | |
self.fc2 = nn.Linear(20, 10) | |
self.fc3 = nn.Linear(10, 1) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.fc2(x) | |
x = self.fc3(x) | |
return x | |
# 保存模型权重和梯度 | |
def save_model_and_gradients(model, path): | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'gradients': {name: param.grad.clone() for name, param in model.named_parameters() if param.grad is not None} | |
}, path) | |
# 加载模型权重和梯度 | |
def load_model_and_gradients(path): | |
checkpoint = torch.load(path) | |
return checkpoint['model_state_dict'], checkpoint['gradients'] | |
# 创建模型、损失函数和优化器 | |
model = SimpleModel() | |
criterion = nn.MSELoss() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
# 创建示例数据 | |
x = torch.randn(5, 10) | |
y = torch.randn(5, 1) | |
# 前向传播 | |
output = model(x) | |
loss = criterion(output, y) | |
# 反向传播 | |
loss.backward() | |
# 保存权重和梯度 | |
save_path = "model_with_grad.pth" | |
save_model_and_gradients(model, save_path) | |
# 加载模型权重和梯度 | |
parameters, gradients = load_model_and_gradients(save_path) | |
# 计算 RGN | |
rgn_results = compute_rgn_from_gradients_and_params(gradients, parameters) | |
# 打印 RGN 结果 | |
print("Relative Gradient Norm (RGN) Results:") | |
for name, rgn in rgn_results.items(): | |
print(f"{name}: {rgn:.6f}") | |
# 可视化 RGN | |
visualize_rgn(rgn_results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment