Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created July 26, 2022 21:31
Show Gist options
  • Save wanchaol/ac2ce241e295ef50ae18be1d8bc20212 to your computer and use it in GitHub Desktop.
Save wanchaol/ac2ce241e295ef50ae18be1d8bc20212 to your computer and use it in GitHub Desktop.
addmm(bias, input, weight)
bias: replicated
output = input * weight
input shard(1), weight, shard(0) -> partial tensor
output -> partial -> replicated?
output + bias -> partial? only do on one rank
bias.grad = output -> partial
output.grad = bias -> replicated
forward of dist op:
disttensor + dist tensor = out_dist_tensor
out_dist_tensor.grad_fn = addbackward ->
dist_tensor0.mm(dist_tensor_1) = out_dist_tensor
out_dist_tensor.grad_fn = mmbackward -> mm
dist_tensor0.grad -> mm(output_grad, dist_tensor_1)
addbackward ->
bias = 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment