Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Sandy4321/c26959779b34e45448743afba8328629 to your computer and use it in GitHub Desktop.
Save Sandy4321/c26959779b34e45448743afba8328629 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import argparse
import datetime
import pandas as pd
import os
from sklearn import preprocessing
from sklearn.metrics import mean_squared_error, make_scorer
from sklearn.externals import joblib
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from tensorflow.python.lib.io import file_io
from models import model_zoo
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
INPUT_FEAT_NAMES = ['Age', 'City_Category', 'Gender',
'Marital_Status', 'Occupation',
'Product_Category_1', 'Product_Category_2',
'Product_Category_3', 'Stay_In_Current_City_Years',
'Product_ID', 'User_ID']
TARGET_NAME = 'Purchase'
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--csv-data-path',
dest='csv_data_path', required=True)
parser.add_argument('--model-dir',
dest='model_dir', required=True)
parser.add_argument('--model-name',
dest='model_name', required=True)
parser.add_argument('--job-dir',
dest='job_dir', required=False, default='/tmp')
args = parser.parse_args()
df_raw = pd.read_csv(args.csv_data_path)
logger.info(df_raw.head())
# Split data into train and eval set
df_raw_X = df_raw[INPUT_FEAT_NAMES]
df_raw_Y = df_raw[TARGET_NAME].values
df_raw_train_X, df_raw_eval_X, df_raw_train_Y, df_raw_eval_Y = train_test_split(
df_raw_X, df_raw_Y, test_size=0.2)
# Train model
model = model_zoo.get_model(args.model_name)
logger.info('Training {}'.format(args.model_name))
model.fit(df_raw_train_X, df_raw_train_Y)
logger.info('Done')
# Evaluate
eval_pred = model.predict(df_raw_eval_X)
logger.info("MSE: {}".format(mean_squared_error(df_raw_eval_Y, eval_pred)))
# Dump model
model_buffer = file_io.FileIO(os.path.join(
args.model_dir, 'model.joblib'), 'w')
joblib.dump(model, model_buffer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment