Created
April 15, 2021 18:01
-
-
Save stas00/013176b7974d6b50b53d683f3ff1e38b to your computer and use it in GitHub Desktop.
experiment at trying to overcome overflow when using bf16-trained model with fp16 mixed precision (bf16-trained model leads to huge activations)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Samyam: I have three thoughts here: | |
# 1) would dropping off large activations push the network towards producing smaller activations? I don't the answer but it feels unlikely as the network is not getting penalized in anyway for producing large activations, | |
# 2) dropout is meant to be used as a regularization but by dropping out only large values, it's introducing a bias. It may have unexpected impact on convergence, | |
# 3) if 1 does not happen then during time of inference where there is no dropout, we have the inf again | |
def dropout_abs_max_values(x, p=0.2): | |
""" Like Dropout but instead of random sampling, this one zeroth the p fraction of the biggest absolute values """ | |
topk = int(p * x.shape[-1]) | |
indices = torch.topk(x.abs(), topk, dim=-1, largest=True)[1] | |
return x.scatter(-1, indices, 0) | |
def dropout_bf16_to_fp16(x, p): | |
if max(abs(x.min()), abs(x.max())) > 1e2: | |
return dropout_abs_max_values(x, p) | |
else: | |
return torch.nn.functional.dropout(x, p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment