Skip to content

Instantly share code, notes, and snippets.

@lebedov
Created January 22, 2017 17:53
Show Gist options
  • Save lebedov/062d0cd8e51d00e2dba209cff9c492fc to your computer and use it in GitHub Desktop.
Save lebedov/062d0cd8e51d00e2dba209cff9c492fc to your computer and use it in GitHub Desktop.
Non-legacy View module for pytorch (http://pytorch.org).
#!/usr/bin/env python
"""
Non-legacy View module for pytorch (http://pytorch.org)
"""
import torch.autograd as autograd
import torch.nn as nn
class View(nn.Module):
def __init__(self, *sizes):
self.sizes = sizes
self._parameters = {}
self._modules = {}
self._forward_hooks = {}
self._backward_hooks = {}
def forward(self, input):
# Actual size of view needs to include the batch size (which is always the first
# dimension of the input because pytorch expects all input to be in minibatches):
sizes = (input.size(0),)+self.sizes
return autograd.variable.View(*sizes)(input)
def __repr__(self):
return self.__class__.__name__ + ' (' + \
(', '.join(map(str, self.sizes)))[:-1] + ')'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment