Skip to content

Instantly share code, notes, and snippets.

@J3698
Last active June 27, 2021 18:43
Show Gist options
  • Save J3698/c6f4c00df594103d66077b5fb19074c2 to your computer and use it in GitHub Desktop.
Save J3698/c6f4c00df594103d66077b5fb19074c2 to your computer and use it in GitHub Desktop.
# check shapes
assert len(target.shape) == 4, "expected 4 dimensions"
assert target.shape == source.shape, "source/target shape mismatch"
batch_size, channels, width, height = source.shape
# calculate target stats
target_reshaped = target.view(batch_size, channels, 1, 1, -1)
target_variances = target_reshaped.var(-1, unbiased = False)
target_means = target_reshaped.mean(-1)
# normalize and rescale source to match target stats
normalized = F.instance_norm(source)
result = normalized * (target_variances ** 0.5) + target_means
assert result.shape == (batch_size, channels, width, height)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment