Created
September 1, 2017 19:33
-
-
Save nzw0301/5964a1f48bf91f2fe582689138d1664b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ref. Fig. 1 in _[A Scalable Asynchronous Distributed Algorithm for Topic Modeling](https://www.cs.utexas.edu/~rofuyu/papers/nomad-lda-www.pdf)_." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"struct FPlusTree\n", | |
" T::Int\n", | |
" f::Array{Float64, 1}\n", | |
" function FPlusTree(p::Array{Float64, 1})\n", | |
" function initTree(p::Array{Float64, 1})\n", | |
" T = length(p)\n", | |
" f = zeros(T*2-1)\n", | |
" for i in 2*T-1:-1:1\n", | |
" f[i] = if T <= i\n", | |
" p[i-T+1]\n", | |
" else\n", | |
" f[2i] + f[2i+1]\n", | |
" end\n", | |
" end\n", | |
" return f\n", | |
" end\n", | |
" new(length(p), initTree(p))\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"sample (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function sample(tree::FPlusTree)\n", | |
" u = tree.f[1] * rand()\n", | |
" i = 1\n", | |
" while (i < tree.T)\n", | |
" i = if u >= tree.f[2i]\n", | |
" u -= tree.f[2i]\n", | |
" 2i+1\n", | |
" else\n", | |
" 2i\n", | |
" end\n", | |
" end\n", | |
" return i - tree.T + 1\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"update (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function update(tree::FPlusTree, t::Int, delta::Float64)\n", | |
" i = t + tree.T - 1\n", | |
" while (i > 0)\n", | |
" tree.f[i] += delta\n", | |
" i = div(i, 2)\n", | |
" end\n", | |
" \n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4-element Array{Float64,1}:\n", | |
" 0.3\n", | |
" 1.5\n", | |
" 0.4\n", | |
" 0.3" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"p = [0.3, 1.5, 0.4, 0.3] # unnormalized categorical distribution: index represent topic id, element is unnormalized probability" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"FPlusTree(4, [2.5, 1.8, 0.7, 0.3, 1.5, 0.4, 0.3])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tree = FPlusTree(p)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Sample topic" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dict{Int64,Int64} with 4 entries:\n", | |
" 4 => 120038\n", | |
" 2 => 600459\n", | |
" 3 => 159441\n", | |
" 1 => 120062" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# sample topics\n", | |
"counter = Dict{Int, Int}()\n", | |
"\n", | |
"for _ in 1:1000000\n", | |
" z = sample(tree)\n", | |
" counter[z] = get(counter, z, 0) + 1\n", | |
"end\n", | |
"counter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4-element Array{Float64,1}:\n", | |
" 0.12\n", | |
" 0.6 \n", | |
" 0.16\n", | |
" 0.12" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"p /= sum(p) # original categorical distribution" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Update" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7-element Array{Float64,1}:\n", | |
" 2.5\n", | |
" 1.8\n", | |
" 0.7\n", | |
" 0.3\n", | |
" 1.5\n", | |
" 0.4\n", | |
" 0.3" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# before update\n", | |
"tree.f" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# topic 3's value + 1.0\n", | |
"update(tree, 3, 1.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7-element Array{Float64,1}:\n", | |
" 3.5\n", | |
" 1.8\n", | |
" 1.7\n", | |
" 0.3\n", | |
" 1.5\n", | |
" 1.4\n", | |
" 0.3" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# after update\n", | |
"tree.f" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dict{Int64,Int64} with 4 entries:\n", | |
" 4 => 86176\n", | |
" 2 => 429227\n", | |
" 3 => 399090\n", | |
" 1 => 85507" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# sample from updated categorical distribution\n", | |
"counter = Dict{Int, Int}()\n", | |
"\n", | |
"for _ in 1:1000000\n", | |
" z = sample(tree)\n", | |
" counter[z] = get(counter, z, 0) + 1\n", | |
"end\n", | |
"counter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 0.6.0", | |
"language": "julia", | |
"name": "julia-0.6" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "0.6.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment