Skip to content

Instantly share code, notes, and snippets.

@KevinLiao159
Created November 11, 2018 02:08
Show Gist options
  • Save KevinLiao159/9f69049d6d3d8a096c0ea08dbc29591b to your computer and use it in GitHub Desktop.
Save KevinLiao159/9f69049d6d3d8a096c0ea08dbc29591b to your computer and use it in GitHub Desktop.
A function for ALS hyper-param tuning
from pyspark.ml.recommendation import ALS
def tune_ALS(train_data, validation_data, maxIter, regParams, ranks):
"""
grid search function to select the best model based on RMSE of
validation data
Parameters
----------
train_data: spark DF with columns ['userId', 'movieId', 'rating']
validation_data: spark DF with columns ['userId', 'movieId', 'rating']
maxIter: int, max number of learning iterations
regParams: list of float, one dimension of hyper-param tuning grid
ranks: list of float, one dimension of hyper-param tuning grid
Return
------
The best fitted ALS model with lowest RMSE score on validation data
"""
# initial
min_error = float('inf')
best_rank = -1
best_regularization = 0
best_model = None
for rank in ranks:
for reg in regParams:
# get ALS model
als = ALS().setMaxIter(maxIter).setRank(rank).setRegParam(reg)
# train ALS model
model = als.fit(train_data)
# evaluate the model by computing the RMSE on the validation data
predictions = model.transform(validation_data)
evaluator = RegressionEvaluator(metricName="rmse",
labelCol="rating",
predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print('{} latent factors and regularization = {}: '
'validation RMSE is {}'.format(rank, reg, rmse))
if rmse < min_error:
min_error = rmse
best_rank = rank
best_regularization = reg
best_model = model
print('\nThe best model has {} latent factors and '
'regularization = {}'.format(best_rank, best_regularization))
return best_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment