Created
March 9, 2020 11:41
-
-
Save kwojcicki/00bab43359e0f798918955666865fd1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""bn_test.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1xtCz3CwsBSQw7-z2HAWgOpAGNdW96XTI | |
""" | |
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"] | |
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py | |
!python pytorch-xla-env-setup.py --version $VERSION | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
device = xm.xla_device() | |
class ErrorNet(nn.Module): | |
def __init__(self): | |
super(ErrorNet, self).__init__() | |
in_chs = [16] | |
self.c_in = nn.Conv2d(1, in_chs[0], 3, padding=1) | |
self.bn_in1 = nn.BatchNorm2d(in_chs[0]) | |
self.bn_in2 = nn.BatchNorm2d(in_chs[0]) | |
def forward(self, x): | |
h = self.bn_in2(self.bn_in1(self.c_in(x))) | |
# h = self.bn_in1(self.c_in(x)) works as expected | |
return h | |
model = ErrorNet() | |
model = model.to(device) | |
input = torch.zeros(1, 1, 32, 96).to(device) | |
model(input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""bn_test.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1xtCz3CwsBSQw7-z2HAWgOpAGNdW96XTI | |
""" | |
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"] | |
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py | |
!python pytorch-xla-env-setup.py --version $VERSION | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
device = xm.xla_device() | |
class ErrorNet(nn.Module): | |
def __init__(self): | |
super(ErrorNet, self).__init__() | |
in_chs = [16] | |
self.c_in = nn.Conv2d(1, in_chs[0], 3, padding=1) | |
self.bn_in1 = nn.BatchNorm2d(in_chs[0]) | |
self.bn_in2 = nn.BatchNorm2d(in_chs[0]) | |
def forward(self, x): | |
h = self.bn_in2(self.bn_in1(self.c_in(x))) | |
# h = self.bn_in1(self.c_in(x)) works as expected | |
return h | |
model = ErrorNet() | |
model = model.to(device) | |
input = torch.zeros(1, 1, 32, 96).to(device) | |
model(input) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment