Skip to content

Instantly share code, notes, and snippets.

@bmcfee
Created July 24, 2014 21:27
Show Gist options
  • Save bmcfee/dd267dae2f02bfc58752 to your computer and use it in GitHub Desktop.
Save bmcfee/dd267dae2f02bfc58752 to your computer and use it in GitHub Desktop.
OVR-friendly grid search
# SKLearn's one-vs-rest class requires that the (binary) estimator object implements decision_function() or predict_proba().
# If you want the internal estimator to contain a parameter sweeping layer (so that each ovr classifier gets optimized separately),
# this fails due to the following chain of events:
#
# 1. OVR checks `hasattr(estimator, 'predict_proba')` at construction time
# 2. `hasattr()` tries to call `getattr(estimator, 'predict_proba')` and fails if an exception is thrown
# 3. Because the estimator has not yet been fit, it has no `best_estimator_` property, so it throws an exception and fails
# 4. `hasattr()` misinterprets this exception, and returns false.
#
# We can circumvent this probelm by putting a wrapper on the predict_proba and decision_function methods, but this is a dirty, dirty hack.
class MyGridSearchCV(sklearn.grid_search.GridSearchCV):
@property
def predict_proba(self):
if hasattr(self, 'best_estimator_'):
return super(MyGridSearchCV, self).predict_proba
return None
@property
def decision_function(self):
if hasattr(self, 'best_estimator_'):
return super(MyGridSearchCV, self).decision_function
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment