Forked from gregcaporaso/neighbor-joining-experiments.ipynb
Created
January 20, 2017 13:32
-
-
Save sdwfrost/b2b54cfc2e009633d3e00117877bf9e2 to your computer and use it in GitHub Desktop.
initial experiments with implementing neighbor joining for scikit-bio
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:eeee5624356f270957bb474d6a81a050501cb69098420ec7a202dc491da8f437" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Working on this with help from the [Neighbor Joining Wikipedia page](http://en.wikipedia.org/wiki/Neighbour_joining) and [@justin212k](https://github.com/justin212k)'s [code](https://github.com/justin212k/scikit-bio/compare/neighbor_joining). Goal is to create a scipy linkage matrix, so the interface matches [SciPy's linkage interface](http://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html). " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from __future__ import division\n", | |
"from skbio.core.distance import DistanceMatrix\n", | |
"\n", | |
"def compute_q(dm):\n", | |
" q = zeros(dm.shape)\n", | |
" n = dm.shape[0]\n", | |
" for i in range(n):\n", | |
" for j in range(i):\n", | |
" q[i, j] = q[j, i] = ((n - 2) * dm[i, j]) - dm[i].sum() - dm[j].sum()\n", | |
" return DistanceMatrix(q, dm.ids)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"data = [[0, 5, 9, 9, 8],\n", | |
" [5, 0, 10, 10, 9],\n", | |
" [9, 10, 0, 8, 7],\n", | |
" [9, 10, 8, 0, 3],\n", | |
" [8, 9, 7, 3, 0]]\n", | |
"ids = list('abcde')\n", | |
"dm = DistanceMatrix(data, ids)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"q = compute_q(dm)\n", | |
"print q" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"5x5 distance matrix\n", | |
"IDs:\n", | |
"a, b, c, d, e\n", | |
"Data:\n", | |
"[[ 0. -50. -38. -34. -34.]\n", | |
" [-50. 0. -38. -34. -34.]\n", | |
" [-38. -38. 0. -40. -40.]\n", | |
" [-34. -34. -40. 0. -48.]\n", | |
" [-34. -34. -40. -48. 0.]]\n" | |
] | |
} | |
], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def pair_members_to_new_node(dm, i, j, dissallow_negative_branch_length):\n", | |
" n = dm.shape[0]\n", | |
" i_to_j = dm[i, j]\n", | |
" i_to_u = (0.5 * i_to_j) + (1 / (2 * (n - 2))) * (dm[i].sum() - dm[j].sum())\n", | |
" j_to_u = i_to_j - i_to_u\n", | |
" \n", | |
" if dissallow_negative_branch_length and i_to_u < 0:\n", | |
" i_to_u = 0\n", | |
" if dissallow_negative_branch_length and j_to_u < 0:\n", | |
" j_to_u = 0\n", | |
" \n", | |
" return i_to_u, j_to_u" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"pair_members_to_new_node(dm, 0, 1, True)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 6, | |
"text": [ | |
"(2.0, 3.0)" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def otu_to_new_node(dm, i, j, k, dissallow_negative_branch_length):\n", | |
" k_to_u = 0.5 * (dm[i, k] + dm[j, k] - dm[i, j])\n", | |
" \n", | |
" if dissallow_negative_branch_length and k_to_u < 0:\n", | |
" k_to_u = 0\n", | |
" \n", | |
" return k_to_u" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"otu_to_new_node(dm, 0, 1, 2, True)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 8, | |
"text": [ | |
"7.0" | |
] | |
} | |
], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import numpy as np\n", | |
"\n", | |
"def lowest_index(dm):\n", | |
" lowest_value = np.inf\n", | |
" for i in range(dm.shape[0]):\n", | |
" for j in range(i):\n", | |
" curr_index = i, j\n", | |
" curr_value = dm[curr_index]\n", | |
" if curr_value < lowest_value:\n", | |
" lowest_value = curr_value\n", | |
" lowest_index = curr_index\n", | |
" return lowest_index\n", | |
" " | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 9 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def compute_collapsed_dm(dm, i, j, dissallow_negative_branch_length, new_node_id=None):\n", | |
" in_n = dm.shape[0]\n", | |
" out_n = in_n - 1\n", | |
" new_node_id = new_node_id or \"(%s, %s)\" % (i, j)\n", | |
" out_ids = [new_node_id]\n", | |
" out_ids.extend([e for e in dm.ids if e not in (i, j)])\n", | |
" result = zeros((out_n, out_n))\n", | |
" for idx1, out_id1 in enumerate(out_ids[1:]):\n", | |
" result[0, idx1 + 1] = result[idx1 + 1, 0] = \\\n", | |
" otu_to_new_node(dm, i, j, out_id1, dissallow_negative_branch_length)\n", | |
" for idx2, out_id2 in enumerate(out_ids[1:idx1+1]):\n", | |
" result[idx1+1, idx2+1] = result[idx2+1, idx1+1] = dm[out_id1, out_id2]\n", | |
" return DistanceMatrix(result, out_ids)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 11 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from skbio.core.tree import TreeNode\n", | |
"\n", | |
"def nj(dm, dissallow_negative_branch_length=True,\n", | |
" result_constructor=TreeNode.from_newick):\n", | |
" while(dm.shape[0] > 3):\n", | |
" q = compute_q(dm)\n", | |
" idx1, idx2 = lowest_index(q)\n", | |
" pair_member_1 = dm.ids[idx1]\n", | |
" pair_member_2 = dm.ids[idx2]\n", | |
" pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm, \n", | |
" idx1,\n", | |
" idx2,\n", | |
" dissallow_negative_branch_length)\n", | |
" node_definition = \"(%s:%d, %s:%d)\" % (pair_member_1,\n", | |
" pair_member_1_len,\n", | |
" pair_member_2,\n", | |
" pair_member_2_len)\n", | |
" dm = compute_collapsed_dm(dm, pair_member_1, pair_member_2, \n", | |
" dissallow_negative_branch_length=dissallow_negative_branch_length,\n", | |
" new_node_id=node_definition)\n", | |
" \n", | |
" # Define the last node - this is an internal node\n", | |
" pair_member_1 = dm.ids[1]\n", | |
" pair_member_2 = dm.ids[2]\n", | |
" pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm,\n", | |
" pair_member_1,\n", | |
" pair_member_2,\n", | |
" dissallow_negative_branch_length)\n", | |
" internal_len = otu_to_new_node(dm, pair_member_1, pair_member_2, node_definition,\n", | |
" dissallow_negative_branch_length=dissallow_negative_branch_length, \n", | |
" )\n", | |
" newick = \"(%s:%d, %s:%d, %s:%d);\" % (pair_member_1, pair_member_1_len,\n", | |
" node_definition, internal_len,\n", | |
" pair_member_2, pair_member_2_len)\n", | |
"\n", | |
" return result_constructor(newick)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 17 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"newick_str = nj(dm, result_constructor=str)\n", | |
"print newick_str" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"(d:2, (c:4, (b:3, a:2):3):2, e:1);\n" | |
] | |
} | |
], | |
"prompt_number": 20 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" /-d\n", | |
" |\n", | |
" | /-c\n", | |
" |---------|\n", | |
"---------| | /-b\n", | |
" | \\--------|\n", | |
" | \\-a\n", | |
" |\n", | |
" \\-e\n" | |
] | |
} | |
], | |
"prompt_number": 19 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment