Created
July 14, 2017 18:42
-
-
Save jeetsukumaran/605de7053e9122af39d091f74bdd946d to your computer and use it in GitHub Desktop.
Randomly partitions a set of elements using the Dirichlet process
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/env python | |
| ############################################################################### | |
| ## | |
| ## Copyright 2017 Jeet Sukumaran. | |
| ## | |
| ## This program is free software; you can redistribute it and/or modify | |
| ## it under the terms of the GNU General Public License as published by | |
| ## the Free Software Foundation; either version 3 of the License, or | |
| ## (at your option) any later version. | |
| ## | |
| ## This program is distributed in the hope that it will be useful, | |
| ## but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| ## GNU General Public License for more details. | |
| ## | |
| ## You should have received a copy of the GNU General Public License along | |
| ## with this program. If not, see <http://www.gnu.org/licenses/>. | |
| ## | |
| ############################################################################### | |
| """ | |
| Randomly partitions a set of elements using the Dirichlet process. | |
| """ | |
| import argparse | |
| import random | |
| def weighted_index_choice(weights, sum_of_weights, rng): | |
| """ | |
| (From: http://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/) | |
| The following is a simple function to implement weighted random choice in | |
| Python. Given a list of weights, it returns an index randomly, according | |
| to these weights [1]. | |
| For example, given [2, 3, 5] it returns 0 (the index of the first element) | |
| with probability 0.2, 1 with probability 0.3 and 2 with probability 0.5. | |
| The weights need not sum up to anything in particular, and can actually be | |
| arbitrary Python floating point numbers. | |
| If we manage to sort the weights in descending order before passing them | |
| to weighted_choice_sub, it will run even faster, since the random call | |
| returns a uniformly distributed value and larger chunks of the total | |
| weight will be skipped in the beginning. | |
| """ | |
| rnd = rng.uniform(0, 1) * sum_of_weights | |
| for i, w in enumerate(weights): | |
| rnd -= w | |
| if rnd < 0: | |
| return i | |
| def sample_partition( | |
| number_of_elements, | |
| scaling_parameter, | |
| rng,): | |
| groups = [] | |
| # element_ids = ["t{}".format(i+1) for i in range(number_of_elements)] | |
| element_ids = [i+1 for i in range(number_of_elements)] | |
| # element_ids = [chr(65+i) for i in range(number_of_elements)] | |
| # element_ids = [chr(97+i) for i in range(number_of_elements)] | |
| rng.shuffle(element_ids) | |
| for i, element_id in enumerate(element_ids): | |
| probs = [] | |
| element_idx = i + 1 | |
| if i == 0: | |
| groups.append([element_id]) | |
| continue | |
| p_new = scaling_parameter/(scaling_parameter + element_idx - 1.0) | |
| probs.append(p_new) | |
| for group in groups: | |
| p = len(group)/(scaling_parameter + element_idx - 1.0) | |
| probs.append(p) | |
| assert abs(sum(probs) - 1.0) <= 1e-5 | |
| selected_idx = weighted_index_choice( | |
| weights=probs, | |
| sum_of_weights=1.0, | |
| rng=rng) | |
| if selected_idx == 0: | |
| groups.append([element_id]) | |
| else: | |
| groups[selected_idx-1].append(element_id) | |
| return groups | |
| def main(): | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("-K", "--number-of-elements", | |
| type=int, | |
| default=10, | |
| help="Number of elements in the set. Default: %(default)s.") | |
| parser.add_argument("-a", "--scaling-parameter", "--alpha", | |
| type=float, | |
| default=1.5, | |
| help="(Anti-)Concentration or scaling parameter:" | |
| " low values result in a more clumpier/clustered" | |
| " partitions, while higher values result in a more" | |
| " dispersed partitions. Default: %(default)s." | |
| ) | |
| parser.add_argument("-n", "--num-replicates", | |
| type=int, | |
| default=10, | |
| help="How many draws to run. Default: %(default)s.") | |
| parser.add_argument("-v", "--verbosity", | |
| type=int, | |
| default=1, | |
| help="How much to report about what is going on" | |
| " with 0 being almost completely quiet and" | |
| " higher numbers reporting more and more." | |
| " Default: %(default)s.") | |
| parser.add_argument("-z", "--random-seed", | |
| type=int, | |
| default=None, | |
| help="Seed for random number generator.") | |
| args = parser.parse_args() | |
| rng = random.Random(args.random_seed) | |
| num_subsets = [] | |
| num_elements_in_subsets = [] | |
| for rep_idx in range(args.num_replicates): | |
| partition = sample_partition( | |
| number_of_elements=args.number_of_elements, | |
| scaling_parameter=args.scaling_parameter, | |
| rng=rng) | |
| if args.verbosity >= 1: | |
| print(partition) | |
| num_subsets.append(len(partition)) | |
| num_elements_in_subsets.append( sum(len(s) for s in partition)/float(len(partition)) ) | |
| print("---") | |
| print("Mean number of subsets per partition: {}".format( | |
| sum(num_subsets) / float(len(num_subsets)))) | |
| print(" Mean number of elements per subset: {}".format( | |
| sum(num_elements_in_subsets) / float(len(num_elements_in_subsets)))) | |
| if __name__ == '__main__': | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment