This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import math | |
## the model definition | |
# see HeKaiming's implementation using torch: | |
# https://github.com/KaimingHe/resnet-1k-layers/blob/master/README.md | |
class Bottleneck(nn.Module): | |
expansion = 4 # # output cahnnels / # input channels |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def init_weight(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0., 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1., 0.02) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
class CharDataset(object): | |
def __init__(self, path): | |
if not os.path.exists(path): | |
raise RuntimeError('Cannot open the file: {}'.format(path)) | |
self.raw_data = open(path, 'r').read() | |
self.chars = list(set(self.raw_data)) | |
self.data_size = len(self.raw_data) | |
print('There are {} characters in the file'.format(self.data_size)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Fuse conv-bn pattern in torch.Module, an example for torch.fx | |
see: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html | |
""" | |
import copy | |
from typing import Tuple, Dict, Any | |
import torch | |
import torch.fx as fx | |
import torch.nn as nn |
OlderNewer