Skip to content

Instantly share code, notes, and snippets.

@stas00
Created April 15, 2021 18:01
Show Gist options
  • Save stas00/013176b7974d6b50b53d683f3ff1e38b to your computer and use it in GitHub Desktop.
Save stas00/013176b7974d6b50b53d683f3ff1e38b to your computer and use it in GitHub Desktop.
experiment at trying to overcome overflow when using bf16-trained model with fp16 mixed precision (bf16-trained model leads to huge activations)
# Samyam: I have three thoughts here:
# 1) would dropping off large activations push the network towards producing smaller activations? I don't the answer but it feels unlikely as the network is not getting penalized in anyway for producing large activations,
# 2) dropout is meant to be used as a regularization but by dropping out only large values, it's introducing a bias. It may have unexpected impact on convergence,
# 3) if 1 does not happen then during time of inference where there is no dropout, we have the inf again
def dropout_abs_max_values(x, p=0.2):
""" Like Dropout but instead of random sampling, this one zeroth the p fraction of the biggest absolute values """
topk = int(p * x.shape[-1])
indices = torch.topk(x.abs(), topk, dim=-1, largest=True)[1]
return x.scatter(-1, indices, 0)
def dropout_bf16_to_fp16(x, p):
if max(abs(x.min()), abs(x.max())) > 1e2:
return dropout_abs_max_values(x, p)
else:
return torch.nn.functional.dropout(x, p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment