Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created September 8, 2024 07:26
Show Gist options
  • Save chenyaofo/5b388cfdf7dd84ed00a0de33b25bb196 to your computer and use it in GitHub Desktop.
Save chenyaofo/5b388cfdf7dd84ed00a0de33b25bb196 to your computer and use it in GitHub Desktop.
Differences between matrix and tensor with nn.Linear
import torch
b = 2
s = 4
h = 4
d = 4
device = "cuda"
dtype = torch.bfloat16
hidden_states = torch.rand((b,s, h*d), device=device, dtype=dtype)
q_proj = torch.nn.Linear(h*d, h*d, device=device, dtype=dtype)
y1 = q_proj(hidden_states)[:,-1:,:]
y2 = q_proj(hidden_states[:,-1:,:])
print(torch.abs(y1-y2).max())
y3 = q_proj(hidden_states[:,-1,:])
y3:torch.Tensor
print(torch.abs(y1-y3.unsqueeze(1)).max())
'''
Running results:
tensor(0.0039, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment