Skip to content

Instantly share code, notes, and snippets.

@johnhw
Last active March 20, 2020 13:12
Show Gist options
  • Save johnhw/117be58261d6f22a4cfbf285cbe67b65 to your computer and use it in GitHub Desktop.
Save johnhw/117be58261d6f22a4cfbf285cbe67b65 to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.metrics import confusion_matrix
from collections import defaultdict
def itr(confusion_matrix, timings, eps=1e-8):
"""Take a confusion matrix of the form
(actual) a b c d
(intended)
a 10 0 1 9
b 2 11 7 2
c 0 0 15 0
d 5 3 12 9
and a list of the average time to specify each class a,b,c,d,...
Maximum entropy is log2(n_classes)
Residual entropy in each class is sum(p * log2(p))
We compute (max entropy - residual_entropy) and divide by the time per intended class
to get approximate maximum possible ITR
return the average information transfer rate and the standard deviation
"""
norm_matrix = (confusion_matrix.T / np.sum(confusion_matrix, axis=1)).T
entropy = np.sum(norm_matrix * np.log2(norm_matrix + eps), axis=1)
max_entropy = np.log2(len(confusion_matrix))
bps_per_class = (max_entropy + entropy) / timings
return np.mean(bps_per_class), np.std(bps_per_class)
## Example of use
# data in the form (intended_class, detected_class, duration_seconds)
example_data = [
(0, 0, 0.5),
(1, 1, 0.54),
(3, 3, 0.53),
(3, 2, 0.44),
(1, 1, 0.42),
(0, 0, 0.33),
(2, 2, 0.92),
]
# compute confusion matrix
confusion = confusion_matrix(
y_true=[intended for (intended, actual, duration) in example_data],
y_pred=[actual for (intended, actual, duration) in example_data],
)
## compute average duration of each class
times = defaultdict(float)
n = defaultdict(int)
for intended, actual, duration in example_data:
times[intended] += duration
n[intended] += duration
mean_duration = [times.get(i, 0) / n.get(i, 1) for i in range(len(confusion))]
## show mean and std. itr
print(itr(confusion, mean_duration))
##############
## tests
confusion = np.eye(8)
mean_duration = [1]*8
# should get ~ 3 bits / second
mean, sd = itr(confusion, mean_duration)
assert abs(mean-3.0)<1e-5
# should get ~ 6 bits/second
mean, sd = itr(confusion, [0.5]*8)
print(mean)
assert abs(mean-6.0)<1e-5
# should get ~ 1.5 bits/second
mean, sd = itr(confusion, [2.0]*8)
assert abs(mean-1.5)<1e-5
## should be pretty much close to 0 for random performance, less than 0.5 bits/second
for i in range(20):
confusion = np.random.randint(0,200,(8,8))
mean, sd = itr(confusion, [1]*8)
assert mean<0.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment