Skip to content

Instantly share code, notes, and snippets.

@ericmjl
Created December 20, 2018 14:59
Show Gist options
  • Save ericmjl/4806c44a4a33aa067ea68a6a853299f8 to your computer and use it in GitHub Desktop.
Save ericmjl/4806c44a4a33aa067ea68a6a853299f8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import autograd.numpy as np\n",
"from autograd.core import defvjp\n",
"import networkx as nx\n",
"from autograd import elementwise_grad as egrad\n",
"from autograd.misc import flatten\n",
"from scipy.sparse import block_diag\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def spdot(s, d):\n",
" \"\"\"\n",
" Wrapped sparse dot for autograd.\n",
" \"\"\"\n",
" return s.dot(d)\n",
"\n",
"\n",
"def _spdot_vjp_0(g, ans, lhs, rhs):\n",
" if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:\n",
" raise NotImplementedError(\"Current spdot vjps only support ndim = 2.\")\n",
"\n",
"# if anp.ndim(lhs) == 0:\n",
"# return anp.sum(rhs * g)\n",
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:\n",
"# return g * rhs\n",
"# if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:\n",
"# return g[:, None] * rhs\n",
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:\n",
"# return anp.dot(rhs, g)\n",
" return anp.dot(g, rhs.T)\n",
"\n",
"def _spdot_vjp_1(g, ans, lhs, rhs):\n",
" if max(anp.ndim(lhs), anp.ndim(rhs)) > 2:\n",
" raise NotImplementedError(\"Current spdot vjps only support ndim = 2.\")\n",
"\n",
"# if anp.ndim(rhs) == 0:\n",
"# return anp.sum(lhs * g)\n",
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 1:\n",
"# return g * lhs\n",
"# if anp.ndim(lhs) == 2 and anp.ndim(rhs) == 1:\n",
"# return anp.dot(g, lhs)\n",
"# if anp.ndim(lhs) == 1 and anp.ndim(rhs) == 2:\n",
"# return lhs[:, None] * g\n",
" return anp.dot(lhs.T, g)\n",
"\n",
"defvjp(spdot, _spdot_vjp_0, _spdot_vjp_1)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n",
"-> result = np.zeros((other.shape[1], self.shape[0]),\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.shape\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(22, 2)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('int64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype.char\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"'l'\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"> <ipython-input-3-651894338714>(32)model()\n",
"-> act = spdot(sums, fs)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) fs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([[-0.99996475, -0.9959886 ],\n",
" [-0.99819975, -0.96650431],\n",
" [-0.99819975, -0.96650431],\n",
" [-0.99819975, -0.96650431],\n",
" [-0.99996475, -0.9959886 ],\n",
" [-0.99996475, -0.9959886 ],\n",
" [-0.98720008, -0.90551051],\n",
" [-0.99996475, -0.9959886 ],\n",
" [-0.99819975, -0.96650431],\n",
" [-0.999748 , -0.98836645],\n",
" [-0.51904078, 0.99547746],\n",
" [-0.52378338, 0.99955411],\n",
" [-0.51904078, 0.99547746],\n",
" [-0.53317175, 0.99999568],\n",
" [-0.51904078, 0.99547746],\n",
" [-0.5378174 , 0.99999958],\n",
" [-0.51426601, 0.95496806],\n",
" [-0.52849372, 0.99995612],\n",
" [-0.51904078, 0.99547746],\n",
" [-0.52378338, 0.99955411],\n",
" [-0.51904078, 0.99547746],\n",
" [-0.52849372, 0.99995612]])\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n",
"-> result = np.zeros((other.shape[1], self.shape[0]),\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.shape\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(22, 2)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('float64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype.char\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"'d'\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) self.dtype.char\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"'d'\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) upcast_char\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<function upcast_char at 0x7fad1cd46510>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) upcast_char((other.dtype.char, self.dtype.char))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'numpy.float64'>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n",
"-> result = np.zeros((other.shape[1], self.shape[0]),\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.shape\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(22, 2)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('int64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) self.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('float64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) upcast_char(self.dtype.char, other.dtype.char)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'numpy.float64'>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"> <ipython-input-3-651894338714>(32)model()\n",
"-> act = spdot(sums, fs)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) fs\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bc41288>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) fs.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('float64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"> /home/maer3/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py(577)_mul_multivector()\n",
"-> result = np.zeros((other.shape[1], self.shape[0]),\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('O')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77908>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb778c8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77988>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77948>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb779c8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ac8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77bc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77cc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77dc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ec8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77108>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77848>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77fc8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b088>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b148>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b108>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b1c8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b188>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b248>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b208>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b2c8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b288>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b348>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b308>]],\n",
" dtype=object)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) import numpy as np\n",
"(Pdb) np.asarray(other)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77908>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb778c8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77988>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77948>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb779c8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77a48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ac8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77b48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77bc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77c48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77cc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77d48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77dc8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77e48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f08>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77ec8>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f88>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77f48>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77108>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77848>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77fc8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b088>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b148>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b108>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b1c8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b188>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b248>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b208>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b2c8>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b288>],\n",
" [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b348>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb8b308>]],\n",
" dtype=object)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other[0]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>,\n",
" <autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77048>],\n",
" dtype=object)\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other[0][0]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<autograd.numpy.numpy_boxes.ArrayBox object at 0x7fad0bb77888>\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other[0][0].dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('float64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other[0][1].dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('float64')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dtype('O')\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) other.dtype = 'float64'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"*** TypeError: Cannot change data-type for object array.\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"(Pdb) c\n"
]
},
{
"ename": "TypeError",
"evalue": "no supported conversion for types: (dtype('float64'), dtype('O'))",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mflat_params\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0mdloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/wrap_util.py\u001b[0m in \u001b[0;36mnary_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margnum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0munary_operator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munary_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mnary_op_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mnary_op_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnary_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnary_operator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/differential_operators.py\u001b[0m in \u001b[0;36melementwise_grad\u001b[0;34m(fun, x)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munary_to_nary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0melementwise_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mvjp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_make_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miscomplex\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Elementwise_grad only applies to real-output functions.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/core.py\u001b[0m in \u001b[0;36mmake_vjp\u001b[0;34m(fun, x)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmake_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mstart_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVJPNode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_root\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mend_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_node\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mend_node\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/tracer.py\u001b[0m in \u001b[0;36mtrace\u001b[0;34m(start_node, fun, x)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrace_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mstart_box\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_box\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_node\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mend_box\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_box\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mend_box\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_trace\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mstart_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_trace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_box\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_node\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/autograd/wrap_util.py\u001b[0m in \u001b[0;36munary_f\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msubargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubvals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0msubargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margnum\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36mmseloss\u001b[0;34m(flat_params, unflattener, model, a, f, s, y)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmseloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munflattener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_mse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-651894338714>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(adj, feats, sums, params)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mp1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'w'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mp1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msums\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mact\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dense1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'w'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dense1'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-2-6ab25bcf1507>\u001b[0m in \u001b[0;36mspdot\u001b[0;34m(s, d)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mWrapped\u001b[0m \u001b[0msparse\u001b[0m \u001b[0mdot\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \"\"\"\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/base.py\u001b[0m in \u001b[0;36mdot\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 360\u001b[0m \"\"\"\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/base.py\u001b[0m in \u001b[0;36m__mul__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'dimension mismatch'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 516\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 517\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mul_multivector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 518\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/coo.py\u001b[0m in \u001b[0;36m_mul_multivector\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_mul_multivector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 577\u001b[0;31m result = np.zeros((other.shape[1], self.shape[0]),\n\u001b[0m\u001b[1;32m 578\u001b[0m dtype=upcast_char(self.dtype.char, other.dtype.char))\n\u001b[1;32m 579\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/sputils.py\u001b[0m in \u001b[0;36mupcast_char\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mupcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0m_upcast_memo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/mpnn/lib/python3.6/site-packages/scipy/sparse/sputils.py\u001b[0m in \u001b[0;36mupcast\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'no supported conversion for types: %r'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: no supported conversion for types: (dtype('float64'), dtype('O'))"
]
}
],
"source": [
"import pdb\n",
"# Define a model that involves a sparse matrix multiply followed by summation followed by dense.\n",
"\n",
"g1 = nx.erdos_renyi_graph(n=10, p=0.3)\n",
"a1 = nx.to_numpy_array(g1) + np.eye(10)\n",
"f1 = np.array([[0, 1] * 10]).reshape(-1, 2)\n",
"s1 = np.ones(10)\n",
"\n",
"g2 = nx.erdos_renyi_graph(n=12, p=0.3)\n",
"a2 = nx.to_numpy_array(g2) + np.eye(12)\n",
"f2 = np.array([[1, 0] * 12]).reshape(-1, 2)\n",
"s2 = np.ones(12)\n",
"\n",
"a = block_diag([a1, a2])\n",
"f = np.vstack([f1, f2])\n",
"s = block_diag([s1, s2])\n",
"\n",
"params = dict()\n",
"params['mp1'] = dict()\n",
"params['mp1']['w'] = np.random.normal(size=(2, 2))\n",
"params['mp1']['b'] = np.random.normal(size=(2,))\n",
"\n",
"params['dense1'] = dict()\n",
"params['dense1']['w'] = np.random.normal(size=(2, 1))\n",
"params['dense1']['b'] = np.random.normal(size=(1,))\n",
"\n",
"def model(adj, feats, sums, params):\n",
" \n",
" fs = spdot(adj, feats)\n",
" fs = np.tanh(np.dot(fs, params['mp1']['w']) + params['mp1']['b'])\n",
" pdb.set_trace() \n",
" act = spdot(sums, fs)\n",
" \n",
" act = np.tanh(np.dot(act, params['dense1']['w']) + params['dense1']['b'])\n",
" \n",
" return act\n",
"\n",
"y = np.array([[-1, 1]])\n",
"\n",
"preds = model(a, f, s, params)\n",
"\n",
"def _mse(y_true, y_pred):\n",
" return np.mean(np.power(y_true - y_pred, 2))\n",
"\n",
"def mseloss(flat_params, unflattener, model, a, f, s, y):\n",
" params = unflattener(flat_params)\n",
" preds = model(a, f, s, params)\n",
" return _mse(y, preds)\n",
" \n",
"dloss = egrad(mseloss)\n",
"flat_params, unflattener = flatten(params)\n",
"for i in range(1000):\n",
" flat_params -= dloss(flat_params, unflattener, model, a, f, s, y) * 0.01\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float64')"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"spdot(a, f).dtype\n",
"spdot(s, spdot(a, f)).dtype"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpnn",
"language": "python",
"name": "mpnn"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment