Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sdwfrost/b2b54cfc2e009633d3e00117877bf9e2 to your computer and use it in GitHub Desktop.
Save sdwfrost/b2b54cfc2e009633d3e00117877bf9e2 to your computer and use it in GitHub Desktop.
initial experiments with implementing neighbor joining for scikit-bio
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:eeee5624356f270957bb474d6a81a050501cb69098420ec7a202dc491da8f437"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Working on this with help from the [Neighbor Joining Wikipedia page](http://en.wikipedia.org/wiki/Neighbour_joining) and [@justin212k](https://github.com/justin212k)'s [code](https://github.com/justin212k/scikit-bio/compare/neighbor_joining). Goal is to create a scipy linkage matrix, so the interface matches [SciPy's linkage interface](http://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html). "
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from __future__ import division\n",
"from skbio.core.distance import DistanceMatrix\n",
"\n",
"def compute_q(dm):\n",
" q = zeros(dm.shape)\n",
" n = dm.shape[0]\n",
" for i in range(n):\n",
" for j in range(i):\n",
" q[i, j] = q[j, i] = ((n - 2) * dm[i, j]) - dm[i].sum() - dm[j].sum()\n",
" return DistanceMatrix(q, dm.ids)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"data = [[0, 5, 9, 9, 8],\n",
" [5, 0, 10, 10, 9],\n",
" [9, 10, 0, 8, 7],\n",
" [9, 10, 8, 0, 3],\n",
" [8, 9, 7, 3, 0]]\n",
"ids = list('abcde')\n",
"dm = DistanceMatrix(data, ids)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"q = compute_q(dm)\n",
"print q"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"5x5 distance matrix\n",
"IDs:\n",
"a, b, c, d, e\n",
"Data:\n",
"[[ 0. -50. -38. -34. -34.]\n",
" [-50. 0. -38. -34. -34.]\n",
" [-38. -38. 0. -40. -40.]\n",
" [-34. -34. -40. 0. -48.]\n",
" [-34. -34. -40. -48. 0.]]\n"
]
}
],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def pair_members_to_new_node(dm, i, j, dissallow_negative_branch_length):\n",
" n = dm.shape[0]\n",
" i_to_j = dm[i, j]\n",
" i_to_u = (0.5 * i_to_j) + (1 / (2 * (n - 2))) * (dm[i].sum() - dm[j].sum())\n",
" j_to_u = i_to_j - i_to_u\n",
" \n",
" if dissallow_negative_branch_length and i_to_u < 0:\n",
" i_to_u = 0\n",
" if dissallow_negative_branch_length and j_to_u < 0:\n",
" j_to_u = 0\n",
" \n",
" return i_to_u, j_to_u"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"pair_members_to_new_node(dm, 0, 1, True)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 6,
"text": [
"(2.0, 3.0)"
]
}
],
"prompt_number": 6
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def otu_to_new_node(dm, i, j, k, dissallow_negative_branch_length):\n",
" k_to_u = 0.5 * (dm[i, k] + dm[j, k] - dm[i, j])\n",
" \n",
" if dissallow_negative_branch_length and k_to_u < 0:\n",
" k_to_u = 0\n",
" \n",
" return k_to_u"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"otu_to_new_node(dm, 0, 1, 2, True)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 8,
"text": [
"7.0"
]
}
],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"\n",
"def lowest_index(dm):\n",
" lowest_value = np.inf\n",
" for i in range(dm.shape[0]):\n",
" for j in range(i):\n",
" curr_index = i, j\n",
" curr_value = dm[curr_index]\n",
" if curr_value < lowest_value:\n",
" lowest_value = curr_value\n",
" lowest_index = curr_index\n",
" return lowest_index\n",
" "
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 9
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def compute_collapsed_dm(dm, i, j, dissallow_negative_branch_length, new_node_id=None):\n",
" in_n = dm.shape[0]\n",
" out_n = in_n - 1\n",
" new_node_id = new_node_id or \"(%s, %s)\" % (i, j)\n",
" out_ids = [new_node_id]\n",
" out_ids.extend([e for e in dm.ids if e not in (i, j)])\n",
" result = zeros((out_n, out_n))\n",
" for idx1, out_id1 in enumerate(out_ids[1:]):\n",
" result[0, idx1 + 1] = result[idx1 + 1, 0] = \\\n",
" otu_to_new_node(dm, i, j, out_id1, dissallow_negative_branch_length)\n",
" for idx2, out_id2 in enumerate(out_ids[1:idx1+1]):\n",
" result[idx1+1, idx2+1] = result[idx2+1, idx1+1] = dm[out_id1, out_id2]\n",
" return DistanceMatrix(result, out_ids)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 11
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from skbio.core.tree import TreeNode\n",
"\n",
"def nj(dm, dissallow_negative_branch_length=True,\n",
" result_constructor=TreeNode.from_newick):\n",
" while(dm.shape[0] > 3):\n",
" q = compute_q(dm)\n",
" idx1, idx2 = lowest_index(q)\n",
" pair_member_1 = dm.ids[idx1]\n",
" pair_member_2 = dm.ids[idx2]\n",
" pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm, \n",
" idx1,\n",
" idx2,\n",
" dissallow_negative_branch_length)\n",
" node_definition = \"(%s:%d, %s:%d)\" % (pair_member_1,\n",
" pair_member_1_len,\n",
" pair_member_2,\n",
" pair_member_2_len)\n",
" dm = compute_collapsed_dm(dm, pair_member_1, pair_member_2, \n",
" dissallow_negative_branch_length=dissallow_negative_branch_length,\n",
" new_node_id=node_definition)\n",
" \n",
" # Define the last node - this is an internal node\n",
" pair_member_1 = dm.ids[1]\n",
" pair_member_2 = dm.ids[2]\n",
" pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm,\n",
" pair_member_1,\n",
" pair_member_2,\n",
" dissallow_negative_branch_length)\n",
" internal_len = otu_to_new_node(dm, pair_member_1, pair_member_2, node_definition,\n",
" dissallow_negative_branch_length=dissallow_negative_branch_length, \n",
" )\n",
" newick = \"(%s:%d, %s:%d, %s:%d);\" % (pair_member_1, pair_member_1_len,\n",
" node_definition, internal_len,\n",
" pair_member_2, pair_member_2_len)\n",
"\n",
" return result_constructor(newick)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 17
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"newick_str = nj(dm, result_constructor=str)\n",
"print newick_str"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"(d:2, (c:4, (b:3, a:2):3):2, e:1);\n"
]
}
],
"prompt_number": 20
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
" /-d\n",
" |\n",
" | /-c\n",
" |---------|\n",
"---------| | /-b\n",
" | \\--------|\n",
" | \\-a\n",
" |\n",
" \\-e\n"
]
}
],
"prompt_number": 19
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment