Last active
April 30, 2018 03:33
-
-
Save wassname/2ac3e0e393f0e0dbc2f830b4f0750f43 to your computer and use it in GitHub Desktop.
kdtree_scaling_scipy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Trying to scale up kdnet\n", | |
| "\n", | |
| "- testing with decimation=100 a first\n", | |
| "\n", | |
| "- Looks like we need multiple of 2. \n", | |
| "- looks like constant number of points\n", | |
| "- looks like first conv must have kernel=number of points! Which is very large, like having a dense layer\n", | |
| "- [ ] make it work with odd numbers of points?" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:24.025541Z", | |
| "start_time": "2018-04-30T03:32:23.829795Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from datasets import PartDataset\n", | |
| "import numpy as np\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "from torch.autograd import Variable\n", | |
| "import torch.optim as optim\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-23T02:23:55.873150Z", | |
| "start_time": "2018-04-23T02:23:55.869711Z" | |
| } | |
| }, | |
| "source": [ | |
| "# gen some fake points" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:32.406139Z", | |
| "start_time": "2018-04-30T03:32:32.369176Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "32768" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import torch.utils.data\n", | |
| "class Dataset(torch.utils.data.Dataset):\n", | |
| " def __init__(self, *tensors):\n", | |
| " assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)\n", | |
| " self.tensors = tensors\n", | |
| "\n", | |
| " def __getitem__(self, index):\n", | |
| " return tuple(tensor[index]*1 for tensor in self.tensors)\n", | |
| "\n", | |
| " def __len__(self):\n", | |
| " return self.tensors[0].shape[0]\n", | |
| " \n", | |
| "# make some fake points to N depth\n", | |
| "# np.random.seed(1)\n", | |
| "levels = 15\n", | |
| "npoints = 2**levels\n", | |
| "\n", | |
| "l = 2\n", | |
| "x=np.random.random((l,2**levels,3))\n", | |
| "y=np.random.random((l,3))\n", | |
| "dataset = Dataset(x,y)\n", | |
| "npoints" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T01:06:41.491373Z", | |
| "start_time": "2018-04-30T01:06:41.487705Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Test default split_ps" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:33.391618Z", | |
| "start_time": "2018-04-30T03:32:33.372734Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# split points into kdtree\n", | |
| "\n", | |
| "def split_ps(point_set):\n", | |
| " \"\"\"split pointset at median value along most varying dimension\"\"\"\n", | |
| " #print point_set.size()\n", | |
| " num_points = point_set.size()[0]//2\n", | |
| " \n", | |
| " # find where to cut (median in most varying dimension)\n", | |
| " diff = point_set.max(dim=0, keepdim = True)[0] - point_set.min(dim=0, keepdim = True)[0]\n", | |
| " dim = torch.max(diff, dim = 1, keepdim = True)[1][0,0]\n", | |
| " cut = torch.median(point_set[:,dim], keepdim = True)[0][0]\n", | |
| " \n", | |
| " # get indices for left right and middle\n", | |
| " left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))\n", | |
| " right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))\n", | |
| " middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))\n", | |
| "\n", | |
| " # pad with middle points (slow)\n", | |
| " if torch.numel(left_idx) < num_points:\n", | |
| " left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)\n", | |
| " if torch.numel(right_idx) < num_points:\n", | |
| " right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)\n", | |
| "\n", | |
| " left_ps = torch.index_select(point_set, dim = 0, index = left_idx)\n", | |
| " right_ps = torch.index_select(point_set, dim = 0, index = right_idx)\n", | |
| " return left_ps, right_ps, dim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:33.671330Z", | |
| "start_time": "2018-04-30T03:32:33.668507Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# j = np.random.randint(l)\n", | |
| "point_set_n, class_label = dataset[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:36.798115Z", | |
| "start_time": "2018-04-30T03:32:33.779569Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 2.93 s, sys: 52 ms, total: 2.98 s\n", | |
| "Wall time: 3.01 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "# try default kdtree\n", | |
| "point_set = torch.FloatTensor(point_set_n)\n", | |
| "\n", | |
| "# construct kd-tree\n", | |
| "tree = [[] for i in range(levels + 1)]\n", | |
| "cutdim = [[] for i in range(levels)]\n", | |
| "tree[0].append(point_set)\n", | |
| "\n", | |
| "for level in range(levels):\n", | |
| " for item in tree[level]:\n", | |
| " left_ps, right_ps, dim = split_ps(item)\n", | |
| " tree[level+1].append(left_ps)\n", | |
| " tree[level+1].append(right_ps)\n", | |
| " cutdim[level].append(dim)\n", | |
| " cutdim[level].append(dim)\n", | |
| "\n", | |
| "cutdim_v = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]\n", | |
| "\n", | |
| "points = torch.stack(tree[-1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Compare with scipy.spatial.cKDTree" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:36.802590Z", | |
| "start_time": "2018-04-30T03:32:36.799872Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "point_set_n, class_label = dataset[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:37.020482Z", | |
| "start_time": "2018-04-30T03:32:36.804782Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 160 ms, sys: 4 ms, total: 164 ms\n", | |
| "Wall time: 166 ms\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "# but now we need to extract nodes from this and compare to make sure it has the same results\n", | |
| "from collections import defaultdict\n", | |
| "import scipy.spatial\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "\n", | |
| "def get_cutdims(tree, max_depth=7):\n", | |
| " \"\"\"Get cutdims from a scipy.spatial.KDTree.\"\"\"\n", | |
| " cutdims = defaultdict(list)\n", | |
| " tree_idxs = defaultdict(list)\n", | |
| "\n", | |
| " def _get_cutdims(tree, level=0, parent=None):\n", | |
| " if tree is None:\n", | |
| " # deal with premature leaf by repeating the leaf\n", | |
| " tree = parent\n", | |
| "\n", | |
| " if level >= max_depth:\n", | |
| " indices = tree.indices\n", | |
| "\n", | |
| " # make sure it's the right amount of indices for this depth\n", | |
| " n = 2**(max_depth - level)\n", | |
| " if len(indices) > n:\n", | |
| " # since we repeated the premature leafs we might get duplicate indices\n", | |
| " # or this might comes into play if the input is too large for the tree\n", | |
| " # print('crop', n, len(indices), level)\n", | |
| " inds = np.random.choice(range(len(indices)), n)\n", | |
| " indices = indices[inds]\n", | |
| " elif len(indices) < n:\n", | |
| " # pad if input is too small for tree\n", | |
| " # print('pad', n, len(indices), level)\n", | |
| " indices = np.concatenate([indices, indices[0:1].repeat(n - len(indices))])\n", | |
| "\n", | |
| " # end recursion\n", | |
| " tree_idxs[level].append(indices)\n", | |
| " return indices\n", | |
| "\n", | |
| " indices = np.concatenate([\n", | |
| " _get_cutdims(tree.lesser, level=level + 1, parent=tree),\n", | |
| " _get_cutdims(tree.greater, level=level + 1, parent=tree)\n", | |
| " ])\n", | |
| " if level < max_depth:\n", | |
| " tree_idxs[level].append(indices)\n", | |
| "\n", | |
| " # since we repeated premature leafs, we get invalid splits\n", | |
| " # in this case just use the parents\n", | |
| " split_dim = tree.split_dim\n", | |
| " if split_dim == -1:\n", | |
| " split_dim = parent.split_dim\n", | |
| " assert split_dim > -1\n", | |
| "\n", | |
| " cutdims[level].append(split_dim)\n", | |
| " cutdims[level].append(split_dim)\n", | |
| " return indices\n", | |
| "\n", | |
| " # init the recursive search\n", | |
| " _get_cutdims(tree, level=0)\n", | |
| "\n", | |
| " # post processes values\n", | |
| " tree_idxs = list(tree_idxs.values())\n", | |
| " cutdims = list(cutdims.values())\n", | |
| " return cutdims, tree_idxs\n", | |
| "\n", | |
| "\n", | |
| "def make_cKDTree(point_set, depth):\n", | |
| " \"\"\"\n", | |
| " Take in a numpy pointset and quickly build a kdtree.\n", | |
| "\n", | |
| " Returns:\n", | |
| " - cutdims: (list) a list containing the dimension cut on each node on each level\n", | |
| " - tree: (list) the datapoints split into multiple arrays on each level\n", | |
| "\n", | |
| " \"\"\"\n", | |
| " tree = scipy.spatial.cKDTree(point_set, leafsize=1, balanced_tree=True)\n", | |
| " cutdims, tree_idxs = get_cutdims(tree.tree, max_depth=depth)\n", | |
| " tree = [np.take(point_set, indices=indices, axis=0) for indices in tree_idxs]\n", | |
| " return cutdims, tree\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "cutdims2, tree2 = make_cKDTree(point_set_n, levels)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-25T06:32:01.680429Z", | |
| "start_time": "2018-04-25T06:32:01.430561Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:37.074922Z", | |
| "start_time": "2018-04-30T03:32:37.022108Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Now lets check that the final layer of the tree matches (we don't use the other layers)\n", | |
| "i=-1\n", | |
| "desired=torch.stack(tree[i]).numpy()\n", | |
| "actual=tree2[i]\n", | |
| "np.testing.assert_allclose(actual,desired, rtol=1e-4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2018-04-30T03:32:37.090620Z", | |
| "start_time": "2018-04-30T03:32:37.078816Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# also check that the cut dimensions match\n", | |
| "for i in range(len(cutdim)):\n", | |
| " assert (np.array(cutdim[i])==cutdims2[i]).all(), 'cutdims should match ind=%s' % i" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "ipython2", | |
| "language": "python", | |
| "name": "ipython2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.13" | |
| }, | |
| "toc": { | |
| "colors": { | |
| "hover_highlight": "#DAA520", | |
| "navigate_num": "#000000", | |
| "navigate_text": "#333333", | |
| "running_highlight": "#FF0000", | |
| "selected_highlight": "#FFD700", | |
| "sidebar_border": "#EEEEEE", | |
| "wrapper_background": "#FFFFFF" | |
| }, | |
| "moveMenuLeft": true, | |
| "nav_menu": { | |
| "height": "102px", | |
| "width": "252px" | |
| }, | |
| "navigate_menu": true, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "threshold": 4, | |
| "toc_cell": false, | |
| "toc_position": { | |
| "height": "549px", | |
| "left": "0px", | |
| "right": "1058px", | |
| "top": "110px", | |
| "width": "212px" | |
| }, | |
| "toc_section_display": "block", | |
| "toc_window_display": true, | |
| "widenNotebook": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment