Last active
March 20, 2020 13:12
-
-
Save johnhw/117be58261d6f22a4cfbf285cbe67b65 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.metrics import confusion_matrix | |
from collections import defaultdict | |
def itr(confusion_matrix, timings, eps=1e-8): | |
"""Take a confusion matrix of the form | |
(actual) a b c d | |
(intended) | |
a 10 0 1 9 | |
b 2 11 7 2 | |
c 0 0 15 0 | |
d 5 3 12 9 | |
and a list of the average time to specify each class a,b,c,d,... | |
Maximum entropy is log2(n_classes) | |
Residual entropy in each class is sum(p * log2(p)) | |
We compute (max entropy - residual_entropy) and divide by the time per intended class | |
to get approximate maximum possible ITR | |
return the average information transfer rate and the standard deviation | |
""" | |
norm_matrix = (confusion_matrix.T / np.sum(confusion_matrix, axis=1)).T | |
entropy = np.sum(norm_matrix * np.log2(norm_matrix + eps), axis=1) | |
max_entropy = np.log2(len(confusion_matrix)) | |
bps_per_class = (max_entropy + entropy) / timings | |
return np.mean(bps_per_class), np.std(bps_per_class) | |
## Example of use | |
# data in the form (intended_class, detected_class, duration_seconds) | |
example_data = [ | |
(0, 0, 0.5), | |
(1, 1, 0.54), | |
(3, 3, 0.53), | |
(3, 2, 0.44), | |
(1, 1, 0.42), | |
(0, 0, 0.33), | |
(2, 2, 0.92), | |
] | |
# compute confusion matrix | |
confusion = confusion_matrix( | |
y_true=[intended for (intended, actual, duration) in example_data], | |
y_pred=[actual for (intended, actual, duration) in example_data], | |
) | |
## compute average duration of each class | |
times = defaultdict(float) | |
n = defaultdict(int) | |
for intended, actual, duration in example_data: | |
times[intended] += duration | |
n[intended] += duration | |
mean_duration = [times.get(i, 0) / n.get(i, 1) for i in range(len(confusion))] | |
## show mean and std. itr | |
print(itr(confusion, mean_duration)) | |
############## | |
## tests | |
confusion = np.eye(8) | |
mean_duration = [1]*8 | |
# should get ~ 3 bits / second | |
mean, sd = itr(confusion, mean_duration) | |
assert abs(mean-3.0)<1e-5 | |
# should get ~ 6 bits/second | |
mean, sd = itr(confusion, [0.5]*8) | |
print(mean) | |
assert abs(mean-6.0)<1e-5 | |
# should get ~ 1.5 bits/second | |
mean, sd = itr(confusion, [2.0]*8) | |
assert abs(mean-1.5)<1e-5 | |
## should be pretty much close to 0 for random performance, less than 0.5 bits/second | |
for i in range(20): | |
confusion = np.random.randint(0,200,(8,8)) | |
mean, sd = itr(confusion, [1]*8) | |
assert mean<0.5 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment