Skip to content

Instantly share code, notes, and snippets.

@t-kabaya
Created March 31, 2021 09:07
Show Gist options
  • Save t-kabaya/c9095f89befe9e3ba94be6285064577b to your computer and use it in GitHub Desktop.
Save t-kabaya/c9095f89befe9e3ba94be6285064577b to your computer and use it in GitHub Desktop.
from spotlight.cross_validation import user_based_train_test_split
from spotlight.datasets.synthetic import generate_sequential
from spotlight.evaluation import sequence_mrr_score
from spotlight.sequence.implicit import ImplicitSequenceModel
from spotlight.interactions import Interactions
import numpy as np
# デフォルトでは、各ユーザー最低10のアイテムが必要 10を下回るとエラーは発生しないが予測が不正確になる。
user_ids = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype = 'int32')
item_ids = np.array([3, 4, 5, 4, 3, 3, 4, 5, 4, 3, 1, 6, 2, 2, 1, 1, 6, 2, 2, 1], dtype = 'int32')
ratings = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype = 'int32')
timestamps = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype = 'int32')
dataset = Interactions(user_ids,
item_ids,
ratings=ratings,
timestamps=timestamps)
train = dataset.to_sequence()
model = ImplicitSequenceModel(n_iter=1000,
representation='cnn',
loss='bpr')
model.fit(train)
# デフォルトでは、10の長さのsequenceをinnputを学習に使用するので、予測には、9の長さのSequenceを使用し最後の一つを予測する形とする。
# 勿論predictに渡す引数は、10以上でも10未満でも良いがデータセットとの関係を考慮し大体9にするのが良い。
prediction = model.predict([3, 4, 5, 4, 3, 3, 4, 5, 4])
print(prediction)
# prediction = [ 0. , 2.245524 , -4.476738 , 10.410953 , 3.155594 , -2.9693732, -6.0920258], dtype=float32)
# 左から3番目つまりitemId = 3がレコメンドされている。
# mrr = sequence_mrr_score(model, train)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment