Last active
May 7, 2019 13:49
-
-
Save jlebensold/f7d5c889ae4d94f7630a96f7effc7e8e to your computer and use it in GitHub Desktop.
Baselines for MorphNet paper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class TanhNet(nn.Module): | |
def __init__(self, in_features, h_units): | |
super(TanhNet, self).__init__() | |
self.fc1 = nn.Linear(in_features, h_units) | |
self.fc2 = nn.Linear(h_units, 10) | |
self.in_features = in_features | |
self.h_units = h_units | |
def forward(self, input: torch.Tensor): | |
flattened = input.view(-1, self.in_features) | |
input = F.tanh(self.fc1(flattened)) | |
input = F.tanh(self.fc2(input)) | |
return input | |
def name(self, dset_name: str): | |
return "tanh_{}x{}_{}".format( | |
self.in_features, self.h_units, dset_name) | |
def store(self, dset_name: str, directory: Path): | |
name = self.name(dset_name) | |
fname = "{}.tch".format(name) | |
torch.save(self, str(directory / fname)) | |
class ReLUNet(nn.Module): | |
def __init__(self, in_features, h_units): | |
super(ReLUNet, self).__init__() | |
self.fc1 = nn.Linear(in_features, h_units) | |
self.fc2 = nn.Linear(h_units, 10) | |
self.in_features = in_features | |
self.h_units = h_units | |
def forward(self, input: torch.Tensor): | |
flattened = input.view(-1, self.in_features) | |
input = F.relu(self.fc1(flattened)) | |
input = F.relu(self.fc2(input)) | |
return input | |
def name(self, dset_name: str): | |
return "relu_{}x{}_{}".format( | |
self.in_features, self.h_units, dset_name) | |
def store(self, dset_name: str, directory: Path): | |
name = self.name(dset_name) | |
fname = "{}.tch".format(name) | |
torch.save(self, str(directory / fname)) | |
class MaxoutNet(nn.Module): | |
def __init__(self, in_features, h_units, out_features=10): | |
super(MaxoutNet, self).__init__() | |
self.fc1 = Maxout(in_features, h_units, 2) | |
self.fc2 = Maxout(h_units, out_features, 2) | |
self.in_features = in_features | |
self.h_units = h_units | |
def forward(self, input: torch.Tensor): | |
flattened = input.view(-1, self.in_features) | |
input = self.fc1(flattened) | |
input = self.fc2(input) | |
return input | |
def name(self, dset_name: str): | |
return "maxout_{}x{}_{}".format( | |
self.in_features, self.h_units, dset_name) | |
def store(self, dset_name: str, directory: Path): | |
name = self.name(dset_name) | |
fname = "{}.tch".format(name) | |
torch.save(self, str(directory / fname)) | |
# from https://github.com/pytorch/pytorch/issues/805 | |
class Maxout(nn.Module): | |
def __init__(self, d_in, d_out, pool_size): | |
super().__init__() | |
self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size | |
self.lin = nn.Linear(d_in, d_out * pool_size) | |
def forward(self, inputs): | |
shape = list(inputs.size()) | |
shape[-1] = self.d_out | |
shape.append(self.pool_size) | |
max_dim = len(shape) - 1 | |
out = self.lin(inputs) | |
maxout, _i = out.view(*shape).max(max_dim) | |
return maxout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from pathlib import Path | |
import torch | |
from torch import nn | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
def dilate(x, s): | |
return torch.max(x + s, dim=1) | |
def erode(x, s): | |
return torch.min(x - s, dim=1) | |
class DilateErode(nn.Module): | |
def __init__(self, in_features: int, number_of_dilations: int, number_of_erosions: int): | |
super().__init__() | |
self.in_features = in_features | |
self.number_of_dilations = number_of_dilations | |
self.number_of_erosions = number_of_erosions | |
if self.number_of_dilations > 0: | |
self.dilations = nn.Parameter(torch.randn(in_features, number_of_dilations)) | |
else: | |
self.dilations = torch.Tensor() | |
if self.number_of_erosions > 0: | |
self.erosions = nn.Parameter(torch.randn(in_features, number_of_erosions)) | |
else: | |
self.erosions = torch.Tensor() | |
self.dilation_bias = nn.Parameter(torch.zeros(1)) | |
self.erosion_bias = nn.Parameter(torch.zeros(1)) | |
def forward(self, input: torch.Tensor): | |
batch_size = input.shape[0] | |
flattened = input.view(batch_size, self.in_features, 1) | |
if self.number_of_dilations > 0: | |
# Each dilation is a max of a sum of all the input features. | |
dsum = flattened + self.dilations | |
dilated = torch.max(dsum, dim=1)[0] | |
# Append the dilation bias. The paper treats it as a tensor, but because you take a max it's | |
# actually just a constant. | |
dilated_with_bias = torch.cat((dilated, self.dilation_bias.expand(batch_size, 1)), dim=1) | |
else: | |
dilated_with_bias = torch.Tensor().to(device) | |
if self.number_of_erosions > 0: | |
# Each erosion is a min of a difference of all the input features. | |
esub = flattened - self.erosions | |
eroded = torch.min(esub, dim=1)[0] | |
# Append the erosion bias. | |
eroded_with_bias = torch.cat((eroded, (-self.erosion_bias).expand(batch_size, 1)), dim=1) | |
else: | |
eroded_with_bias = torch.Tensor().to(device) | |
combined = torch.cat((dilated_with_bias, eroded_with_bias), dim=1) | |
return combined | |
class DenMoNet(nn.Module): | |
"""The dilation-erosion network.""" | |
def __init__(self, input_space_dim: int, number_dilations: int, number_erosions: int, output_space_dim: int): | |
super().__init__() | |
self.de_layer = DilateErode(input_space_dim, number_dilations, number_erosions) | |
# The linear combination size is the number of erosions plus the number of dilations, plus | |
# one bias node for each (if there's at least one, that is). | |
lc_size = number_erosions + np.sign(number_erosions) + number_dilations + np.sign(number_dilations) | |
self.linear_combination_layer = nn.Linear(lc_size, output_space_dim) | |
def name(self, dset_name: str): | |
return "denmo_{}x{}_{}".format(self.de_layer.number_of_dilations, | |
self.de_layer.number_of_erosions, dset_name) | |
def forward(self, input: torch.Tensor): | |
temp = self.de_layer(input) | |
self.temp = temp | |
classification = self.linear_combination_layer(temp) | |
return classification | |
def store(self, dset_name: str, directory: Path): | |
name = self.name(dset_name) | |
fname = "{}.tch".format(name) | |
torch.save(self, str(directory / fname)) |
Author
jlebensold
commented
Nov 30, 2018
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment