Last active
September 3, 2024 11:47
-
-
Save hermidalc/0ee8b3846b0ffd368dd6b7f3b310adcc to your computer and use it in GitHub Desktop.
Inspect any scikit-learn fitted Pipeline to transform a feature metadata pandas DataFrame through the Pipeline and add model interpretation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def transform_feature_meta(pipe, feature_meta): | |
transformed_feature_meta = None | |
for estimator in pipe: | |
if isinstance(estimator, ColumnTransformer): | |
for _, trf_pipe, trf_columns in estimator.transformers_: | |
if isinstance(trf_pipe, str) and trf_pipe == 'drop': | |
trf_feature_meta = feature_meta.iloc[ | |
~feature_meta.index.isin(trf_columns)] | |
elif ((isinstance(trf_columns, slice) | |
and (isinstance(trf_columns.start, str) | |
or isinstance(trf_columns.stop, str))) | |
or isinstance(trf_columns[0], str)): | |
trf_feature_meta = feature_meta.loc[trf_columns] | |
else: | |
trf_feature_meta = feature_meta.iloc[trf_columns] | |
if isinstance(trf_pipe, BaseEstimator): | |
for trf_estimator in trf_pipe: | |
if hasattr(trf_estimator, 'get_support'): | |
trf_feature_meta = trf_feature_meta.loc[ | |
trf_estimator.get_support()] | |
elif hasattr(trf_estimator, 'get_feature_names'): | |
trf_new_feature_names = ( | |
trf_estimator.get_feature_names( | |
input_features=(trf_feature_meta.index | |
.values)).astype(str)) | |
trf_feature_meta = pd.DataFrame( | |
np.repeat(trf_feature_meta.values, [ | |
np.sum(np.char.startswith( | |
trf_new_feature_names, | |
'{}_'.format(feature_name))) | |
for feature_name in trf_feature_meta.index | |
], axis=0), columns=trf_feature_meta.columns, | |
index=trf_new_feature_names) | |
if transformed_feature_meta is None: | |
transformed_feature_meta = trf_feature_meta | |
else: | |
transformed_feature_meta = pd.concat( | |
[transformed_feature_meta, trf_feature_meta], axis=0) | |
else: | |
if transformed_feature_meta is None: | |
transformed_feature_meta = feature_meta | |
if hasattr(estimator, 'get_support'): | |
transformed_feature_meta = ( | |
transformed_feature_meta.loc[estimator.get_support()]) | |
elif hasattr(estimator, 'get_feature_names'): | |
new_feature_names = estimator.get_feature_names( | |
input_features=transformed_feature_meta.index.values | |
).astype(str) | |
transformed_feature_meta = pd.DataFrame( | |
np.repeat(transformed_feature_meta.values, [ | |
np.sum(np.char.startswith( | |
new_feature_names, '{}_'.format(feature_name))) | |
for feature_name in transformed_feature_meta.index | |
], axis=0), columns=transformed_feature_meta.columns, | |
index=new_feature_names) | |
final_estimator = pipe[-1] | |
feature_weights = explain_weights_df( | |
final_estimator, feature_names=transformed_feature_meta.index.values) | |
if feature_weights is None and hasattr(final_estimator, 'estimator_'): | |
feature_weights = explain_weights_df( | |
final_estimator.estimator_, | |
feature_names=transformed_feature_meta.index.values) | |
if feature_weights is not None: | |
feature_weights.set_index('feature', inplace=True, | |
verify_integrity=True) | |
feature_weights.columns = map(str.title, feature_weights.columns) | |
transformed_feature_meta = transformed_feature_meta.join( | |
feature_weights, how='inner') | |
transformed_feature_meta.index.rename('Feature', inplace=True) | |
return transformed_feature_meta |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If you do not want to use ELI5, replace https://gist.github.com/hermidalc/0ee8b3846b0ffd368dd6b7f3b310adcc#file-transform_feature_meta-py-L57-L68 with: