Skip to content

Instantly share code, notes, and snippets.

@kshirsagarsiddharth
Created June 23, 2021 16:50
Show Gist options
  • Save kshirsagarsiddharth/557e7559d166f00a9b8ddaf15076a7c6 to your computer and use it in GitHub Desktop.
Save kshirsagarsiddharth/557e7559d166f00a9b8ddaf15076a7c6 to your computer and use it in GitHub Desktop.
als = ALS(maxIter = 10 ,
userCol = "user_id",
itemCol = "isbn_indexed",
ratingCol = "book_rating",
nonnegative = True,
coldStartStrategy = 'drop')
from pyspark.ml.tuning import ParamGridBuilder,CrossValidator
grid = ParamGridBuilder().addGrid(als.rank, [10,30])\
.addGrid(als.regParam, [0.2,0.01,1,2])\
.build()
cv = CrossValidator(estimator = als,estimatorParamMaps = grid, evaluator = evaluator, parallelism = 4, numFolds = 3)
model = cv.fit(training)
best_model = model.bestModel
predictions = best_model.transform(test)
rmse = evaluator.evaluate(predictions)
print("Root mean squared error: ",rmse)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment