Last active
February 22, 2020 08:04
-
-
Save yoshi0309/33bd912d91c0bb5cdf30 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import sys | |
import itertools | |
import csv | |
import datetime | |
import time | |
from math import sqrt | |
from operator import add | |
from os.path import join, isfile, dirname | |
from pyspark import SparkConf, SparkContext | |
from pyspark.mllib.recommendation import ALS | |
def parseRating(line): | |
""" | |
rating.csv のレコード行からデータを抽出し、タプルにして返す | |
user_id, timestamp, restaurant_id, total(rating) | |
""" | |
fields = line.strip().split(",") | |
try: | |
date = datetime.datetime.strptime(fields[11],'"%Y-%m-%d %H:%M:%S"') | |
epoch = time.mktime(date.timetuple())%10 | |
except ValueError: | |
epoch = 0 | |
return fields[2], str(epoch)+","+fields[1]+","+fields[3] | |
def parseRestaurant(line): | |
""" | |
restaurants.csv のレコード行からデータを抽出し、タプルにして返す。 | |
restaurant_id, name | |
""" | |
fields = line.strip().split(",") | |
return int(fields[0]), fields[1] | |
def parseUserids(line): | |
""" | |
""" | |
fields = line.split(",") | |
return fields[0],int(fields[1]) | |
def getUserId(line): | |
""" | |
ratings.csv から userid を抽出して返す | |
""" | |
fields = line.strip().split(",") | |
for field in fields: | |
if field.startswith("\""): | |
field = field[1:] | |
if field.endswith("\""): | |
field = field[:-1] | |
return fields[2] | |
def reformat(x): | |
id = x[1] + 1 #zipWithUniqueId は 0 から始まるので、1ずつずらす。 | |
return str(x[0]),int(id) | |
def parseJoinedData(x): | |
""" | |
return user_id, restaurant_id, total, timestamp | |
total (rating) の値は、0から始まるので +1 する。 | |
""" | |
fields = x[0].split(",") | |
return x[1],int(fields[1]),float(fields[2])+1,float(fields[0]) | |
def computeRmse(model, data, n): | |
""" | |
Compute RMSE (Root Mean Squared Error). | |
""" | |
# print "Start prediction.." | |
# print "Data count : %i " % data.count() | |
predictions = model.predictAll(data.map(lambda x: (x[0], x[1]))) | |
# print "Prediction done. count: %i . And join data." % predictions.count() | |
predictionsAndRatings = predictions.map(lambda x: ((x[0], x[1]), x[2])) \ | |
.join(data.map(lambda x: ((x[0], x[1]), x[2]))) \ | |
.values() | |
# print "Start calculate RMSE : %i " % predictionsAndRatings.count() | |
return sqrt(predictionsAndRatings.map(lambda x: (x[0] - x[1]) ** 2).reduce(add) / float(n)) | |
if __name__ == "__main__": | |
# set up environment | |
conf = SparkConf().setAppName("ldgourmentALS").set("spark.executor.memory", "1g") \ | |
.set("spark.python.worker.memory","3g") | |
sc = SparkContext(conf=conf) | |
# load personal ratings | |
# userid, itemid, rating | |
myRatings = [ | |
(0,169,5.0), #天一 | |
(0,3127,5.0), #一蘭 | |
(0,333,5.0), #二郎 | |
# (0,142,2.0) # 香港ガーデン | |
] | |
myRatingsRDD = sc.parallelize(myRatings, 1) | |
# 1. ratings.csv を読み込む | |
# 2. 1行目(idから始まるレコード)をフィルター | |
# 3. ratings.csvの行から userid / itemid / raitng のみを抽出しタプルに変換(parseRating) | |
# 4. rate が 0 のものを除外。 | |
homeDir = sys.argv[1] | |
# ratings = sc.textFile(join(homeDir, "ratings.csv")).filter(lambda x: not(x.startswith("id")) ).map(parseRating).filter(lambda x : not(x[1].endswith("0"))) | |
ratings = sc.textFile(join(homeDir, "ratings.csv")).filter(lambda x: not(x.startswith("id")) ).map(parseRating) | |
# restaurantの名前を取り出す。 | |
# 1. raestaurants.csv を読み込む | |
# 2. id から始まるレコードをフィルター(1行目) | |
# 3. restaurants.csv の各行から restaurantid / name のみを抽出しタプルに変換(parseRestaurant) | |
restaurantsRdd = sc.textFile(join(homeDir, "restaurants.csv")).filter(lambda x: not(x.startswith("id")) ).map(parseRestaurant) | |
restaurants = dict(restaurantsRdd.collect()) | |
# ueridから連番に置き換える | |
# (元のuseridは半角英数のため,MLLibではそのまま扱えない。 | |
# useridに対しintergerの連番をふる。 | |
# userid / userid_int | |
userids = sc.textFile(join(homeDir, "ratings.csv")).filter(lambda x : not(x.startswith("id"))).map(getUserId).distinct().zipWithUniqueId().map(reformat) | |
# print "DEBUG! print ratings before join!" | |
# print ratings.take(5) | |
# print "DEBUG! print userids before join!" | |
# print userids.take(5) | |
# raings の userid を userid_int に置き換える | |
# 1. ratings に userids を join する (userid, (itemid+rating, userid_int)) | |
# 2. values でvalue部分だけを取り出し itemid+rating, userid_int | |
# 3. タプルの順番を入れ替える userid_int, itemid, rating, timestamp (parseJoinedData) | |
# ratings = ratings.join(userids).values().map(parseJoinedData).union(myRatingsRDD) | |
ratings = ratings.join(userids).values().map(parseJoinedData) | |
# print "DEBUG ! print ratings after join." | |
# print ratings.take(5) | |
# 各データの件数を数えて表示。 | |
numRatings = ratings.count() | |
numUsers = ratings.map(lambda r:r[0]).distinct().count() | |
numRestaurants = ratings.map(lambda r:r[1]).distinct().count() | |
print "Got %d ratings from %d users on %d restaurants." % (numRatings,numUsers,numRestaurants) | |
# rating からテスト用と検証用のデータをサンプリング | |
# 学習用のデータは全件使用 | |
numPartitions = 9 | |
training = ratings.filter(lambda x:x[3]<6)\ | |
.map(lambda x:(x[0],x[1],x[2])).union(myRatingsRDD).repartition(numPartitions).cache() | |
validation = ratings.filter(lambda x:x[3]>=6 and x[3]<8)\ | |
.map(lambda x:(x[0],x[1],x[2])).repartition(numPartitions).cache() | |
test = ratings.filter(lambda x:x[3]>=8)\ | |
.map(lambda x:(x[0],x[1],x[2])).repartition(numPartitions).cache() | |
# 各データの件数を数えて表示。 | |
numTraining = training.count() | |
numValidation = validation.count() | |
numTest = test.count() | |
print "Training: %d, validation: %d, test: %d" % (numTraining,numValidation,numTest) | |
print training.take(5) | |
# 学習用各種パラメータ | |
# 配列なので、複数セット可。複数セットすると最も RMSE の値が良いものを採用する。 | |
ranks = [8,9,10] | |
lambdas = [0.31,0.32,0.33] | |
numIters = [3] | |
# 一時変数 | |
bestValidationRmse = float("inf") | |
bestModel = None | |
bestRank = 0 | |
bestLambda = -1.0 | |
bestNumIter = -1 | |
# DEBUG | |
# print "DEBUG ! #############################################################################" | |
# print training.take(5) | |
for rank,lmbda,numIter in itertools.product(ranks,lambdas,numIters): | |
# 学習の実行 | |
model = ALS.train(training,rank,numIter,lmbda) | |
# 検証用のデータを使用して RMSE を計算し表示 | |
validationRmse = computeRmse(model, validation, numValidation) | |
print "RMSE (validation) = %f for the model trained with " % validationRmse + "rank = %d, lambda = %.2f, and numIter %d." % (rank,lmbda,numIter) | |
# RMSE の値が最もよければ、学習用のパラメータを今のものと置き換える。 | |
if(validationRmse < bestValidationRmse): | |
bestModel = model | |
bestRank = rank | |
bestLambda = lmbda | |
bestNumIter = numIter | |
bestValidationRmse = validationRmse | |
# テスト用データで bestmodel を使い RMSE を計算して表示。 | |
testRmse = computeRmse(bestModel, test, numTest) | |
print "The best model was trained with rank = %d and lambda = %.2f, " % (bestRank, bestLambda) \ | |
+ "and numIter = %d, and its RMSE on the test set is %f." % (bestNumIter, testRmse) | |
# rate を推測する。 | |
# レストランの全データ(restaurants)から、まだ訪れたことのない restaurantid を抽出 | |
myRatedRestaurantIds = set([x[1] for x in myRatings]) | |
candidates = sc.parallelize([m for m in restaurants if m not in myRatedRestaurantIds]).repartition(numPartitions).cache() | |
# 1. userid 0 に対して、訪れたことのないレストランの rating を全て推測 | |
# 2. rate の大きかった順にソートして、Top 50 件を抽出 | |
# 3. 抽出した 50 件を表示 | |
predictions = bestModel.predictAll(candidates.map(lambda x : (0,x))).collect() | |
recommendations = sorted(predictions, key=lambda x : x[2], reverse=True)[:50] | |
for i in xrange(len(recommendations)): | |
print ("%2d: %s, Score: %0.2f , User Id: %s" % (i + 1, restaurants[recommendations[i][1]].encode('utf-8', 'ignore'), recommendations[i][2], recommendations[i][0])) | |
# clean up | |
sc.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment