Created
August 1, 2019 13:56
-
-
Save stephenroller/ae87eef9d704fcf5797067f15a4b742e to your computer and use it in GitHub Desktop.
FusedLayerNorm cannot handle batchsize >= 2**16
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Results of running. Seems indifferent to --dim and --eps. | |
$ python flntest.py --batchsize 65535 | |
Worse case difference: 2.86102294921875e-06 | |
Average case difference: 3.698113104633194e-08 | |
$ python flntest.py --batchsize 65536 | |
Worse case difference: 14.040550231933594 | |
Average case difference: 1.076439619064331 | |
Failure | |
""" | |
import sys | |
import argparse | |
import torch | |
import torch.nn.functional as F | |
import apex.normalization.fused_layer_norm as apexnorm | |
def test(batchsize, dim, eps): | |
weight = torch.randn(dim).cuda() | |
bias = torch.randn(dim).cuda() | |
# input | |
X = torch.randn(batchsize, dim).cuda() | |
# using torch's layernorm | |
Yapx = apexnorm.FusedLayerNormAffineFunction.apply(X, weight, bias, (dim,), eps) | |
Ypyt = F.layer_norm(X, (dim,), weight, bias, eps) | |
print("Worse case difference: {}".format((Yapx - Ypyt).abs().max())) | |
print("Average case difference: {}".format((Yapx - Ypyt).abs().mean())) | |
return Yapx.allclose(Ypyt, atol=1e-5) | |
def main(): | |
ap = argparse.ArgumentParser() | |
ap.add_argument('-b', '--batchsize', type=int, default=128) | |
ap.add_argument('-d', '--dim', type=int, default=512) | |
ap.add_argument('-e', '--eps', type=float, default=1e-6) | |
args = ap.parse_args() | |
result = test(args.batchsize, args.dim, args.eps) | |
if not result: | |
print("Failure") | |
sys.exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment