Created
March 31, 2021 09:07
-
-
Save t-kabaya/c9095f89befe9e3ba94be6285064577b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from spotlight.cross_validation import user_based_train_test_split | |
from spotlight.datasets.synthetic import generate_sequential | |
from spotlight.evaluation import sequence_mrr_score | |
from spotlight.sequence.implicit import ImplicitSequenceModel | |
from spotlight.interactions import Interactions | |
import numpy as np | |
# デフォルトでは、各ユーザー最低10のアイテムが必要 10を下回るとエラーは発生しないが予測が不正確になる。 | |
user_ids = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype = 'int32') | |
item_ids = np.array([3, 4, 5, 4, 3, 3, 4, 5, 4, 3, 1, 6, 2, 2, 1, 1, 6, 2, 2, 1], dtype = 'int32') | |
ratings = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype = 'int32') | |
timestamps = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype = 'int32') | |
dataset = Interactions(user_ids, | |
item_ids, | |
ratings=ratings, | |
timestamps=timestamps) | |
train = dataset.to_sequence() | |
model = ImplicitSequenceModel(n_iter=1000, | |
representation='cnn', | |
loss='bpr') | |
model.fit(train) | |
# デフォルトでは、10の長さのsequenceをinnputを学習に使用するので、予測には、9の長さのSequenceを使用し最後の一つを予測する形とする。 | |
# 勿論predictに渡す引数は、10以上でも10未満でも良いがデータセットとの関係を考慮し大体9にするのが良い。 | |
prediction = model.predict([3, 4, 5, 4, 3, 3, 4, 5, 4]) | |
print(prediction) | |
# prediction = [ 0. , 2.245524 , -4.476738 , 10.410953 , 3.155594 , -2.9693732, -6.0920258], dtype=float32) | |
# 左から3番目つまりitemId = 3がレコメンドされている。 | |
# mrr = sequence_mrr_score(model, train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment