Skip to content

Instantly share code, notes, and snippets.

@kwojcicki
Created March 9, 2020 11:41
Show Gist options
  • Save kwojcicki/00bab43359e0f798918955666865fd1f to your computer and use it in GitHub Desktop.
Save kwojcicki/00bab43359e0f798918955666865fd1f to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""bn_test.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1xtCz3CwsBSQw7-z2HAWgOpAGNdW96XTI
"""
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
class ErrorNet(nn.Module):
def __init__(self):
super(ErrorNet, self).__init__()
in_chs = [16]
self.c_in = nn.Conv2d(1, in_chs[0], 3, padding=1)
self.bn_in1 = nn.BatchNorm2d(in_chs[0])
self.bn_in2 = nn.BatchNorm2d(in_chs[0])
def forward(self, x):
h = self.bn_in2(self.bn_in1(self.c_in(x)))
# h = self.bn_in1(self.c_in(x)) works as expected
return h
model = ErrorNet()
model = model.to(device)
input = torch.zeros(1, 1, 32, 96).to(device)
model(input)
# -*- coding: utf-8 -*-
"""bn_test.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1xtCz3CwsBSQw7-z2HAWgOpAGNdW96XTI
"""
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
class ErrorNet(nn.Module):
def __init__(self):
super(ErrorNet, self).__init__()
in_chs = [16]
self.c_in = nn.Conv2d(1, in_chs[0], 3, padding=1)
self.bn_in1 = nn.BatchNorm2d(in_chs[0])
self.bn_in2 = nn.BatchNorm2d(in_chs[0])
def forward(self, x):
h = self.bn_in2(self.bn_in1(self.c_in(x)))
# h = self.bn_in1(self.c_in(x)) works as expected
return h
model = ErrorNet()
model = model.to(device)
input = torch.zeros(1, 1, 32, 96).to(device)
model(input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment