Created
March 21, 2020 14:08
-
-
Save bicycle1885/105d39c5573a0bc2687ae1bdb414b522 to your computer and use it in GitHub Desktop.
Quickly optimized version of https://gist.github.com/raryo/f755b3f89eba3e084082dc9d57ebb68f
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
function main() | |
# input data | |
train_data = read_data("./ml-100k/u.data") | |
n_user = length(unique([t[1] for t in train_data])) | |
n_item = length(unique([t[2] for t in train_data])) | |
# parameters | |
P, Q = fit(n_user, n_item, train_data) | |
end | |
function read_data(file_path) | |
train_data = Tuple{Int,Int,Float32}[] | |
for l in eachline(file_path) | |
u, i, r, _ = parse.(Int, split(l)) | |
push!(train_data, (u, i, r)) | |
end | |
return train_data | |
end | |
function fit(n_user, n_item, train_data; n_itr=50, n_fac=5, γ=0.07f0, λ=0.01f0) | |
# init parameters | |
P = randn(Float32, n_fac, n_user) | |
Q = randn(Float32, n_fac, n_item) | |
# optimaize: SGD | |
for itr in 1:n_itr | |
loss = 0f0 | |
for (u, i, r) in train_data | |
# calc error | |
pu, qi = P[:,u], Q[:,i] | |
e = r - pu ⋅ qi | |
@. Q[:,i] += γ * (e * pu - λ * qi) | |
@. P[:,u] += γ * (e * qi - λ * pu) | |
# calc loss | |
selfdot(x) = x ⋅ x | |
loss += e*e + λ * (selfdot(@view(P[:,u])) + selfdot(@view(Q[:,i]))) | |
end | |
println("$itr: $loss") | |
end | |
return P, Q | |
end | |
using Random | |
Random.seed!(1234) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment