Skip to content

Instantly share code, notes, and snippets.

@skaae
Created May 30, 2015 20:31
Show Gist options
  • Save skaae/51390e25ff5f456f0a8f to your computer and use it in GitHub Desktop.
Save skaae/51390e25ff5f456f0a8f to your computer and use it in GitHub Desktop.
.. automodule:: lasagne.updates

The update functions implement different methods to control the learning rate for use with stochastic gradient descent.

Update functions take a loss expression or a list of gradient expressions and a list of parameters as input and return an ordered dictionary of updates:

.. autosummary::
        sgd
        momentum
        nesterov_momentum
        adagrad
        rmsprop
        adadelta
        adam

Two functions can be used to further modify the updates to include momentum:

.. autosummary::
        apply_momentum
        apply_nesterov_momentum

Finally, we provide two helper functions to constrain the norm of tensors:

.. autosummary::
        norm_constraint
        total_norm_constraint

:func:`norm_constraint()` can be used to constrain the norm of parameters (as an alternative to weight decay), or for a form of gradient clipping. :func:`total_norm_constraint()` constrain the total norm of a list of tensors. This is often used when training recurrent neural networks.

Examples

>>> import lasagne
>>> import theano.tensor as T
>>> import theano
>>> from lasagne.nonlinearities import softmax
>>> from lasagne.layers import InputLayer, DenseLayer, get_output
>>> from lasagne.updates import sgd, apply_momentum
>>> l_in = InputLayer((100, 20))
>>> l1 = DenseLayer(l_in, num_units=3, nonlinearity=softmax)
>>> x = T.matrix('x')  # shp: num_batch x num_features
>>> y = T.ivector('y') # shp: num_batch
>>> l_out = get_output(l1, x)
>>> params = lasagne.layers.get_all_params(l1)
>>> loss = T.mean(T.nnet.categorical_crossentropy(l_out, y))
>>> updates_sgd = sgd(loss, params, learning_rate=0.0001)
>>> updates = apply_momentum(updates_sgd, params, momentum=0.9)
>>> train_function = theano.function([x, y], updates=updates)

Update functions

.. autofunction:: sgd
.. autofunction:: momentum
.. autofunction:: nesterov_momentum
.. autofunction:: adagrad
.. autofunction:: rmsprop
.. autofunction:: adadelta
.. autofunction:: adam


Update modification functions

.. autofunction:: apply_momentum
.. autofunction:: apply_nesterov_momentum


Helper functions

.. autofunction:: norm_constraint
.. autofunction:: total_norm_constraint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment