This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sam import SAM | |
... | |
model = YourModel() | |
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update | |
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9) | |
... | |
for input, output in data: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class SAM(torch.optim.Optimizer): | |
def __init__(self, params, base_optimizer, rho=0.05, **kwargs): | |
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
defaults = dict(rho=rho, **kwargs) | |
super(SAM, self).__init__(params, defaults) | |
self.base_optimizer = base_optimizer(self.param_groups, **kwargs) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ResNet(nn.Module): | |
def __init__(self,block,layers,image_channels,num_classes): | |
super(ResNet,self).__init__() | |
self.in_channels = 64 | |
self.conv1 = nn.Conv2d(image_channels,64,kernel_size=7,stride=2,padding=3) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu = nn.ReLU() | |
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) | |
#the resnet layers | |
self.layer1 = self._make_layer(block,layers[0],int_channels=64,stride=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# an essential block of layers which forms resnets | |
class ResBlock(nn.Module): | |
#in_channels -> input channels,int_channels->intermediate channels | |
def __init__(self,in_channels,int_channels,identity_downsample=None,stride=1): | |
super(ResBlock,self).__init__() | |
self.expansion = 4 | |
self.conv1 = nn.Conv2d(in_channels,int_channels,kernel_size=1,stride=1,padding=0) | |
self.bn1 = nn.BatchNorm2d(int_channels) | |
self.conv2 = nn.Conv2d(int_channels,int_channels,kernel_size=3,stride=stride,padding=1) | |
self.bn2 = nn.BatchNorm2d(int_channels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _make_layer(self,block,num_res_blocks,int_channels,stride): | |
identity_downsample = None | |
layers = [] | |
if stride!=1 or self.in_channels != int_channels*4: | |
identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels,int_channels*4, | |
kernel_size=1,stride=stride), | |
nn.BatchNorm2d(int_channels*4)) | |
layers.append(ResBlock(self.in_channels,int_channels,identity_downsample,stride)) | |
#this expansion size will always be 4 for all the types of ResNets |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class InceptionNet(nn.Module): | |
def __init__(self,aux_logits=True,num_classes=1000): | |
super(InceptionNet, self).__init__() | |
assert aux_logits == True or aux_logits == False | |
self.aux_logits = aux_logits | |
self.conv1 = conv_block(in_channels=3,out_channels=64,kernel_size=(7,7), | |
stride=(2,2), padding=(3,3)) | |
self.maxpool1 = nn.MaxPool2d(kernel_size=3,stride=2, padding=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Inception_block(nn.Module): | |
def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool): | |
super(Inception_block, self).__init__() | |
self.branch1 = conv_block(in_channels, out_1x1, kernel_size=(1,1)) | |
self.branch2 = nn.Sequential( | |
conv_block(in_channels, red_3x3, kernel_size=(1,1)), | |
conv_block(red_3x3, out_3x3, kernel_size=(3,3),padding=(1,1)) | |
) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class InceptionAux(nn.Module): | |
def __init__(self, in_channels, num_classes): | |
super(InceptionAux,self).__init__() | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(p=0.7) | |
self.pool = nn.AvgPool2d(kernel_size=5,stride=3) | |
self.conv = conv_block(in_channels, 128, kernel_size=1) | |
self.fc1 = nn.Linear(2048, 1024) | |
self.fc2 = nn.Linear(1024, num_classes) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class conv_block(nn.Module): | |
def __init__(self, in_channels, out_channels, **kwargs): | |
super(conv_block, self).__init__() | |
self.relu = nn.ReLU() | |
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) | |
self.batchnorm = nn.BatchNorm2d(out_channels) | |
def forward(self, x): | |
return self.relu(self.batchnorm(self.conv(x))) |