Created
February 24, 2015 13:59
-
-
Save sisp/6b78d68b69727413c22b to your computer and use it in GitHub Desktop.
CrossentropySoftmax1HotWithBiasDx + local_useless_incsubtensor_alloc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/theano/tensor/nnet/nnet.py b/theano/tensor/nnet/nnet.py | |
index 788213d..e468385 100644 | |
--- a/theano/tensor/nnet/nnet.py | |
+++ b/theano/tensor/nnet/nnet.py | |
@@ -1099,7 +1099,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): | |
return [g_dy, g_sm, g_y_idx] | |
def c_code_cache_version(self): | |
- return (3,) | |
+ return (4,) | |
def c_code(self, node, name, inp, out, sub): | |
dnll, sm, y_idx = inp | |
@@ -1128,7 +1128,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): | |
PyErr_SetString(PyExc_ValueError, "rank error"); | |
%(fail)s; | |
} | |
- if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(sm)s)[0]) | |
+ if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(sm)s)[0] && PyArray_DIMS(%(dnll)s)[0] != 1) | |
{ | |
PyErr_Format(PyExc_ValueError, | |
"dnll.shape[0] (%%ld) != sm.shape[0] (%%ld)", | |
@@ -1136,7 +1136,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): | |
(long int)PyArray_DIMS(%(sm)s)[0]); | |
%(fail)s; | |
} | |
- if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(y_idx)s)[0]) | |
+ if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(y_idx)s)[0] && PyArray_DIMS(%(dnll)s)[0] != 1) | |
{ | |
PyErr_Format(PyExc_ValueError, | |
"dnll.shape[0] (%%ld) != y_idx.shape[0] (%%ld)", | |
@@ -1161,7 +1161,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): | |
for (size_t i = 0; i < PyArray_DIMS(%(dx)s)[0]; ++i) | |
{ | |
- const dtype_%(dnll)s dnll_i = ((dtype_%(dnll)s*)(PyArray_BYTES(%(dnll)s) + PyArray_STRIDES(%(dnll)s)[0] * i))[0]; | |
+ const dtype_%(dnll)s dnll_i = ((dtype_%(dnll)s*)(PyArray_BYTES(%(dnll)s) + PyArray_STRIDES(%(dnll)s)[0] * (PyArray_DIMS(%(dnll)s)[0] > 1 ? i : 0)))[0]; | |
const %(y_idx_type) s y_i = ((%(y_idx_type)s*)(PyArray_BYTES(%(y_idx)s) + PyArray_STRIDES(%(y_idx)s)[0] * i))[0]; | |
@@ -1736,7 +1736,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): | |
# if the graph is valid, they have the same shape, so we | |
# also know that z has the right shape. | |
- if incr.type not in (dvector, fvector): | |
+ if incr.type not in (dvector, fvector) and not all(incr.broadcastable): | |
return | |
# here we know that we are incrementing some part of matrix z by a vector |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment