Created
December 20, 2018 14:59
-
-
Save ericmjl/4806c44a4a33aa067ea68a6a853299f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import autograd.numpy as np\n", | |
"from autograd.core import defvjp\n", | |
"import networkx as nx\n", | |
"from autograd import elementwise_grad as egrad\n", | |
"from autograd.misc import flatten\n", | |
"from scipy.sparse import block_diag\n", | |
"\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"%matplotlib inline\n", | |
"%config InlineBackend.figure_format = 'retina'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def spdot(s, d):\n", | |
" \"\"\"\n", | |
" Wrapped sparse dot for autograd.\n", | |
" \"\"\"\n", | |
" return s.dot(d)\n", | |
"\n", | |
"\n", | |
"def _spdot_vjp_0(g, ans, lhs, rhs):\n", | |
" if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:\n", | |
" raise NotImplementedError(\"Current spdot vjps only support ndim = 2.\")\n", | |
"\n", | |
"# if anp.ndim(lhs) == 0:\n", | |
"# return anp.sum(rhs * g)\n", | |
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:\n", | |
"# return g * rhs\n", | |
"# if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:\n", | |
"# return g[:, None] * rhs\n", | |
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:\n", | |
"# return anp.dot(rhs, g)\n", | |
" return anp.dot(g, rhs.T)\n", | |
"\n", | |
"def _spdot_vjp_1(g, ans, lhs, rhs):\n", | |
" if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:\n", | |
" raise NotImplementedError(\"Current spdot vjps only support ndim = 2.\")\n", | |
"\n", | |
"# if anp.ndim(rhs) == 0:\n", | |
"# return anp.sum(lhs * g)\n", | |
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:\n", | |
"# return g * lhs\n", | |
"# if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:\n", | |
"# return anp.dot(g, lhs)\n", | |
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:\n", | |
"# return lhs[:, None] * g\n", | |
" return anp.dot(lhs.T, g)\n", | |
"\n", | |
"defvjp(spdot, _spdot_vjp_0, _spdot_vjp_1)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n", | |
"-> result = np.zeros((other.shape[1], self.shape[0]),\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.shape\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(22, 2)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('int64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype.char\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"'l'\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> <ipython-input-3-651894338714>(32)model()\n", | |
"-> act = spdot(sums, fs)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) fs\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"array([[-0.99996475, -0.9959886 ],\n", | |
" [-0.99819975, -0.96650431],\n", | |
" [-0.99819975, -0.96650431],\n", | |
" [-0.99819975, -0.96650431],\n", | |
" [-0.99996475, -0.9959886 ],\n", | |
" [-0.99996475, -0.9959886 ],\n", | |
" [-0.98720008, -0.90551051],\n", | |
" [-0.99996475, -0.9959886 ],\n", | |
" [-0.99819975, -0.96650431],\n", | |
" [-0.999748 , -0.98836645],\n", | |
" [-0.51904078, 0.99547746],\n", | |
" [-0.52378338, 0.99955411],\n", | |
" [-0.51904078, 0.99547746],\n", | |
" [-0.53317175, 0.99999568],\n", | |
" [-0.51904078, 0.99547746],\n", | |
" [-0.5378174 , 0.99999958],\n", | |
" [-0.51426601, 0.95496806],\n", | |
" [-0.52849372, 0.99995612],\n", | |
" [-0.51904078, 0.99547746],\n", | |
" [-0.52378338, 0.99955411],\n", | |
" [-0.51904078, 0.99547746],\n", | |
" [-0.52849372, 0.99995612]])\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n", | |
"-> result = np.zeros((other.shape[1], self.shape[0]),\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.shape\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(22, 2)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('float64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype.char\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"'d'\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) self.dtype.char\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"'d'\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) upcast_char\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<function upcast_char at 0x7fad1cd46510>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) upcast_char((other.dtype.char, self.dtype.char))\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'numpy.float64'>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n", | |
"-> result = np.zeros((other.shape[1], self.shape[0]),\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.shape\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(22, 2)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('int64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) self.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('float64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) upcast_char(self.dtype.char, other.dtype.char)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'numpy.float64'>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> <ipython-input-3-651894338714>(32)model()\n", | |
"-> act = spdot(sums, fs)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) fs\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bc41288>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) fs.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('float64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n", | |
"-> result = np.zeros((other.shape[1], self.shape[0]),\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('O')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"array([[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77908>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb778c8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77988>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77948>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb779c8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ac8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77bc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77cc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77dc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ec8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77108>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77848>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77fc8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b088>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b148>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b108>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b1c8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b188>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b248>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b208>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b2c8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b288>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b348>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b308>]],\n", | |
" dtype=object)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) import numpy as np\n", | |
"(Pdb) np.asarray(other)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"array([[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77908>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb778c8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77988>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77948>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb779c8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ac8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77bc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77cc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77dc8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f08>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ec8>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f88>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f48>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77108>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77848>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77fc8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b088>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b148>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b108>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b1c8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b188>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b248>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b208>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b2c8>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b288>],\n", | |
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b348>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b308>]],\n", | |
" dtype=object)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other[0]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"array([<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n", | |
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n", | |
" dtype=object)\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other[0][0]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) \n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other[0][0].dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('float64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other[0][1].dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('float64')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dtype('O')\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) other.dtype = 'float64'\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"*** TypeError: Cannot change data-type for object array.\n" | |
] | |
}, | |
{ | |
"name": "stdin", | |
"output_type": "stream", | |
"text": [ | |
"(Pdb) c\n" | |
] | |
}, | |
{ | |
"ename": "TypeError", | |
"evalue": "no supported conversion for types: (dtype('float64'), dtype('O'))", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mflat_params\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0mdloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/wrap_util.py\u001b[0m in \u001b[0;36mnary_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margnum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0munary_operator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munary_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mnary_op_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mnary_op_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnary_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnary_operator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/differential_operators.py\u001b[0m in \u001b[0;36melementwise_grad\u001b[0;34m(fun, x)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munary_to_nary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0melementwise_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mvjp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_make_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miscomplex\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Elementwise_grad only applies to real-output functions.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/core.py\u001b[0m in \u001b[0;36mmake_vjp\u001b[0;34m(fun, x)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mstart_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVJPNode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_root\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mend_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_node\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mend_node\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/tracer.py\u001b[0m in \u001b[0;36mtrace\u001b[0;34m(start_node, fun, x)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrace_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mstart_box\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_box\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_node\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mend_box\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_box\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mend_box\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_trace\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mstart_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_trace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_node\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/wrap_util.py\u001b[0m in \u001b[0;36munary_f\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msubargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubvals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0msubargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36mmseloss\u001b[0;34m(flat_params, unflattener, model, a, f, s, y)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmseloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_mse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(adj, feats, sums, params)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mp1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'w'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mp1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msums\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mact\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dense1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'w'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dense1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-2-6ab25bcf1507>\u001b[0m in \u001b[0;36mspdot\u001b[0;34m(s, d)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mWrapped\u001b[0m \u001b[0msparse\u001b[0m \u001b[0mdot\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \"\"\"\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/base.py\u001b[0m in \u001b[0;36mdot\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 360\u001b[0m \"\"\"\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/base.py\u001b[0m in \u001b[0;36m__mul__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'dimension mismatch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 516\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 517\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mul_multivector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 518\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py\u001b[0m in \u001b[0;36m_mul_multivector\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_mul_multivector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 577\u001b[0;31m result = np.zeros((other.shape[1], self.shape[0]),\n\u001b[0m\u001b[1;32m 578\u001b[0m dtype=upcast_char(self.dtype.char, other.dtype.char))\n\u001b[1;32m 579\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/sputils.py\u001b[0m in \u001b[0;36mupcast_char\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mupcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0m_upcast_memo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/sputils.py\u001b[0m in \u001b[0;36mupcast\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'no supported conversion for types: %r'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: no supported conversion for types: (dtype('float64'), dtype('O'))" | |
] | |
} | |
], | |
"source": [ | |
"import pdb\n", | |
"# Define a model that involves a sparse matrix multiply followed by summation followed by dense.\n", | |
"\n", | |
"g1 = nx.erdos_renyi_graph(n=10, p=0.3)\n", | |
"a1 = nx.to_numpy_array(g1) + np.eye(10)\n", | |
"f1 = np.array([[0, 1] * 10]).reshape(-1, 2)\n", | |
"s1 = np.ones(10)\n", | |
"\n", | |
"g2 = nx.erdos_renyi_graph(n=12, p=0.3)\n", | |
"a2 = nx.to_numpy_array(g2) + np.eye(12)\n", | |
"f2 = np.array([[1, 0] * 12]).reshape(-1, 2)\n", | |
"s2 = np.ones(12)\n", | |
"\n", | |
"a = block_diag([a1, a2])\n", | |
"f = np.vstack([f1, f2])\n", | |
"s = block_diag([s1, s2])\n", | |
"\n", | |
"params = dict()\n", | |
"params['mp1'] = dict()\n", | |
"params['mp1']['w'] = np.random.normal(size=(2, 2))\n", | |
"params['mp1']['b'] = np.random.normal(size=(2,))\n", | |
"\n", | |
"params['dense1'] = dict()\n", | |
"params['dense1']['w'] = np.random.normal(size=(2, 1))\n", | |
"params['dense1']['b'] = np.random.normal(size=(1,))\n", | |
"\n", | |
"def model(adj, feats, sums, params):\n", | |
" \n", | |
" fs = spdot(adj, feats)\n", | |
" fs = np.tanh(np.dot(fs, params['mp1']['w']) + params['mp1']['b'])\n", | |
" pdb.set_trace() \n", | |
" act = spdot(sums, fs)\n", | |
" \n", | |
" act = np.tanh(np.dot(act, params['dense1']['w']) + params['dense1']['b'])\n", | |
" \n", | |
" return act\n", | |
"\n", | |
"y = np.array([[-1, 1]])\n", | |
"\n", | |
"preds = model(a, f, s, params)\n", | |
"\n", | |
"def _mse(y_true, y_pred):\n", | |
" return np.mean(np.power(y_true - y_pred, 2))\n", | |
"\n", | |
"def mseloss(flat_params, unflattener, model, a, f, s, y):\n", | |
" params = unflattener(flat_params)\n", | |
" preds = model(a, f, s, params)\n", | |
" return _mse(y, preds)\n", | |
" \n", | |
"dloss = egrad(mseloss)\n", | |
"flat_params, unflattener = flatten(params)\n", | |
"for i in range(1000):\n", | |
" flat_params -= dloss(flat_params, unflattener, model, a, f, s, y) * 0.01\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dtype('float64')" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"spdot(a, f).dtype\n", | |
"spdot(s, spdot(a, f)).dtype" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "mpnn", | |
"language": "python", | |
"name": "mpnn" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment