Created
January 13, 2020 14:32
-
-
Save Zeng1998/42058f8ad4ffc81400fac21da7a40300 to your computer and use it in GitHub Desktop.
zxc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Created by keane on 2020/1/11 | |
// | |
#include <iostream> | |
#include <fstream> | |
#include <seal/seal.h> | |
#include <iomanip> | |
using namespace std; | |
using namespace seal; | |
ifstream input; | |
ofstream output; | |
template<class T> | |
void write_to_file(string path, T t) { | |
output.open(path); | |
t.save(output); | |
output.close(); | |
} | |
template<class T> | |
void read_by_file(string path, T &t, shared_ptr<SEALContext> context) { | |
input.open(path); | |
t.load(context, input); | |
input.close(); | |
} | |
string read_into_string(string path) { | |
input.open(path); | |
ostringstream buf; | |
char ch; | |
while (buf && input.get(ch)) { | |
buf.put(ch); | |
} | |
input.close(); | |
return buf.str(); | |
} | |
void load_from_string(Ciphertext &cr, string const &str, shared_ptr<SEALContext> context) { | |
istringstream s(str); | |
cr.load(context, s); | |
} | |
bool exists(const string &name) { | |
ifstream f(name.c_str()); | |
return f.good(); | |
} | |
template<class T> | |
T stringToNum(const string &str) { | |
istringstream iss(str); | |
T num; | |
iss >> num; | |
return num; | |
} | |
template<typename T> | |
inline void print_matrix(vector<T> matrix, size_t row_size) { | |
size_t print_size = 5; | |
cout << endl; | |
cout << " ["; | |
for (size_t i = 0; i < print_size; i++) { | |
cout << setw(3) << right << matrix[i] << ","; | |
} | |
cout << setw(3) << " ...,"; | |
for (size_t i = row_size - print_size; i < row_size; i++) { | |
cout << setw(3) << matrix[i] | |
<< ((i != row_size - 1) ? "," : " ]\n"); | |
} | |
cout << " ["; | |
for (size_t i = row_size; i < row_size + print_size; i++) { | |
cout << setw(3) << matrix[i] << ","; | |
} | |
cout << setw(3) << " ...,"; | |
for (size_t i = 2 * row_size - print_size; i < 2 * row_size; i++) { | |
cout << setw(3) << matrix[i] | |
<< ((i != 2 * row_size - 1) ? "," : " ]\n"); | |
} | |
cout << endl; | |
} | |
void print_vec(vector<u_int64_t> v,int siz) { | |
cout << "[ "; | |
// int siz = v.size(); | |
for (int i = 0; i < siz; i++) { | |
cout << v[i] << " "; | |
} | |
cout << "]\n"; | |
} | |
void rotate_sum(Ciphertext &cr, Evaluator &evaluator, int poly_modulus_degree, GaloisKeys glk) { | |
Ciphertext tmp; | |
int poly_modulus_degree_power = log2(poly_modulus_degree); | |
for (int i = 0; i < poly_modulus_degree_power - 1; i++) { | |
evaluator.rotate_rows(cr, pow(2, i), glk, tmp); | |
evaluator.add_inplace(cr, tmp); | |
} | |
} | |
void rotate_sum_test(Ciphertext &cr, Evaluator &evaluator, int poly_modulus_degree, GaloisKeys glk, Decryptor &decryptor, | |
BatchEncoder &batchEncoder) { | |
Plaintext pr; | |
Ciphertext tmp; | |
decryptor.decrypt(cr, pr); | |
vector<uint64_t> v; | |
batchEncoder.decode(pr, v); | |
int poly_modulus_degree_power = log2(poly_modulus_degree); | |
for (int i = 0; i < poly_modulus_degree_power - 1; i++) { | |
evaluator.rotate_rows(cr, pow(2, i), glk, tmp); | |
evaluator.add_inplace(cr, tmp); | |
decryptor.decrypt(cr, pr); | |
vector<uint64_t> v; | |
batchEncoder.decode(pr, v); | |
// cout << i << endl; | |
// print_vec(v); | |
} | |
} | |
const int N = 100; | |
const int M = 50; | |
int main() { | |
EncryptionParameters parms(scheme_type::BFV); | |
//noise log(coeff/poly) | |
size_t poly_modulus_degree = 8192; | |
parms.set_poly_modulus_degree(poly_modulus_degree); | |
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); | |
parms.set_plain_modulus(PlainModulus::Batching(poly_modulus_degree, 20)); | |
auto context = SEALContext::Create(parms); | |
KeyGenerator keygen(context); | |
string psk_path = "/home/keane/CLionProjects/new_data/key/psk.key"; | |
string sk_path = "/home/keane/CLionProjects/new_data/key/sk.key"; | |
string rlk_path = "/home/keane/CLionProjects/new_data/key/rlk.key"; | |
string glk_path = "/home/keane/CLionProjects/new_data/key/glk.key"; | |
PublicKey psk; | |
SecretKey sk; | |
RelinKeys rlk; | |
GaloisKeys glk; | |
// PublicKey psk=keygen.public_key(); | |
// SecretKey sk=keygen.secret_key(); | |
// RelinKeys rlk=keygen.relin_keys(); | |
// GaloisKeys glk=keygen.galois_keys(); | |
// write_to_file(psk_path,psk); | |
// write_to_file(sk_path,sk); | |
// write_to_file(rlk_path,rlk); | |
// write_to_file(glk_path,glk); | |
read_by_file(psk_path, psk, context); | |
read_by_file(sk_path, sk, context); | |
read_by_file(rlk_path, rlk, context); | |
read_by_file(glk_path, glk, context); | |
Encryptor encryptor(context, psk); | |
Decryptor decryptor(context, sk); | |
Evaluator evaluator(context); | |
BatchEncoder batch_encoder(context); | |
IntegerEncoder encoder(context); | |
size_t slot_count = batch_encoder.slot_count(); | |
size_t row_size = slot_count / 2; | |
Plaintext pr; | |
Ciphertext cr; | |
//每个用户预处理一个密文变量 | |
vector<uint64_t> u(slot_count, 0ULL); | |
vector<uint64_t> w(slot_count,0ULL); | |
// 处理一个商品对应的用户的评分 | |
ifstream inFile("rate_9.csv", ios::in); | |
string lineStr; | |
int now=1; | |
while (getline(inFile, lineStr)){ | |
stringstream ss(lineStr); | |
string str; | |
vector<string> lineArray; | |
while (getline(ss, str, ',')){ | |
lineArray.push_back(str); | |
} | |
int userId=stringToNum<int>(lineArray[0]); | |
int itemId=stringToNum<int>(lineArray[1]); | |
int rating=stringToNum<int>(lineArray[2]); | |
if(userId!=now){ | |
// 处理缺失数据 | |
for(int i=0;i<M;i++){ | |
if(u[i]==0){ | |
u[i]=1; | |
} | |
} | |
// 预处理 | |
int sum=0; | |
for(int i=0;i<M;i++){ | |
sum+=u[i]*u[i]; | |
} | |
double sq=sqrt(1.0*sum); | |
for(int i=0;i<M;i++){ | |
u[i]=int(u[i]*100.0/sq); | |
} | |
batch_encoder.encode(u,pr); | |
encryptor.encrypt(pr,cr); | |
cout << now << endl; | |
string name="/home/keane/CLionProjects/new_data/data/user_vec_"+to_string(now)+".data"; | |
write_to_file(name,cr); | |
u.assign(slot_count,0ULL); | |
now=userId; | |
} | |
u[itemId-1]=rating; | |
w.assign(slot_count,0ULL); | |
w[userId-1]=rating; | |
string name="/home/keane/CLionProjects/new_data/data/item_vec_"+to_string(userId)+"_"+to_string(itemId)+".data"; | |
batch_encoder.encode(w,pr); | |
encryptor.encrypt(pr,cr); | |
write_to_file(name,cr); | |
} | |
for(int i=0;i<M;i++){ | |
if(u[i]==0){ | |
u[i]=1; | |
} | |
} | |
// 预处理 | |
int sum=0; | |
for(int i=0;i<M;i++){ | |
sum+=u[i]*u[i]; | |
} | |
double sq=sqrt(1.0*sum); | |
for(int i=0;i<M;i++){ | |
u[i]=int(u[i]*100.0/sq); | |
} | |
batch_encoder.encode(u,pr); | |
encryptor.encrypt(pr,cr); | |
cout << now << endl; | |
string name="/home/keane/CLionProjects/new_data/data/user_vec_"+to_string(now)+".data"; | |
write_to_file(name,cr); | |
u.assign(slot_count,0ULL); | |
batch_encoder.encode(u,pr); | |
encryptor.encrypt(pr,cr); | |
for(int i=1;i<=N;i++){ | |
string name="/home/keane/CLionProjects/new_data/data/user_vec_"+to_string(i)+".data"; | |
if(!exists(name)){ | |
write_to_file(name,cr); | |
} | |
} | |
// 求用户相似度 | |
Ciphertext ca,cb; | |
double st=clock(); | |
for(int i=1;i<=N;i++){ | |
for(int j=1;j<=N;j++){ | |
string name="/home/keane/CLionProjects/new_data/data/user_vec_"+to_string(i)+".data"; | |
read_by_file(name,ca,context); | |
name="/home/keane/CLionProjects/new_data/data/user_vec_"+to_string(j)+".data"; | |
read_by_file(name,cb,context); | |
evaluator.multiply_inplace(ca,cb); | |
evaluator.relinearize_inplace(ca,rlk); | |
rotate_sum(ca,evaluator,poly_modulus_degree,glk); | |
name="/home/keane/CLionProjects/new_data/sim/user_sim_"+to_string(i)+"_"+to_string(j)+".data"; | |
write_to_file(name,ca); | |
} | |
} | |
double ed=clock(); | |
cout << "the time to calculate user similarity matrix: " <<(ed-st)/CLOCKS_PER_SEC <<"s\n"; | |
// 将N*N的相似度密文 变成 N的相似度密文 | |
vector<uint64_t> mc(slot_count,0ULL); | |
vector<uint64_t> v(slot_count,0ULL); | |
Ciphertext sums; | |
for(int i=1;i<=N;i++){ | |
batch_encoder.encode(mc,pr); | |
encryptor.encrypt(pr,sums); | |
for(int j=1;j<=N;j++){ | |
mc.assign(slot_count,0ULL); | |
mc[j-1]=1; | |
batch_encoder.encode(mc,pr); | |
string name="/home/keane/CLionProjects/new_data/sim/user_sim_"+to_string(i)+"_"+to_string(j)+".data"; | |
read_by_file(name,cr,context); | |
evaluator.multiply_plain_inplace(cr,pr); | |
evaluator.add_inplace(sums,cr); | |
} | |
string name="/home/keane/CLionProjects/new_data/sim/user_sim_vec_"+to_string(i)+".data"; | |
write_to_file(name,sums); | |
decryptor.decrypt(sums,pr); | |
batch_encoder.decode(pr,v); | |
print_matrix(v,row_size); | |
} | |
// for(int i=1;i<=N;i++){ | |
// string name="/home/keane/CLionProjects/new_data/sim/user_sim_vec_"+to_string(i)+".data"; | |
// read_by_file(name,cr,context); | |
// decryptor.decrypt(cr,pr); | |
// batch_encoder.decode(pr,v); | |
// print_vec(v,50); | |
// } | |
//将每个物品对应每个用户的评分相加 | |
// vector<uint64_t> mc(slot_count,0ULL); | |
// Ciphertext sums; | |
for(int i=1;i<=M;i++){ | |
batch_encoder.encode(mc,pr); | |
encryptor.encrypt(pr,sums); | |
for(int j=1;j<=N;j++){ | |
mc.assign(slot_count,0ULL); | |
mc[j-1]=1; | |
batch_encoder.encode(mc,pr); | |
string name="/home/keane/CLionProjects/new_data/data/item_vec_"+to_string(j)+"_"+to_string(i)+".data"; | |
if(!exists(name)){ | |
continue; | |
} | |
read_by_file(name,cr,context); | |
// decryptor.decrypt(cr,pr); | |
// batch_encoder.decode(pr,v); | |
// if(i==6){ | |
// print_matrix(v,row_size); | |
// print_vec(v,50); | |
// } | |
// print_matrix(v,row_size); | |
evaluator.multiply_plain_inplace(cr,pr); | |
evaluator.add_inplace(sums,cr); | |
} | |
string name="/home/keane/CLionProjects/new_data/data/item_vec_"+to_string(i)+".data"; | |
write_to_file(name,sums); | |
decryptor.decrypt(sums,pr); | |
batch_encoder.decode(pr,v); | |
} | |
// | |
// string name="/home/keane/CLionProjects/new_data/data/item_vec_"+to_string(9)+"_"+to_string(6)+".data"; | |
// read_by_file(name,cr,context); | |
// decryptor.decrypt(cr,pr); | |
// batch_encoder.decode(pr,v); | |
// print_vec(v,50); | |
// for(int i=1;i<=M;i++){ | |
// string name="/home/keane/CLionProjects/new_data/data/item_vec_"+to_string(i)+".data"; | |
// read_by_file(name,cr,context); | |
// decryptor.decrypt(cr,pr); | |
// batch_encoder.decode(pr,v); | |
// print_vec(v,50); | |
// } | |
// cout << "enter the user_id and get the recommended ratings for items\n"; | |
int user_id = 4; | |
// cin >> user_id; | |
st = clock(); | |
name = "/home/keane/CLionProjects/new_data/data/user_vec_" + to_string(user_id) + ".data"; | |
read_by_file(name, cr, context); | |
vector<uint64_t> ans; | |
// vector<uint64_t> mc(slot_count,0ULL); | |
// Ciphertext sum; | |
for (int i = 1; i <= M; i++) { | |
Ciphertext cc,cb; | |
string name = "/home/keane/CLionProjects/new_data/sim/user_sim_vec_" + to_string(user_id) + ".data"; | |
read_by_file(name, cc, context); | |
name="/home/keane/CLionProjects/new_data/data/item_vec_" + to_string(i)+ ".data"; | |
read_by_file(name,cb,context); | |
evaluator.multiply_inplace(cc, cb); | |
evaluator.relinearize_inplace(cc, rlk); | |
rotate_sum(cc,evaluator,poly_modulus_degree,glk); | |
decryptor.decrypt(cc, pr); | |
batch_encoder.decode(pr, v); | |
if(v[0]>=320000){ | |
cout << user_id <<","<<i<<"\n"; | |
} | |
} | |
ed = clock(); | |
// print_vec(ans); | |
cout << "Calculate user " + to_string(user_id) + " recommendation score: " + to_string((ed - st) / CLOCKS_PER_SEC) | |
<< "s\n"; | |
uint64_t all=0; | |
// freopen("/home/keane/CLionProjects/new_data/rec5.csv","w",stdout); | |
for(int user_id=1;user_id<=N;user_id++){ | |
for (int i = 1; i <= M; i++) { | |
Ciphertext cc,cb; | |
string name = "/home/keane/CLionProjects/new_data/sim/user_sim_vec_" + to_string(user_id) + ".data"; | |
read_by_file(name, cc, context); | |
name="/home/keane/CLionProjects/new_data/data/item_vec_" + to_string(i)+ ".data"; | |
read_by_file(name,cb,context); | |
evaluator.multiply_inplace(cc, cb); | |
evaluator.relinearize_inplace(cc, rlk); | |
rotate_sum(cc,evaluator,poly_modulus_degree,glk); | |
decryptor.decrypt(cc, pr); | |
batch_encoder.decode(pr, v); | |
all+=v[0]; | |
cout << v[0]<<endl; | |
if(v[0]>=1000000){ | |
cout << user_id <<","<<i<<"\n"; | |
} | |
} | |
} | |
cout << all*1.0/N/M << endl; | |
uint64_t me=uint64_t(all*1.0/N/M); | |
freopen("/home/keane/CLionProjects/new_data/rec12.csv","w",stdout); | |
cout <<"user,item\n"; | |
for(int user_id=1;user_id<=N;user_id++){ | |
for (int i = 1; i <= M; i++) { | |
Ciphertext cc,cb; | |
string name = "/home/keane/CLionProjects/new_data/sim/user_sim_vec_" + to_string(user_id) + ".data"; | |
read_by_file(name, cc, context); | |
name="/home/keane/CLionProjects/new_data/data/item_vec_" + to_string(i)+ ".data"; | |
read_by_file(name,cb,context); | |
evaluator.multiply_inplace(cc, cb); | |
evaluator.relinearize_inplace(cc, rlk); | |
rotate_sum(cc,evaluator,poly_modulus_degree,glk); | |
cout << " + Noise budget in result: " | |
<< decryptor.invariant_noise_budget(cc) << " bits" << endl; | |
decryptor.decrypt(cc, pr); | |
batch_encoder.decode(pr, v); | |
// if(v[0]>=650000){ | |
// cout << user_id <<","<<i<<"\n"; | |
// } | |
} | |
} | |
cout << "the time to calculate user similarity matrix: " << 35747.2<<"s\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment