Created
May 18, 2015 17:18
-
-
Save sotelo/bd70927c2ce04e227e49 to your computer and use it in GitHub Desktop.
Mean and Variance aggregation scheme.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from theano import tensor | |
from theano.ifelse import ifelse | |
from blocks.utils import shared_like | |
from blocks.monitoring.aggregation import AggregationScheme, Aggregator | |
class MeanAndVariance(AggregationScheme): | |
"""Aggregation scheme which computes the mean. | |
Parameters | |
---------- | |
numerator : :class:`~tensor.TensorVariable` | |
Theano variable for the numerator e.g. the likelihood | |
denominator : :class:`~tensor.TensorVariable` | |
Theano variable for the denominator e.g. the batch size | |
""" | |
def __init__(self, numerator, denominator, axis = ()): | |
self.axis = () | |
self.numerator = numerator.sum(axis = axis) | |
self.denominator = denominator | |
self.squared_num = (numerator**2).sum(axis = axis) | |
def get_aggregator(self): | |
initialized = shared_like(0.) | |
numerator_acc = shared_like(self.numerator) | |
denominator_acc = shared_like(self.denominator) | |
squared_num_acc = shared_like(self.squared_num) | |
conditional_update_num = ifelse(initialized, | |
self.numerator + numerator_acc, | |
self.numerator) | |
conditional_update_den = ifelse(initialized, | |
self.denominator + denominator_acc, | |
self.denominator) | |
conditional_update_sqn = ifelse(initialized, | |
self.squared_num + squared_num_acc, | |
self.squared_num) | |
initialization_updates = [(numerator_acc, | |
tensor.zeros_like(numerator_acc)), | |
(denominator_acc, | |
tensor.zeros_like(denominator_acc)), | |
(squared_num_acc, | |
tensor.zeros_like(squared_num_acc)), | |
(initialized, 0.)] | |
accumulation_updates = [(numerator_acc, | |
conditional_update_num), | |
(denominator_acc, | |
conditional_update_den), | |
(squared_num_acc, | |
conditional_update_sqn), | |
(initialized, 1.)] | |
readout_variable = tensor.stacklists([(numerator_acc / | |
denominator_acc), | |
((squared_num_acc / | |
denominator_acc) - | |
(numerator_acc / | |
denominator_acc)**2)]) | |
aggregator = Aggregator(aggregation_scheme=self, | |
initialization_updates=initialization_updates, | |
accumulation_updates=accumulation_updates, | |
readout_variable = readout_variable) | |
return aggregator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment