Created
November 1, 2017 13:11
-
-
Save bowbowbow/e17016825bd24fd453f1f3d672062954 to your computer and use it in GitHub Desktop.
데이터 전처리 코드와 ALS CF
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.feedforward.classification; | |
import org.apache.spark.api.java.function.VoidFunction; | |
import org.bytedeco.javacv.FrameFilter; | |
import scala.Tuple2; | |
import org.apache.spark.api.java.*; | |
import org.apache.spark.api.java.function.Function; | |
import org.apache.spark.mllib.recommendation.ALS; | |
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; | |
import org.apache.spark.mllib.recommendation.Rating; | |
import org.apache.spark.SparkConf; | |
import java.io.*; | |
import java.io.FileWriter; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.*; | |
public class CF implements Serializable { | |
// data path | |
static String trainPath = "/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/fr_train.txt"; | |
static String usersPath = "/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/group40.txt"; | |
static File outputFile = new File("/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/result_CF.txt"); | |
static MatrixFactorizationModel model; | |
static JavaSparkContext jsc; | |
static int rank = 20; | |
static int numIterations = 20; | |
static double lambda = 0.01; | |
public static void main(String[] args) throws Exception { | |
SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example") | |
.setMaster("local[*]") | |
.set("spark.driver.host", "localhost"); | |
jsc = new JavaSparkContext(conf); | |
buildOutput(); | |
// TestRMSE(); | |
// CheckBS(); | |
} | |
static void TestRMSE() { | |
JavaRDD<String> trainData = jsc.textFile(trainPath); | |
JavaRDD<String>[] tmp = trainData.randomSplit(new double[]{0.99, 0.01}); | |
JavaRDD<Rating> trainRatings = tmp[0].map( | |
new Function<String, Rating>() { | |
public Rating call(String s) { | |
String[] sarray = s.split(","); | |
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), | |
Double.parseDouble(sarray[2])*-1); | |
} | |
} | |
); | |
// Build the recommendation model using ALS | |
model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda); | |
JavaRDD<String> testData = tmp[1]; | |
JavaRDD<Rating> testRatings = testData.map( | |
new Function<String, Rating>() { | |
public Rating call(String s) { | |
String[] sarray = s.split(","); | |
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), | |
Double.parseDouble(sarray[2])*-1); | |
} | |
} | |
); | |
List<Rating> listTest = testRatings.collect(); | |
int cnt = 0; | |
double RMSE = 0.0; | |
for (int i = 0; i < listTest.size(); i++) { | |
if (i % 10 == 0) { | |
// 진행상황 체크 | |
System.out.println("Status i : " + i + "/" + listTest.size()); | |
} | |
int user = listTest.get(i).user(); | |
int product = listTest.get(i).product(); | |
double rating = listTest.get(i).rating(); | |
try { | |
cnt++; | |
double prediction = model.predict(user, product); | |
RMSE += (prediction - rating) * (prediction - rating); | |
} catch (Exception e) { | |
// 주어진 정보로 예측 불가능 | |
} | |
} | |
RMSE /= (double) cnt; | |
RMSE = Math.sqrt(RMSE); | |
System.out.println("RMSE : " + RMSE); | |
} | |
// static void CheckBS() throws Exception { | |
// JavaRDD<String> trainData = jsc.textFile(trainPath); | |
// JavaRDD<Rating> trainRatings = trainData.map( | |
// new Function<String, Rating>() { | |
// public Rating call(String s) { | |
// String[] sarray = s.split(","); | |
// int u = Integer.parseInt(sarray[0]); | |
// int v = Integer.parseInt(sarray[1]); | |
// double r = Double.parseDouble(sarray[2]); | |
// return new Rating(u, v, r ); | |
// } | |
// } | |
// ); | |
// | |
// HashMap<Integer, Integer> userMap = new HashMap<Integer , Integer>(); | |
// List<Rating> trainList = trainRatings.collect(); | |
// for(int i=0; i<trainList.size(); i++) { | |
// int u = trainList.get(i).user(); | |
// userMap.put(u, 1); | |
// } | |
// | |
// // Build the recommendation model using ALS | |
// model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda); | |
// | |
// List<BS> bsList = new ArrayList<BS>(); | |
// try(Scanner scanner = new Scanner(new File(bsPath))){ | |
// do{ | |
// String line = scanner.nextLine(); | |
// String tmpLine = line; | |
// String[] sarray = line.split(" "); | |
// int v = Integer.parseInt(sarray[0]); | |
// bsList.add(new BS(v, tmpLine)); | |
// }while(scanner.hasNextLine()); | |
// } catch(IOException e) { | |
// e.printStackTrace(); | |
// } | |
// | |
// FileWriter outputWriter = new FileWriter(outputFile, false); | |
// for(int i=0; i<bsList.size(); i++) { | |
// if(i > 3000) break; | |
// int v = bsList.get(i).v; | |
// | |
// if(i % 100 == 0) System.out.println("progress : " + i +"/"+bsList.size()); | |
// try { | |
// double prediction = model.predict(Integer.parseInt(target), v); | |
// outputWriter.write(bsList.get(i).line + " " + prediction + "\n"); | |
// } catch (Exception e) { | |
// // System.out.println("exception u : " + u + ", v : " + v); | |
// // 주어진 정보로 예측 불가능 | |
// } | |
// } | |
// outputWriter.close(); | |
// | |
// } | |
static void buildOutput() throws Exception { | |
JavaRDD<String> trainData = jsc.textFile(trainPath); | |
JavaRDD<Rating> trainRatings = trainData.map( | |
new Function<String, Rating>() { | |
public Rating call(String s) { | |
String[] sarray = s.split(","); | |
int u = Integer.parseInt(sarray[0]); | |
int v = Integer.parseInt(sarray[1]); | |
double r = Double.parseDouble(sarray[2]); | |
return new Rating(u, v, -1.0*r); | |
} | |
} | |
); | |
HashMap<Integer, Integer> userMap = new HashMap<Integer , Integer>(); | |
List<Rating> trainList = trainRatings.collect(); | |
for(int i=0; i<trainList.size(); i++) { | |
int u = trainList.get(i).user(); | |
userMap.put(u, 1); | |
} | |
// Build the recommendation model using ALS | |
model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda); | |
List<Integer> users = new ArrayList<Integer>(); | |
try(Scanner scanner = new Scanner(new File(usersPath))){ | |
do{ | |
String line = scanner.nextLine(); | |
users.add(Integer.parseInt(line)); | |
} while(scanner.hasNextLine()); | |
}catch(IOException e){ | |
e.printStackTrace(); | |
} | |
FileWriter outputWriter = new FileWriter(outputFile, false); | |
for(int i=0; i<users.size(); i++) { | |
int u = users.get(i); | |
System.out.println("progress : " + i +"/" +users.size() + ", target: " + u); | |
if (!userMap.containsKey(u)) { | |
System.out.println("continue! : " + u); | |
continue; | |
} | |
Rating[] recommendations = model.recommendProducts(u, 500); | |
for (Rating rating : recommendations) { | |
int nowu = rating.user(); | |
int v = rating.product(); | |
double r = rating.rating(); | |
outputWriter.write(nowu + " " + v + " " + -1.0*r + "\n"); | |
} | |
// List<Candi> candi = new ArrayList<Candi>(); | |
// for(int j=0;j<users.size();j++) { | |
// double s = j/users.size(); | |
// if(j % 1000 == 0) System.out.println("progress : " + j +"/"+users.size()); | |
// | |
// int v = users.get(j); | |
// try { | |
// double prediction = model.predict(u, v); | |
// candi.add(new Candi(u, v, prediction)); | |
// } catch (Exception e) { | |
// System.out.println("exception u : " + u + ", v : " + v); | |
// // 주어진 정보로 예측 불가능 | |
// } | |
// } | |
// Collections.sort(candi, new Comparator<Candi>() { | |
// public int compare(Candi a, Candi b) { | |
// return a.r < b.r ? -1 : a.r == b.r ? 0 : 1; | |
// } | |
// }); | |
// | |
// for (int j = 0; j < 1000; j++) { | |
// int v = candi.get(j).v; | |
// double r = candi.get(j).r; | |
// outputWriter.write(u + " " + v + " " + r + "\n"); | |
// } | |
} | |
outputWriter.close(); | |
} | |
} | |
class BS { | |
String line; | |
int v; | |
BS (int v, String line) { | |
this.line = line; | |
this.v = v; | |
} | |
} | |
class Candi { | |
int u, v; | |
double r; | |
Candi() { | |
} | |
Candi(int u, int v, double r) { | |
this.u = u; | |
this.v = v; | |
this.r = r; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
users = {} | |
trust = {} | |
# 신뢰 관계 40개 이상인 유저 | |
group = {} | |
def build_users(): | |
input = open("./user_relation.txt", "r") | |
while True: | |
line = input.readline() | |
if not line: break | |
p = line.split(' ') | |
uid = int(p[0]) | |
u = int(p[1]) | |
cnt = int(p[2]) | |
users[u]=uid | |
input.close() | |
input = open("./group40.txt", "r") | |
while True: | |
line = input.readline() | |
if not line: break | |
p = line.split(' ') | |
uid = int(p[0]) | |
group[uid] = 1 | |
input.close() | |
def build_trust(): | |
input = open("./trust.txt", "r") | |
while True: | |
line = input.readline() | |
if not line: break | |
p = line.split('\t') | |
u = users[int(p[0])] | |
v = users[int(p[1])] | |
t = int(p[2]) | |
trust[(u,v)] = t | |
input.close() | |
# 훈련에 사용할 fr 값 | |
train_fr = {} | |
def generate_train(): | |
input = open("./fr_given.txt", "r") | |
output = open("./fr_train.txt", "w") | |
# 원래 주어진 fr값 채우기 | |
while True: | |
line = input.readline() | |
if not line: break | |
p = line.split(' ') | |
u = users[int(p[0])] | |
v = users[int(p[1])] | |
r = float(p[2]) | |
if u in group and v in group and (u,v) in trust: | |
train_fr[(u, v)] = r | |
# trust 관계로 덮어쓰기 | |
# for key, value in trust.items(): | |
# u = key[0] | |
# v = key[1] | |
# | |
# t = value | |
# | |
# if u in group and v in group: | |
# if t == 1: | |
# train_fr[(u, v)] = 0.85 | |
# else: | |
# train_fr[(u, v)] = 0 | |
# 출력하기 | |
for key, value in train_fr.items(): | |
u = key[0] | |
v = key[1] | |
r = value | |
if u in group and v in group: | |
output.write(str(u) + "," + str(v) + "," + str(r) + "\n") | |
input.close() | |
output.close() | |
build_users() | |
build_trust() | |
generate_train() | |
print('finish') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment