Skip to content

Instantly share code, notes, and snippets.

@eddjberry
Created July 17, 2020 14:42
Show Gist options
  • Save eddjberry/43ceca3d29905781ab6c8dab07e8e1da to your computer and use it in GitHub Desktop.
Save eddjberry/43ceca3d29905781ab6c8dab07e8e1da to your computer and use it in GitHub Desktop.
Generate the data required for a partial dependence plot from a PySpark model
def partial_dependency_data(df, model, col, values, sample_fraction = 0.1):
# empty list for predictions
avg_predictions = list()
# take a sample of the data to use
df_sample = df.sample(fraction = sample_fraction)
# loop through the values
for val in values:
# take a sample of the data and replace
# the col of interest with the val
df_sample_replace = (df_sample
.drop(col)
.withColumn(col, lit(val)))
# generate predictions
df_predictions = model.transform(df_sample_replace)
# get the prob element of the predictions
df_predictions_refine = df_predictions.select(col, prob_element('probability'))
# select the prob element to refine
df_predictions_refine = df_predictions_refine.selectExpr(col, "`<lambda>(probability)` AS probability")
# calculate the average prediction
avg_prediction = df_predictions_refine.agg(F.mean(df_predictions_refine.probability)).collect()
# append the prediction
avg_predictions.append(avg_prediction[0])
# put the average probabilities in a dataframe
df_pd = pd.DataFrame(avg_predictions, columns = ['avg_probability'])
# add the values for the column of interest
df_pd[col] = values
return df_pd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment