This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
# setup | |
emb1 = nn.Embedding(4, 4) | |
opt1 = torch.optim.Adam(emb1.parameters(), lr=1.) | |
emb2 = nn.Embedding(4, 4, sparse=True) | |
emb2.load_state_dict(emb1.state_dict()) | |
opt2 = torch.optim.SparseAdam(emb2.parameters(), lr=1.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for https://twitter.com/francoisfleuret/status/1550886362815012865 | |
import torch | |
# setup | |
N, Q, R = 5, 20, 10 | |
U = torch.randn(N, Q) | |
V = torch.arange(N*R).view(N, R).float() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LayerNorm(nn.Module): | |
def __init__(self, normalized_shape, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--local_rank', type=int, default=0) | |
parser.add_argument('--apex', action='store_true') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import random | |
import shutil | |
import time | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import time | |
def test(input, kernel, target): | |
# Warmup | |
for _ in range(50): | |
output = F.conv3d(input, kernel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import time | |
torch.backends.cudnn.benachmark = True | |
a = torch.randn(1024, 1024, 10).cuda() | |
b = torch.randn(1024, 1024, 10).cuda() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import time | |
# Create dummy data | |
data = torch.randn(1, 3, 224, 224, device='cuda') | |
target = torch.randint(0, 100, (1,), device='cuda') | |
model = models.resnet152() |
This file has been truncated, but you can view the full file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
torch.backends.cudnn.benchmark = True | |
from apex.normalization import FusedLayerNorm | |
import time | |
NewerOlder