Skip to content

Instantly share code, notes, and snippets.

@myui
Created November 22, 2024 05:54
Show Gist options
  • Save myui/a8a416c390e8f9458045d19c3f31b501 to your computer and use it in GitHub Desktop.
Save myui/a8a416c390e8f9458045d19c3f31b501 to your computer and use it in GitHub Desktop.
SLIM Elastic
import numpy as np
import scipy.sparse as sp
from sklearn.linear_model import ElasticNet
import warnings
from sklearn.exceptions import ConvergenceWarning
class SLIMElastic:
"""
SLIMElastic is a sparse linear method for top-K recommendation, which learns
a sparse aggregation coefficient matrix by solving an L1-norm and L2-norm
regularized optimization problem.
"""
def __init__(self, config: dict={}):
"""
Initialize the SLIMElastic model.
Args:
alpha (float): Regularization strength.
l1_ratio (float): The ratio between L1 and L2 regularization.
positive_only (bool): Whether to enforce positive coefficients.
"""
self.alpha = config.get("alpha", 0.1)
self.l1_ratio = config.get("l1_ratio", 0.1)
self.positive_only = config.get("positive_only", True)
self.max_iter = config.get("max_iter", 100)
self.tol = config.get("tol", 1e-4)
# Initialize an empty item similarity matrix (will be computed during fit)
self.item_similarity = None
def fit(self, interaction_matrix):
"""
Fit the SLIMElastic model to the interaction matrix.
Args:
interaction_matrix (csr_matrix): User-item interaction matrix (sparse).
"""
if not isinstance(interaction_matrix, sp.csr_matrix):
raise ValueError("Interaction matrix must be a scipy.sparse.csr_matrix of user-item interactions.")
X = interaction_matrix.astype(np.float32)
num_items = X.shape[1]
self.item_similarity = np.zeros((num_items, num_items)) # Initialize similarity matrix
model = ElasticNet(alpha=self.alpha, l1_ratio=self.l1_ratio, positive=self.positive_only, fit_intercept=False, max_iter=self.max_iter, tol=self.tol)
item_coeffs = []
# Ignore convergence warnings for ElasticNet
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
# Iterate through each item (column) and fit the model
for j in range(num_items):
# Target column (current item)
y = X[:, j]
# Set the target item column to 0
X[:, j] = 0
# Fit the model
model.fit(X, y.toarray().ravel())
# Update the item similarity matrix with new coefficients (weights for each user-item interaction)
self.item_similarity[:, j] = model.coef_
# Reattach the item column after training
X[:, j] = y
return self
def partial_fit(self, interaction_matrix, updated_items):
"""
Incrementally fit the SLIMElastic model with new or updated items.
Args:
interaction_matrix (coo_matrix): user-item interaction matrix (sparse).
updated_items (list): List of item indices that were updated.
"""
if not isinstance(interaction_matrix, sp.csr_matrix):
raise ValueError("Interaction matrix must be a scipy.sparse.csr_matrix of user-item interactions.")
X = interaction_matrix.astype(np.float32)
model = ElasticNet(alpha=self.alpha, l1_ratio=self.l1_ratio, positive=self.positive_only, fit_intercept=False, max_iter=self.max_iter, tol=self.tol)
# Iterate through the updated items and fit the model incrementally
for j in updated_items:
# Target column (current item)
y = X[:, j]
# Set the target item column to 0
X[:, j] = 0
# Fit the model for the updated item
model.fit(X, y.toarray().ravel())
# Update the item similarity matrix with new coefficients (weights for each user-item interaction)
self.item_similarity[:, j] = model.coef_
# Reattach the item column after training
X[:, j] = y
return self
def predict(self, user_id, interaction_matrix):
"""
Compute the predicted scores for a specific user across all items.
Args:
user_id (int): The user ID (row index in interaction_matrix).
interaction_matrix (csr_matrix): User-item interaction matrix.
Returns:
numpy.ndarray: Predicted scores for the user across all items.
"""
if self.item_similarity is None:
raise RuntimeError("Model must be fitted before calling predict.")
# Compute the predicted scores by performing dot product between the user interaction vector
# and the item similarity matrix
return interaction_matrix[user_id].dot(self.item_similarity)
def predict_all(self, interaction_matrix):
return interaction_matrix.dot(self.item_similarity)
def recommend(self, user_id, interaction_matrix, top_k=10, exclude_seen=True):
"""
Recommend top-K items for a given user.
Args:
user_id (int): ID of the user (row index in interaction_matrix).
interaction_matrix (csr_matrix): User-item interaction matrix (sparse).
top_k (int): Number of recommendations to return.
exclude_seen (bool): Whether to exclude items the user has already interacted with.
Returns:
List of recommended item indices.
"""
# Get predicted scores for all items for the given user
user_scores = self.predict(user_id, interaction_matrix)
user_scores = user_scores.ravel()
# Exclude items that the user has already interacted with
if exclude_seen:
seen_items = interaction_matrix[user_id].indices
user_scores[seen_items] = -np.inf # Exclude seen items by setting scores to -inf
# Get the top-K items by sorting the predicted scores in descending order
# [::-1] reverses the order to get the items with the highest scores first
top_items = np.argsort(user_scores)[-top_k:][::-1]
return top_items
from scipy.sparse import csr_matrix
# Dummy user-item interaction data
data = [
[1, 0, 3, 0],
[0, 2, 0, 4],
[1, 0, 0, 5],
[0, 3, 0, 0]
]
interaction_matrix = csr_matrix(data)
# Initialize and train SLIM model
slim = SLIMElastic()
slim.fit(interaction_matrix)
# Initial Predictions
print("Initial Predicted Scores:\n", slim.predict_all(interaction_matrix))
# Incremental update: new interactions for items 2 and 3
new_data = [
[0, 0, 5, 1],
[0, 0, 0, 3],
[0, 0, 2, 0],
[0, 0, 0, 4]
]
new_interaction_matrix = csr_matrix(new_data)
# Update items 2 and 3
slim.partial_fit(new_interaction_matrix, updated_items=[2, 3])
# Updated Predictions
print("Updated Predicted Scores:\n", slim.predict_all(interaction_matrix))
# Recommendations for user 0
recommendations = slim.recommend(user_id=0, interaction_matrix=interaction_matrix, top_k=2)
print("Recommendations for user 0:", recommendations)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment