Last active
June 2, 2017 15:17
-
-
Save christopher-beckham/362caad79a253a7bde3a7a0e44f11775 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lasagne.nonlinearities import * | |
from lasagne.layers import Layer | |
class SpatialSoftmaxLayer(Layer): | |
""" | |
Softmax layer that computes the softmax over pixels in the same location, | |
i.e., over the channel axis. This layer will automatically use the CuDNN | |
version of this softmax if it is available. | |
Parameters | |
---------- | |
incoming : a :class:`Layer` | |
dnn_softmax_mode : if CuDNN is enabled, what mode should we use for | |
that implementation. There are two: 'accurate', and 'fast' | |
""" | |
def __init__(self, incoming, dnn_softmax_mode='accurate', **kwargs): | |
super(SpatialSoftmaxLayer, self).__init__(incoming, **kwargs) | |
self.use_dnn = False | |
self.input_shape = incoming.output_shape | |
self.dnn_softmax_mode = dnn_softmax_mode | |
try: | |
from theano.sandbox.cuda import dnn | |
if theano.sandbox.cuda.cuda_enabled and dnn.dnn_available(): | |
self.use_dnn = True | |
self.dnn_softmax = dnn.GpuDnnSoftmax | |
except ImportError: | |
pass | |
def get_output_for(self, input, **kwargs): | |
if self.use_dnn: | |
return self.dnn_softmax('bc01', algo=self.dnn_softmax_mode, mode='channel')(input) | |
else: | |
bs, c, h, w = self.input_shape | |
ds1 = input.dimshuffle((0,2,3,1)) | |
rs1 = ds1.reshape((-1, c)) | |
softm = softmax(rs1) | |
rs2 = softm.reshape((-1,h,w,c)) | |
ds2 = rs2.dimshuffle((0,3,1,2)) | |
return ds2 |
I found two problems with this code:
- It assumes 4D arrays (i.e., it does not work with 3D CNN).
- It does not seem to work with FCN (where self.input_shape contains None's).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Could you please publicly give this code a license? (E.g. BSD, Apache, …)?
Looks sane for the time being – in Theano I found several issues / PRs (such as Theano/Theano#5719 ) which try introducing better softmax variants in Theano itself. Eventually, they'll make their way into Theano + Lasagne.