Last active
December 16, 2015 18:48
-
-
Save jnothman/5480026 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tool to examine the output of model selection search results from scikit-learn (assuming #1787). | |
Pandas might be more appropriate, but I haven't worked out how to do group_best there... | |
For example: | |
>>> my_search = GridSearchCV(est, param_dict={'a': [...], 'b': [...], 'c': [...]}) | |
>>> my_search.fit(X, y) | |
>>> rw = ResultsWrangler(my_search.grid_results_, my_search.fold_results_) | |
>>> grouped = rw.group_best(['a', 'b']) | |
>>> print(zip(grouped.parameters, grouped.scores)) | |
""" | |
import itertools | |
import numpy as np | |
import numpy.ma.mrecords as mrecords | |
def params_to_array(parameter_dicts): | |
fields = {} | |
for params in parameter_dicts: | |
for name, value in params.iteritems(): | |
fields[name] = value # take an example for masking | |
field_names = sorted(fields.iterkeys()) | |
data = [] | |
mask = [] | |
for params in parameter_dicts: | |
row = [(params[name], False) if name in params | |
else (fields[name], True) | |
for name in field_names] | |
rdata, rmask = zip(*row) | |
data.append(rdata) | |
mask.append(rmask) | |
return mrecords.fromrecords(data, mask=mask, names=field_names) | |
def params_array_to_dicts(self, parameters): | |
field_names = parameters.dtype.names | |
return [ | |
{name: params[name].item() | |
for name in field_names | |
if not params[name].mask | |
} | |
for params in parameters | |
] | |
class ResultsWrangler(object): | |
def __init__(self, grid_results, fold_results, parameters=None, score_field='test_score', greater_is_better=True): | |
self.points = grid_results | |
self.folds = fold_results | |
if parameters is None: | |
parameters = params_to_array(grid_results['parameters']) | |
self.parameters = parameters | |
# TODO: check shapes/types are valid | |
self.score_field = score_field | |
self.greater_is_better = greater_is_better | |
def __iter__(self): | |
return (itertools.izip(self.parameters, self.points, self.folds)) | |
def __len__(self): | |
return self.folds.shape[0] | |
def __getitem__(self, index): | |
return ResultsWrangler(self.points[index], self.folds[index], self.parameters[index], score_field=self.score_field, greater_is_better=self.greater_is_better) | |
@property | |
def scores(self): | |
return self.points[self.score_field] | |
def groups(self, fields=None): | |
"""Assigns each entry an integer corresponding to distinct settings of `fields`.""" | |
# XXX: should this also return the unique parameter settings? | |
parameters = self.parameters | |
if fields is not None: | |
parameters = parameters[list(fields)] | |
_, inverse = np.unique(parameters, return_inverse=True) | |
return inverse | |
def all_best(self, tol=0.001): | |
"""Returns a `ResultsWrangler` including only results within `tol` of the best score.""" | |
scores = self.scores | |
if self.greater_is_better: | |
return self[scores > scores.max() - tol] | |
else: | |
return self[scores < scores.min() + tol] | |
def group_best(self, fields=None): | |
"""Returns a `ResultsWrangler` of the best-scoring parameters for each distinct value of `fields`.""" | |
if self.greater_is_better: | |
scores = self.scores | |
else: | |
scores = -self.scores | |
# Thanks to http://stackoverflow.com/questions/8623047/group-by-max-or-min-in-a-numpy-array/8623168#8623168 | |
# Sort with major key groups, minor key score: | |
groups = self.groups(fields) | |
order = np.lexsort((scores, groups)) | |
groups = groups[order] | |
# Index marks change from one group to next, i.e. within-group max | |
index = np.empty(len(groups), 'bool') | |
index[-1] = True | |
index[:-1] = groups[1:] != groups[:-1] | |
return self[order[index]] | |
def compare_wilcoxon(self, ind1, ind2): | |
diffs = np.diff(self.folds[self.score_field][[ind1, ind2]], axis=0).flat | |
order = np.argsort(diffs) | |
return np.abs(np.dot(order + 1, np.sign(diffs))) | |
# TODO: plots | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment