Skip to content

Instantly share code, notes, and snippets.

@BrambleXu
Last active November 1, 2019 22:08
Show Gist options
  • Save BrambleXu/feb892476202ecc55d03f1f377869755 to your computer and use it in GitHub Desktop.
Save BrambleXu/feb892476202ecc55d03f1f377869755 to your computer and use it in GitHub Desktop.
import math, gzip, pickle
import numpy as np
import random
import torch
from torch import tensor
from torch.nn import init
from fastai import datasets
torch.manual_seed(42)
##### prepocess #####
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
def get_data():
path = datasets.download_data(MNIST_URL, ext='.gz')
with gzip.open(path, 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
return map(tensor, (x_train,y_train,x_valid,y_valid))
def normalize(x, m, s):
return (x-m)/s # mean, std
# load data
x_train, y_train, x_valid, y_valid = get_data()
# print(x_train.mean(), x_train.std())
# print(x_valid.mean(), x_valid.std())
# normalize
train_mean, train_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, train_mean, train_std)
# NB: Use training, not validation mean for validation set
x_valid = normalize(x_valid, train_mean, train_std)
# print(x_train.mean(), x_train.std())
# print(x_valid.mean(), x_valid.std())
##### random init: weight mean and std #####
# random init
w1 = torch.randn(784, 50)
b1 = torch.randn(50)
def linear(x, w, b):
return x@w + b
t1 = linear(x_valid, w1, b1)
# print(t1.mean(), t1.std())
##### comparison of kaiming init and random init #####
# random init
w1 = torch.randn(784, 50)
b1 = torch.randn(50)
w2 = torch.randn(50, 10)
b2 = torch.randn(10)
w3 = torch.randn(10, 1)
b3 = torch.randn(1)
def linear(x, w, b):
return x@w + b
def relu(x):
return x.clamp_min(0.)
t1 = relu(linear(x_valid, w1, b1))
t2 = relu(linear(t1, w2, b2))
t3 = relu(linear(t2, w3, b3))
# print(t1.mean(), t1.std())
# print(t2.mean(), t2.std())
# print(t3.mean(), t3.std())
# kaiming init
w1 = torch.randn(784, 50) * math.sqrt(2/784)
b1 = torch.randn(50)
w2 = torch.randn(50, 10) * math.sqrt(2/50)
b2 = torch.randn(10)
w3 = torch.randn(10, 1) * math.sqrt(2/10)
b3 = torch.randn(1)
def linear(x, w, b):
return x@w + b
def relu(x):
return x.clamp_min(0.)
t1 = relu(linear(x_valid, w1, b1))
t2 = relu(linear(t1, w2, b2))
t3 = relu(linear(t2, w3, b3))
# print(t1.mean(), t1.std())
# print(t2.mean(), t2.std())
# print(t3.mean(), t3.std())
##### Understand fan_in and fan_out mode in Pytorch implementation #####
# linear layer implementation
node_in, node_out = 784, 50
layer = torch.nn.Linear(node_in, node_out)
init.kaiming_normal_(layer.weight, mode='fan_in')
# with torch.no_grad():
# t = relu(layer(x_valid))
t = relu(layer(x_valid))
# print(t.mean(), t.std())
# weight matrix implementation
def linear(x, w, b):
return x@w + b
node_in, node_out = 784, 50
w1 = torch.randn(node_in, node_out)
init.kaiming_normal_(w1, mode='fan_out')
b1 = torch.randn(node_out)
t = relu(linear(x_valid, w1, b1))
# print(t.mean(), t.std())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment