Skip to content

Instantly share code, notes, and snippets.

@RoaldSchuring
Last active February 5, 2020 13:00
Show Gist options
  • Save RoaldSchuring/cb4857602b82fbcbb65b105ac0f5f859 to your computer and use it in GitHub Desktop.
Save RoaldSchuring/cb4857602b82fbcbb65b105ac0f5f859 to your computer and use it in GitHub Desktop.
Wine Recommender Sklearn Nearest Neighbors
import argparse
import pandas as pd
import os
from sklearn.externals import joblib
from sklearn.neighbors import NearestNeighbors
import numpy as np
import subprocess
import sys
def install(package):
subprocess.call([sys.executable, "-m", "pip", "install", package])
install('s3fs')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Hyperparameters are described here.
parser.add_argument('--n_neighbors', type=int, default=10)
parser.add_argument('--metric', type=str, default='cosine')
# Sagemaker specific arguments. Defaults are set in the environment variables.
parser.add_argument('--output-data-dir', type=str, default='s3://data-science-wine-reviews/nearest_neighbors/output_data')
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default='s3://data-science-wine-reviews/nearest_neighbors/data/wine_review_vectors.csv')
args = parser.parse_args()
# Load the training data into a Pandas dataframe and make sure it is in the appropriate format
raw_data = pd.read_csv(args.train)
def convert_to_list(raw_review_vec):
review_vec_trimmed = raw_review_vec.replace('[', '').replace(']', '')
review_vec = np.fromstring(review_vec_trimmed, dtype=float, sep=' ')
review_vec_list = review_vec.tolist()
return review_vec_list
raw_data['review_vec'] = raw_data['review_vector'].apply(convert_to_list)
wine_vectors_list = np.array(list(raw_data['review_vec']))
# Supply the hyperparameters of the nearest neighbors model
n_neighbors = args.n_neighbors
metric = args.metric
# Now, fit the nearest neighbors model
nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric)
model_nn = nn.fit(wine_vectors_list)
print('model has been fitted')
# Save the model to the output location in S3
joblib.dump(model_nn, os.path.join(args.model_dir, "model.joblib"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment