Last active
February 12, 2023 11:07
-
-
Save giuseppebonaccorso/3acff53906cbb7c37abb8c7e4bf3b0ef to your computer and use it in GitHub Desktop.
SVD Recommendations using Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
# Set random seed for reproducibility | |
np.random.seed(1000) | |
nb_users = 5000 | |
nb_products = 2000 | |
nb_factors = 500 | |
max_rating = 5 | |
nb_rated_products = 500 | |
top_k_products = 10 | |
# Create a random User-Item matrix | |
uim = np.zeros((nb_users, nb_products), dtype=np.float32) | |
for i in range(nb_users): | |
nbp = np.random.randint(0, nb_products, size=nb_rated_products) | |
for j in nbp: | |
uim[i, j] = np.random.randint(1, max_rating+1) | |
# Create a Tensorflow graph | |
graph = tf.Graph() | |
with graph.as_default(): | |
# User-item matrix | |
user_item_matrix = tf.placeholder(tf.float32, shape=(nb_users, nb_products)) | |
# SVD | |
St, Ut, Vt = tf.svd(user_item_matrix) | |
# Compute reduced matrices | |
Sk = tf.diag(St)[0:nb_factors, 0:nb_factors] | |
Uk = Ut[:, 0:nb_factors] | |
Vk = Vt[0:nb_factors, :] | |
# Compute Su and Si | |
Su = tf.matmul(Uk, tf.sqrt(Sk)) | |
Si = tf.matmul(tf.sqrt(Sk), Vk) | |
# Compute user ratings | |
ratings_t = tf.matmul(Su, Si) | |
# Pick top k suggestions | |
best_ratings_t, best_items_t = tf.nn.top_k(ratings_t, top_k_products) | |
# Create Tensorflow session | |
session = tf.InteractiveSession(graph=graph) | |
# Compute the top k suggestions for all users | |
feed_dict = { | |
user_item_matrix: uim | |
} | |
best_items = session.run([best_items_t], feed_dict=feed_dict) | |
# Suggestions for user 1000, 1010 | |
for i in range(1000, 1010): | |
print('User {}: {}'.format(i, best_items[0][i])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for i in range(1000, 1010): | |
print('User {}: {}'.format(i, best_items[0][i])) | |
User 1000: [ 412 867 1040 509 1311 1562 758 1796 636 556] | |
User 1001: [ 548 88 1299 175 81 1837 282 1555 1796 1902] | |
User 1002: [ 433 667 460 821 1762 775 1673 278 284 1540] | |
User 1003: [1823 602 1874 43 1979 1612 1755 857 891 1701] | |
User 1004: [1700 1312 892 621 194 1919 196 1746 1697 1192] | |
User 1005: [ 891 221 1112 1387 768 1697 916 485 1673 1515] | |
User 1006: [ 463 611 1986 1253 175 1362 1112 1811 1045 768] | |
User 1007: [1170 70 1886 757 412 606 892 1772 1540 1415] | |
User 1008: [ 757 855 509 329 410 1304 1900 1631 476 284] | |
User 1009: [ 888 1654 6 1453 735 1745 505 422 1878 1965] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hello Giuseppe Bonaccorso,
I want to know what is dataset? (the structure of data).
I 'm newbie. I want to put the real data in to it.
Thank you.