Skip to content

Instantly share code, notes, and snippets.

@erap129
Created February 19, 2022 12:23
Show Gist options
  • Save erap129/d2a2dae52a6a1e1cc578f8e40295dd91 to your computer and use it in GitHub Desktop.
Save erap129/d2a2dae52a6a1e1cc578f8e40295dd91 to your computer and use it in GitHub Desktop.
from surprise import AlgoBase, KNNBasic
from surprise.prediction_algorithms.knns import SymmetricAlgo
class CustomSimKNNAlgorithm(KNNBasic):
def __init__(self, sim_options, k=40, min_k=1):
SymmetricAlgo.__init__(self)
self.sim_options = sim_options
self.k = k
self.min_k = min_k
def fit(self, trainset, similarities):
AlgoBase.fit(self, trainset)
self.sim = similarities
ub = self.sim_options['user_based']
self.n_x = self.trainset.n_users if ub else self.trainset.n_items
self.n_y = self.trainset.n_items if ub else self.trainset.n_users
self.xr = self.trainset.ur if ub else self.trainset.ir
self.yr = self.trainset.ir if ub else self.trainset.ur
def test(self, testset, verbose=False):
# The ratings are translated back to their original scale.
predictions = [self.predict(uid,
iid,
r_ui_trans,
verbose=verbose)
for (uid, iid, r_ui_trans) in tqdm(testset, desc='making predictions')]
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment