This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Save/Load state_dict | |
| torch.save(model.state_dict(), PATH) | |
| model = TheModelClass(*args, **kwargs) | |
| model.load_state_dict(torch.load(PATH)) | |
| model.eval() | |
| # Save/Load Entire Model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class LuongAttnDecoderRNN(nn.Module): | |
| def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): | |
| super(LuongAttnDecoderRNN, self).__init__() | |
| # Keep for reference | |
| self.attn_model = attn_model | |
| self.hidden_size = hidden_size | |
| self.output_size = output_size | |
| self.n_layers = n_layers | |
| self.dropout = dropout |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Luong attention layer | |
| class Attn(torch.nn.Module): | |
| def __init__(self, method, hidden_size): | |
| super(Attn, self).__init__() | |
| self.method = method | |
| if self.method not in ['dot', 'general', 'concat']: | |
| raise ValueError(self.method, "is not an appropriate attention method.") | |
| self.hidden_size = hidden_size | |
| if self.method == 'general': | |
| self.attn = torch.nn.Linear(self.hidden_size, hidden_size) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class EncoderRNN(nn.Module): | |
| def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): | |
| super(EncoderRNN, self).__init__() | |
| self.n_layers = n_layers | |
| self.hidden_size = hidden_size | |
| self.embedding = embedding | |
| # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' | |
| # because our input size is a word embedding with number of features == hidden_size | |
| self.gru = nn.GRU(hidden_size, hidden_size, n_layers, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Voc: | |
| def __init__(self, name): | |
| self.name = name | |
| self.trimmed = False | |
| self.word2index = {} | |
| self.word2count = {} | |
| self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} | |
| self.num_words = 3 # Count SOS, EOS, PAD | |
| def addSentence(self, sentence): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model_conv = torchvision.models.resnet18(pretrained=True) | |
| for param in model_conv.parameters(): | |
| param.requires_grad = False | |
| # Parameters of newly constructed modules have requires_grad=True by default | |
| num_ftrs = model_conv.fc.in_features | |
| model_conv.fc = nn.Linear(num_ftrs, 2) | |
| model_conv = model_conv.to(device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model_ft = models.resnet18(pretrained=True) | |
| num_ftrs = model_ft.fc.in_features | |
| model_ft.fc = nn.Linear(num_ftrs, 2) | |
| model_ft = model_ft.to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| # Observe that all parameters are being optimized | |
| optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import torch | |
| class TwoLayerNet(torch.nn.Module): | |
| def __init__(self, D_in, H, D_out): | |
| """ | |
| In the constructor we instantiate two nn.Linear modules and assign them as | |
| member variables. | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class MyReLU(torch.autograd.Function): | |
| """ | |
| We can implement our own custom autograd Functions by subclassing | |
| torch.autograd.Function and implementing the forward and backward passes | |
| which operate on Tensors. | |
| """ | |
| @staticmethod | |
| def forward(ctx, input): | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torchvision import transforms, datasets | |
| data_transform = transforms.Compose([ | |
| transforms.RandomSizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) |