Created
July 17, 2020 14:42
-
-
Save eddjberry/43ceca3d29905781ab6c8dab07e8e1da to your computer and use it in GitHub Desktop.
Generate the data required for a partial dependence plot from a PySpark model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def partial_dependency_data(df, model, col, values, sample_fraction = 0.1): | |
# empty list for predictions | |
avg_predictions = list() | |
# take a sample of the data to use | |
df_sample = df.sample(fraction = sample_fraction) | |
# loop through the values | |
for val in values: | |
# take a sample of the data and replace | |
# the col of interest with the val | |
df_sample_replace = (df_sample | |
.drop(col) | |
.withColumn(col, lit(val))) | |
# generate predictions | |
df_predictions = model.transform(df_sample_replace) | |
# get the prob element of the predictions | |
df_predictions_refine = df_predictions.select(col, prob_element('probability')) | |
# select the prob element to refine | |
df_predictions_refine = df_predictions_refine.selectExpr(col, "`<lambda>(probability)` AS probability") | |
# calculate the average prediction | |
avg_prediction = df_predictions_refine.agg(F.mean(df_predictions_refine.probability)).collect() | |
# append the prediction | |
avg_predictions.append(avg_prediction[0]) | |
# put the average probabilities in a dataframe | |
df_pd = pd.DataFrame(avg_predictions, columns = ['avg_probability']) | |
# add the values for the column of interest | |
df_pd[col] = values | |
return df_pd |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment