Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created May 26, 2022 01:04
Show Gist options
  • Select an option

  • Save Lyken17/5c1f12cdc1aa5bb028de1cc38c9680c6 to your computer and use it in GitHub Desktop.

Select an option

Save Lyken17/5c1f12cdc1aa5bb028de1cc38c9680c6 to your computer and use it in GitHub Desktop.
# dense update
# forward
input: 1, 48, 8, 8
weight: 240, 48, 1, 1
output: 1, 240, 8, 8
# input
# (n, c, h, w) => (1, n * c, h, w)
input_1 = 1, 48, 8, 8
# grad
# (n, oc, oh, ow) = (tiling) => (n, c // g * oc , oh, ow)
# => (n * c // g * oc , 1, oh, ow)
grad_1 = 1, 240, 8, 8
= (tiling) => 1, 11520, 8, 8
=> 11520, 1, 8, 8
# grad_weight
# (1, n * c, h, w) * (n * c // g * oc , 1, kh, Kw) | g=n*c
grad_w = conv(input_1, grad_1, groups=48)
= (1, 48, 8, 8) * (11520, 1, 8, 8)
= (11520, 1, 1, 1)
# => (n * c // g * oc , 1, oh, ow)
# => (n, c // g * oc, 1, oh, ow)
# = (reduce) => (c // g * oc, 1, oh, ow)
# => (c // g, oc, oh, ow)
grad_w = 11520, 1, 1, 1 => 1, 11520, 1, 1, 1 => 11520, 1, 1, 1
=> 48, 240, 1, 1 => 240, 48, 1, 1
# sparse update
# forward
input: 1, 48, 8, 8
weight: 240, 48, 1, 1
output: 1, 240, 8, 8
# input
# (n, c, h, w) => (1, n * c, h, w)
input_1 = 1, 24, 8, 8
# grad
# (n, oc, oh, ow) = (tiling) => (n, c // g * oc , oh, ow)
# => (n * c // g * oc , 1, oh, ow)
grad_1 = 1, 120, 8, 8
= (tiling) => 1, 5760, 8, 8
=> 5760, 1, 8, 8
# grad_weight
# (1, n * c, h, w) * (n * c // g * oc , 1, kh, Kw) | g=n*c
grad_w = conv(input_1, grad_1, groups=24)
= (1, 24, 8, 8) * (5760, 1, 8, 8)
= (5760, 1, 1, 1)
# => (n * c // g * oc , 1, oh, ow)
# => (n, c // g * oc, 1, oh, ow)
# => (reduce) => (c // g * oc, 1, oh, ow)
# => (c // g, oc, oh, ow)
grad_w = 5760, 1, 1, 1 => 1, 5760, 1, 1, 1 => 5760, 1, 1, 1
=> 24, 240, 1, 1 => 240, 24, 1, 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment