Skip to content

Instantly share code, notes, and snippets.

@nanguoyu
Created January 21, 2025 09:26
Show Gist options
  • Save nanguoyu/c1dd1808612902597cb9f56ca13b974a to your computer and use it in GitHub Desktop.
Save nanguoyu/c1dd1808612902597cb9f56ca13b974a to your computer and use it in GitHub Desktop.
save load gradients and compute RGN
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
def compute_rgn_from_gradients_and_params(gradients, parameters):
"""
根据梯度值和参数值直接计算 RGN
参数:
gradients (dict): 参数名到梯度张量的映射
parameters (dict): 参数名到参数张量的映射
返回:
rgn_dict (dict): 参数名到 RGN 值的映射
"""
rgn_dict = {}
for name in gradients.keys():
grad = gradients[name]
param = parameters[name]
if grad is not None and param is not None:
grad_norm = torch.norm(grad).item() # 梯度的 L2 范数
param_norm = torch.norm(param).item() # 参数的 L2 范数
if param_norm > 0: # 防止除以零
rgn_dict[name] = grad_norm / param_norm
else:
rgn_dict[name] = 0.0 # 参数范数为零时的默认值
return rgn_dict
# 可视化 RGN 的函数
def visualize_rgn(rgn_dict):
"""
可视化 RGN 值
参数:
rgn_dict (dict): 参数名到 RGN 值的映射
"""
names = list(rgn_dict.keys())
values = list(rgn_dict.values())
plt.figure(figsize=(10, 6))
plt.bar(names, values)
plt.xlabel("Parameter Names")
plt.ylabel("Relative Gradient Norm (RGN)")
plt.title("Relative Gradient Norm (RGN) for Each Parameter")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()
# 主流程
# 定义一个简单模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
# 保存模型权重和梯度
def save_model_and_gradients(model, path):
torch.save({
'model_state_dict': model.state_dict(),
'gradients': {name: param.grad.clone() for name, param in model.named_parameters() if param.grad is not None}
}, path)
# 加载模型权重和梯度
def load_model_and_gradients(path):
checkpoint = torch.load(path)
return checkpoint['model_state_dict'], checkpoint['gradients']
# 创建模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 创建示例数据
x = torch.randn(5, 10)
y = torch.randn(5, 1)
# 前向传播
output = model(x)
loss = criterion(output, y)
# 反向传播
loss.backward()
# 保存权重和梯度
save_path = "model_with_grad.pth"
save_model_and_gradients(model, save_path)
# 加载模型权重和梯度
parameters, gradients = load_model_and_gradients(save_path)
# 计算 RGN
rgn_results = compute_rgn_from_gradients_and_params(gradients, parameters)
# 打印 RGN 结果
print("Relative Gradient Norm (RGN) Results:")
for name, rgn in rgn_results.items():
print(f"{name}: {rgn:.6f}")
# 可视化 RGN
visualize_rgn(rgn_results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment