Created
March 8, 2015 16:51
-
-
Save sisp/52b9a757e76080771a5f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/theano/tensor/basic.py b/theano/tensor/basic.py | |
index 8f734e2..52327a6 100644 | |
--- a/theano/tensor/basic.py | |
+++ b/theano/tensor/basic.py | |
@@ -3804,6 +3804,202 @@ def vertical_stack(*args): | |
return concatenate(args, axis=0) | |
+class FixUnknownDimension(Op): | |
+ """Infer an unknown dimension indicated by `-1`. | |
+ | |
+ In `Reshape` one dimension can be provided as `-1` which means the size of | |
+ this dimension is inferred. This op computes the missing dimension. | |
+ """ | |
+ | |
+ def __init__(self, ndim): | |
+ self.ndim = ndim | |
+ | |
+ def __eq__(self, other): | |
+ return (type(other) is type(self)) and (other.ndim == self.ndim) | |
+ | |
+ def __hash__(self): | |
+ return hash(type(self)) ^ hash(self.ndim) | |
+ | |
+ def __str__(self): | |
+ return '%s{%s}' % (self.__class__.__name__, self.ndim) | |
+ | |
+ def make_node(self, newshape, size): | |
+ newshape = as_tensor_variable(newshape, ndim=1) | |
+ if not newshape.dtype.startswith('int'): | |
+ raise TypeError('`newshape` must be integers', | |
+ newshape, newshape.dtype) | |
+ assert newshape.ndim == 1 | |
+ | |
+ size = as_tensor_variable(size, ndim=0) | |
+ if not size.dtype.startswith('int'): | |
+ raise TypeError('`size` must be an integer', shape, shape.dtype) | |
+ assert size.ndim == 0 | |
+ | |
+ return gof.Apply(self, [newshape, size], [newshape.type()]) | |
+ | |
+ def perform(self, node, inp, out_): | |
+ newshape, size = inp | |
+ out, = out_ | |
+ | |
+ if newshape.ndim != self.ndim: | |
+ raise ValueError('Argument `newshape` to ' | |
+ 'FixUnknownDimension.perform has incorrect ' | |
+ 'length %d, should be %d.' | |
+ % (newshape.ndim, self.ndim), newshape) | |
+ | |
+ if size.ndim != 0: | |
+ raise ValueError('Argument `size` to FixUnknownDimension.perform ' | |
+ 'must be a scalar (0 dimensions). (%d dimensions)' | |
+ % size.ndim) | |
+ | |
+ i_unknown = newshape < 0 | |
+ n_unknown = i_unknown.sum() | |
+ | |
+ if (out[0] is None) or (out[0].shape != newshape.shape): | |
+ out[0] = numpy.empty_like(newshape) | |
+ | |
+ out[0][:] = newshape | |
+ | |
+ if n_unknown == 0: | |
+ if newshape.prod() != size: | |
+ raise ValueError('Total size must not change.') | |
+ elif n_unknown == 1: | |
+ known = newshape[~i_unknown].prod() | |
+ if (known > 0) and (size % known == 0): | |
+ out[0][i_unknown] = size // known | |
+ else: | |
+ raise ValueError('Total size must not change.') | |
+ else: | |
+ raise ValueError('Can only specify one unknown dimension.') | |
+ | |
+ def infer_shape(self, node, ishapes): | |
+ return [ishapes[0]] | |
+ | |
+ def c_code_cache_version(self): | |
+ return (1,) | |
+ | |
+ def c_support_code(self): | |
+ """ | |
+ This code is borrowed from <numpy/core/src/multiarray/shape.c>. | |
+ """ | |
+ return """ | |
+ static int | |
+ _fix_unknown_dimension(PyArray_Dims *newshape, npy_intp s_original) | |
+ { | |
+ npy_intp *dimensions; | |
+ npy_intp i_unknown, s_known; | |
+ int i, n; | |
+ static char msg[] = "total size of new array must be unchanged"; | |
+ | |
+ dimensions = newshape->ptr; | |
+ n = newshape->len; | |
+ s_known = 1; | |
+ i_unknown = -1; | |
+ | |
+ for (i = 0; i < n; i++) { | |
+ if (dimensions[i] < 0) { | |
+ if (i_unknown == -1) { | |
+ i_unknown = i; | |
+ } | |
+ else { | |
+ PyErr_SetString(PyExc_ValueError, | |
+ "can only specify one" \ | |
+ " unknown dimension"); | |
+ return -1; | |
+ } | |
+ } | |
+ else { | |
+ s_known *= dimensions[i]; | |
+ } | |
+ } | |
+ | |
+ if (i_unknown >= 0) { | |
+ if ((s_known == 0) || (s_original % s_known != 0)) { | |
+ PyErr_SetString(PyExc_ValueError, msg); | |
+ return -1; | |
+ } | |
+ dimensions[i_unknown] = s_original/s_known; | |
+ } | |
+ else { | |
+ if (s_original != s_known) { | |
+ PyErr_SetString(PyExc_ValueError, msg); | |
+ return -1; | |
+ } | |
+ } | |
+ return 0; | |
+ } | |
+ """ | |
+ | |
+ def c_code(self, node, name, inputs, outputs, sub): | |
+ newshape, size = inputs | |
+ out, = outputs | |
+ ndim = self.ndim | |
+ dtype_newshape = node.inputs[0].type.dtype_specs()[1] | |
+ dtype_size = node.inputs[1].type.dtype_specs()[1] | |
+ fail = sub['fail'] | |
+ return """ | |
+ if (PyArray_DIMS(%(newshape)s)[0] != %(ndim)s) | |
+ { | |
+ PyErr_Format(PyExc_ValueError, | |
+ "Argument `newshape` to FixUnknownDimension.c_code " | |
+ "has incorrect length %%ld, should be %%ld.", | |
+ (long int)PyArray_DIMS(%(newshape)s)[0], | |
+ (long int)%(ndim)s); | |
+ %(fail)s; | |
+ } | |
+ | |
+ if (PyArray_NDIM(%(size)s) != 0) | |
+ { | |
+ PyErr_Format(PyExc_ValueError, | |
+ "Argument `size` to FixUnknownDimension.c_code must " | |
+ "be a scalar (0 dimensions). (%%ld dimensions)", | |
+ (long int)PyArray_NDIM(%(size)s)); | |
+ %(fail)s; | |
+ } | |
+ | |
+ // Check if output memory can be reused. If not, allocate new memory. | |
+ if ((NULL == %(out)s) || (PyArray_DIMS(%(out)s)[0] != %(ndim)s)) | |
+ { | |
+ if (NULL != %(out)s) | |
+ Py_XDECREF(%(out)s); | |
+ | |
+ %(out)s = (PyArrayObject*) PyArray_SimpleNew( | |
+ PyArray_NDIM(%(newshape)s), | |
+ PyArray_DIMS(%(newshape)s), | |
+ NPY_INTP); | |
+ | |
+ if (!%(out)s) | |
+ { | |
+ PyErr_SetString(PyExc_MemoryError, "Failed to alloc output."); | |
+ %(fail)s | |
+ } | |
+ } | |
+ | |
+ PyArray_Dims newshape; | |
+ newshape.ptr = (npy_intp*)PyArray_DATA(%(out)s); | |
+ newshape.len = %(ndim)s; | |
+ for (int i = 0; i < %(ndim)s; ++i) | |
+ { | |
+ // -- We do not want an explicit cast here. `newshape` can be any | |
+ // -- int* dtype. The compiler will explicitly upcast it, but | |
+ // -- will err if this will downcast. This could happen if the | |
+ // -- user pass an int64 dtype, but npy_intp endup being int32. | |
+ newshape.ptr[i] = ((%(dtype_newshape)s*)( | |
+ PyArray_BYTES(%(newshape)s) + | |
+ PyArray_STRIDES(%(newshape)s)[0] * i))[0]; | |
+ } | |
+ | |
+ { | |
+ const npy_intp size = *((%(dtype_size)s*)PyArray_BYTES(%(size)s)); | |
+ if (_fix_unknown_dimension(&newshape, size) < 0) { | |
+ // The error message should have been set by | |
+ // `_fix_unknown_dimension`. | |
+ %(fail)s; | |
+ } | |
+ } | |
+ """ % locals() | |
+ | |
+ | |
class Reshape(Op): | |
"""Perform a reshape operation of the input x to the new shape shp. | |
@@ -3892,54 +4088,9 @@ class Reshape(Op): | |
return self(eval_points[0], *inputs[1:], **dict(return_list=True)) | |
def infer_shape(self, node, ishapes): | |
- # inputs[1] can contain at most one value of '-1', meaning the actual | |
- # shape of the output will be automatically computed by reshape, so | |
- # that the total number of elements stays the same. | |
- # TODO: Maybe put that formula here? | |
- # It's not trivial, because we would have to check if the product of | |
- # all the non-minus-one shapes is a divisor of the product of the | |
- # original shapes. | |
- | |
- # The following expression leads to cycles in feature_shape, | |
- # because it tries to replace the Shape_i node by the switch | |
- # statement, which depends on Shape_i. | |
- # return [tuple([switch(eq(node.inputs[1][i], -1), | |
- # theano.tensor.opt.Shape_i(i)(node.outputs[0]), | |
- # node.inputs[1][i]) | |
- # for i in xrange(self.ndim)] | |
- # )] | |
- | |
- # Here, we only simplify if the shape (node.inputs[1]) is a constant, | |
- # ideally it would suffice to check that it is always non-negative. | |
- | |
- requ = node.inputs[1] | |
- if isinstance(requ, theano.tensor.TensorConstant): | |
- requ = list(requ.data) | |
- requ_part = [ele for ele in requ if ele != -1] | |
- crit = len(requ) - len(requ_part) | |
- if crit == 1 and len(requ_part) > 0: | |
- missing = mul(*ishapes[0]) // mul(*requ_part) | |
- for i, ele in enumerate(requ): | |
- if ele == -1: | |
- requ[i] = missing | |
- elif crit == 1: # we reshape to -1 | |
- requ = [mul(*ishapes[0])] | |
- elif crit > 1: | |
- raise ValueError('shape argument to Reshape.perform' | |
- ' must have at most one entry equal to -1') | |
- return [requ] | |
- else: | |
- oshape = [] | |
- for i in xrange(self.ndim): | |
- default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0]) | |
- try: | |
- os_i = get_scalar_constant_value(node.inputs[1][i]).item() | |
- if os_i == -1: | |
- os_i = default_os_i | |
- except NotScalarConstantError: | |
- os_i = default_os_i | |
- oshape.append(os_i) | |
- return [tuple(oshape)] | |
+ newshape = node.inputs[1] | |
+ outshape = FixUnknownDimension(self.ndim)(newshape, mul(*ishapes[0])) | |
+ return [tuple(outshape[i] for i in xrange(self.ndim))] | |
def c_code_cache_version(self): | |
return (6,) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment