Skip to content

Instantly share code, notes, and snippets.

@bowbowbow
Created November 1, 2017 13:11
Show Gist options
  • Save bowbowbow/e17016825bd24fd453f1f3d672062954 to your computer and use it in GitHub Desktop.
Save bowbowbow/e17016825bd24fd453f1f3d672062954 to your computer and use it in GitHub Desktop.
데이터 전처리 코드와 ALS CF
package org.deeplearning4j.examples.feedforward.classification;
import org.apache.spark.api.java.function.VoidFunction;
import org.bytedeco.javacv.FrameFilter;
import scala.Tuple2;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.SparkConf;
import java.io.*;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.*;
public class CF implements Serializable {
// data path
static String trainPath = "/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/fr_train.txt";
static String usersPath = "/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/group40.txt";
static File outputFile = new File("/Users/bowbowbow/Desktop/SourceTree/2017_CSE4007_2014004857/data/result_CF.txt");
static MatrixFactorizationModel model;
static JavaSparkContext jsc;
static int rank = 20;
static int numIterations = 20;
static double lambda = 0.01;
public static void main(String[] args) throws Exception {
SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example")
.setMaster("local[*]")
.set("spark.driver.host", "localhost");
jsc = new JavaSparkContext(conf);
buildOutput();
// TestRMSE();
// CheckBS();
}
static void TestRMSE() {
JavaRDD<String> trainData = jsc.textFile(trainPath);
JavaRDD<String>[] tmp = trainData.randomSplit(new double[]{0.99, 0.01});
JavaRDD<Rating> trainRatings = tmp[0].map(
new Function<String, Rating>() {
public Rating call(String s) {
String[] sarray = s.split(",");
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
Double.parseDouble(sarray[2])*-1);
}
}
);
// Build the recommendation model using ALS
model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda);
JavaRDD<String> testData = tmp[1];
JavaRDD<Rating> testRatings = testData.map(
new Function<String, Rating>() {
public Rating call(String s) {
String[] sarray = s.split(",");
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
Double.parseDouble(sarray[2])*-1);
}
}
);
List<Rating> listTest = testRatings.collect();
int cnt = 0;
double RMSE = 0.0;
for (int i = 0; i < listTest.size(); i++) {
if (i % 10 == 0) {
// 진행상황 체크
System.out.println("Status i : " + i + "/" + listTest.size());
}
int user = listTest.get(i).user();
int product = listTest.get(i).product();
double rating = listTest.get(i).rating();
try {
cnt++;
double prediction = model.predict(user, product);
RMSE += (prediction - rating) * (prediction - rating);
} catch (Exception e) {
// 주어진 정보로 예측 불가능
}
}
RMSE /= (double) cnt;
RMSE = Math.sqrt(RMSE);
System.out.println("RMSE : " + RMSE);
}
// static void CheckBS() throws Exception {
// JavaRDD<String> trainData = jsc.textFile(trainPath);
// JavaRDD<Rating> trainRatings = trainData.map(
// new Function<String, Rating>() {
// public Rating call(String s) {
// String[] sarray = s.split(",");
// int u = Integer.parseInt(sarray[0]);
// int v = Integer.parseInt(sarray[1]);
// double r = Double.parseDouble(sarray[2]);
// return new Rating(u, v, r );
// }
// }
// );
//
// HashMap<Integer, Integer> userMap = new HashMap<Integer , Integer>();
// List<Rating> trainList = trainRatings.collect();
// for(int i=0; i<trainList.size(); i++) {
// int u = trainList.get(i).user();
// userMap.put(u, 1);
// }
//
// // Build the recommendation model using ALS
// model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda);
//
// List<BS> bsList = new ArrayList<BS>();
// try(Scanner scanner = new Scanner(new File(bsPath))){
// do{
// String line = scanner.nextLine();
// String tmpLine = line;
// String[] sarray = line.split(" ");
// int v = Integer.parseInt(sarray[0]);
// bsList.add(new BS(v, tmpLine));
// }while(scanner.hasNextLine());
// } catch(IOException e) {
// e.printStackTrace();
// }
//
// FileWriter outputWriter = new FileWriter(outputFile, false);
// for(int i=0; i<bsList.size(); i++) {
// if(i > 3000) break;
// int v = bsList.get(i).v;
//
// if(i % 100 == 0) System.out.println("progress : " + i +"/"+bsList.size());
// try {
// double prediction = model.predict(Integer.parseInt(target), v);
// outputWriter.write(bsList.get(i).line + " " + prediction + "\n");
// } catch (Exception e) {
// // System.out.println("exception u : " + u + ", v : " + v);
// // 주어진 정보로 예측 불가능
// }
// }
// outputWriter.close();
//
// }
static void buildOutput() throws Exception {
JavaRDD<String> trainData = jsc.textFile(trainPath);
JavaRDD<Rating> trainRatings = trainData.map(
new Function<String, Rating>() {
public Rating call(String s) {
String[] sarray = s.split(",");
int u = Integer.parseInt(sarray[0]);
int v = Integer.parseInt(sarray[1]);
double r = Double.parseDouble(sarray[2]);
return new Rating(u, v, -1.0*r);
}
}
);
HashMap<Integer, Integer> userMap = new HashMap<Integer , Integer>();
List<Rating> trainList = trainRatings.collect();
for(int i=0; i<trainList.size(); i++) {
int u = trainList.get(i).user();
userMap.put(u, 1);
}
// Build the recommendation model using ALS
model = ALS.train(JavaRDD.toRDD(trainRatings), rank, numIterations, lambda);
List<Integer> users = new ArrayList<Integer>();
try(Scanner scanner = new Scanner(new File(usersPath))){
do{
String line = scanner.nextLine();
users.add(Integer.parseInt(line));
} while(scanner.hasNextLine());
}catch(IOException e){
e.printStackTrace();
}
FileWriter outputWriter = new FileWriter(outputFile, false);
for(int i=0; i<users.size(); i++) {
int u = users.get(i);
System.out.println("progress : " + i +"/" +users.size() + ", target: " + u);
if (!userMap.containsKey(u)) {
System.out.println("continue! : " + u);
continue;
}
Rating[] recommendations = model.recommendProducts(u, 500);
for (Rating rating : recommendations) {
int nowu = rating.user();
int v = rating.product();
double r = rating.rating();
outputWriter.write(nowu + " " + v + " " + -1.0*r + "\n");
}
// List<Candi> candi = new ArrayList<Candi>();
// for(int j=0;j<users.size();j++) {
// double s = j/users.size();
// if(j % 1000 == 0) System.out.println("progress : " + j +"/"+users.size());
//
// int v = users.get(j);
// try {
// double prediction = model.predict(u, v);
// candi.add(new Candi(u, v, prediction));
// } catch (Exception e) {
// System.out.println("exception u : " + u + ", v : " + v);
// // 주어진 정보로 예측 불가능
// }
// }
// Collections.sort(candi, new Comparator<Candi>() {
// public int compare(Candi a, Candi b) {
// return a.r < b.r ? -1 : a.r == b.r ? 0 : 1;
// }
// });
//
// for (int j = 0; j < 1000; j++) {
// int v = candi.get(j).v;
// double r = candi.get(j).r;
// outputWriter.write(u + " " + v + " " + r + "\n");
// }
}
outputWriter.close();
}
}
class BS {
String line;
int v;
BS (int v, String line) {
this.line = line;
this.v = v;
}
}
class Candi {
int u, v;
double r;
Candi() {
}
Candi(int u, int v, double r) {
this.u = u;
this.v = v;
this.r = r;
}
}
users = {}
trust = {}
# 신뢰 관계 40개 이상인 유저
group = {}
def build_users():
input = open("./user_relation.txt", "r")
while True:
line = input.readline()
if not line: break
p = line.split(' ')
uid = int(p[0])
u = int(p[1])
cnt = int(p[2])
users[u]=uid
input.close()
input = open("./group40.txt", "r")
while True:
line = input.readline()
if not line: break
p = line.split(' ')
uid = int(p[0])
group[uid] = 1
input.close()
def build_trust():
input = open("./trust.txt", "r")
while True:
line = input.readline()
if not line: break
p = line.split('\t')
u = users[int(p[0])]
v = users[int(p[1])]
t = int(p[2])
trust[(u,v)] = t
input.close()
# 훈련에 사용할 fr 값
train_fr = {}
def generate_train():
input = open("./fr_given.txt", "r")
output = open("./fr_train.txt", "w")
# 원래 주어진 fr값 채우기
while True:
line = input.readline()
if not line: break
p = line.split(' ')
u = users[int(p[0])]
v = users[int(p[1])]
r = float(p[2])
if u in group and v in group and (u,v) in trust:
train_fr[(u, v)] = r
# trust 관계로 덮어쓰기
# for key, value in trust.items():
# u = key[0]
# v = key[1]
#
# t = value
#
# if u in group and v in group:
# if t == 1:
# train_fr[(u, v)] = 0.85
# else:
# train_fr[(u, v)] = 0
# 출력하기
for key, value in train_fr.items():
u = key[0]
v = key[1]
r = value
if u in group and v in group:
output.write(str(u) + "," + str(v) + "," + str(r) + "\n")
input.close()
output.close()
build_users()
build_trust()
generate_train()
print('finish')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment