Created
November 11, 2018 02:08
-
-
Save KevinLiao159/9f69049d6d3d8a096c0ea08dbc29591b to your computer and use it in GitHub Desktop.
A function for ALS hyper-param tuning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml.recommendation import ALS | |
def tune_ALS(train_data, validation_data, maxIter, regParams, ranks): | |
""" | |
grid search function to select the best model based on RMSE of | |
validation data | |
Parameters | |
---------- | |
train_data: spark DF with columns ['userId', 'movieId', 'rating'] | |
validation_data: spark DF with columns ['userId', 'movieId', 'rating'] | |
maxIter: int, max number of learning iterations | |
regParams: list of float, one dimension of hyper-param tuning grid | |
ranks: list of float, one dimension of hyper-param tuning grid | |
Return | |
------ | |
The best fitted ALS model with lowest RMSE score on validation data | |
""" | |
# initial | |
min_error = float('inf') | |
best_rank = -1 | |
best_regularization = 0 | |
best_model = None | |
for rank in ranks: | |
for reg in regParams: | |
# get ALS model | |
als = ALS().setMaxIter(maxIter).setRank(rank).setRegParam(reg) | |
# train ALS model | |
model = als.fit(train_data) | |
# evaluate the model by computing the RMSE on the validation data | |
predictions = model.transform(validation_data) | |
evaluator = RegressionEvaluator(metricName="rmse", | |
labelCol="rating", | |
predictionCol="prediction") | |
rmse = evaluator.evaluate(predictions) | |
print('{} latent factors and regularization = {}: ' | |
'validation RMSE is {}'.format(rank, reg, rmse)) | |
if rmse < min_error: | |
min_error = rmse | |
best_rank = rank | |
best_regularization = reg | |
best_model = model | |
print('\nThe best model has {} latent factors and ' | |
'regularization = {}'.format(best_rank, best_regularization)) | |
return best_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment