Created
October 23, 2011 09:05
-
-
Save ogrisel/1307136 to your computer and use it in GitHub Desktop.
Expected Mutual Information profiling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Line # Hits Time Per Hit % Time Line Contents | |
| ============================================================== | |
| 627 def expected_mutual_information(contingency, n_samples): | |
| 628 """Calculate the expected mutual information for two labelings.""" | |
| 629 25 82 3.3 0.0 R, C = contingency.shape | |
| 630 25 53 2.1 0.0 N = n_samples | |
| 631 25 1059 42.4 0.0 a = np.sum(contingency, axis=1, dtype='int') | |
| 632 25 932 37.3 0.0 b = np.sum(contingency, axis=0, dtype='int') | |
| 633 25 58 2.3 0.0 emi = 0 | |
| 634 1839 4005 2.2 0.0 for i in range(R): | |
| 635 165963 346569 2.1 1.0 for j in range(C): | |
| 636 164149 1155150 7.0 3.4 start = int(max(a[i] + b[j] - N, 1)) | |
| 637 164149 804241 4.9 2.4 end = int(min(a[i], b[j]) + 1) | |
| 638 331098 882214 2.7 2.6 for nij in range(start, end): | |
| 639 166949 464229 2.8 1.4 term1 = nij / float(N) | |
| 640 166949 2607495 15.6 7.7 term2 = np.log(float(N * nij) / (a[i] * b[j])) | |
| 641 # a! / (a - n)! | |
| 642 166949 3473142 20.8 10.3 term3a = np.multiply.reduce(range(a[i] - nij + 1, a[i] + 1)) | |
| 643 # b! / (b - n)! | |
| 644 166949 3190928 19.1 9.4 term3b = np.multiply.reduce(range(b[j] - nij + 1, b[j] + 1)) | |
| 645 # (N - a)! / N! | |
| 646 166949 2952123 17.7 8.7 t = np.multiply.reduce(range(N - a[i] + 1, N + 1)) | |
| 647 166949 578205 3.5 1.7 if t == 0: | |
| 648 continue | |
| 649 166949 1683693 10.1 5.0 term3c = 1. / t | |
| 650 # (N - b)! / (N - a - b - n)! | |
| 651 166949 902521 5.4 2.7 num3d = N - b[j] + 1 | |
| 652 166949 1057582 6.3 3.1 den3d = N - a[i] - b[j] + nij + 1 | |
| 653 166949 386928 2.3 1.1 if num3d > den3d: | |
| 654 16173 256298 15.8 0.8 term3d = np.multiply.reduce(range(den3d, num3d)) | |
| 655 else: | |
| 656 150776 2519192 16.7 7.5 term3d = np.multiply.reduce(range(num3d, den3d)) | |
| 657 150776 585305 3.9 1.7 term3d = 1. / term3d | |
| 658 # 1 / n! | |
| 659 166949 8339780 50.0 24.7 term3e = 1. / factorial(nij) | |
| 660 # Add the product of all terms | |
| 661 166949 405512 2.4 1.2 emi += (term1 * term2 * term3a * term3b | |
| 662 166949 1176503 7.0 3.5 * term3c * term3d * term3e) | |
| 663 25 49 2.0 0.0 return emi |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To reproduce the preceding plot, install line profiler in IPython as explained in the profiling scikit-learn documentation: