Last active
February 5, 2020 13:00
-
-
Save RoaldSchuring/cb4857602b82fbcbb65b105ac0f5f859 to your computer and use it in GitHub Desktop.
Wine Recommender Sklearn Nearest Neighbors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import pandas as pd | |
import os | |
from sklearn.externals import joblib | |
from sklearn.neighbors import NearestNeighbors | |
import numpy as np | |
import subprocess | |
import sys | |
def install(package): | |
subprocess.call([sys.executable, "-m", "pip", "install", package]) | |
install('s3fs') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# Hyperparameters are described here. | |
parser.add_argument('--n_neighbors', type=int, default=10) | |
parser.add_argument('--metric', type=str, default='cosine') | |
# Sagemaker specific arguments. Defaults are set in the environment variables. | |
parser.add_argument('--output-data-dir', type=str, default='s3://data-science-wine-reviews/nearest_neighbors/output_data') | |
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) | |
parser.add_argument('--train', type=str, default='s3://data-science-wine-reviews/nearest_neighbors/data/wine_review_vectors.csv') | |
args = parser.parse_args() | |
# Load the training data into a Pandas dataframe and make sure it is in the appropriate format | |
raw_data = pd.read_csv(args.train) | |
def convert_to_list(raw_review_vec): | |
review_vec_trimmed = raw_review_vec.replace('[', '').replace(']', '') | |
review_vec = np.fromstring(review_vec_trimmed, dtype=float, sep=' ') | |
review_vec_list = review_vec.tolist() | |
return review_vec_list | |
raw_data['review_vec'] = raw_data['review_vector'].apply(convert_to_list) | |
wine_vectors_list = np.array(list(raw_data['review_vec'])) | |
# Supply the hyperparameters of the nearest neighbors model | |
n_neighbors = args.n_neighbors | |
metric = args.metric | |
# Now, fit the nearest neighbors model | |
nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric) | |
model_nn = nn.fit(wine_vectors_list) | |
print('model has been fitted') | |
# Save the model to the output location in S3 | |
joblib.dump(model_nn, os.path.join(args.model_dir, "model.joblib")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment