Created
December 3, 2019 11:59
-
-
Save yuyasugano/0fe89f0984c76e596a3aae23687933cf to your computer and use it in GitHub Desktop.
Training script for SageMaker scikit-learn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import pandas as pd | |
import os | |
# GradientBoosting Classifier | |
from sklearn.ensemble import GradientBoostingClassifier | |
from sklearn.externals import joblib | |
# Pipeline and StandardScaler | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.pipeline import Pipeline | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# Hyperparameters are described here. In this simple example we are just including one hyperparameter. | |
parser.add_argument('--learning_rate', type=float, default=0.1) | |
parser.add_argument('--n_estimators', type=int, default=100) | |
# Sagemaker specific arguments. Defaults are set in the environment variables. | |
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) | |
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) | |
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) | |
args = parser.parse_args() | |
# Take the set of files and read them all into a single pandas dataframe | |
input_files = [os.path.join(args.train, file) for file in os.listdir(args.train) ] | |
if len(input_files) == 0: | |
raise ValueError(('There are no files in {}.\n' + | |
'This usually indicates that the channel ({}) was incorrectly specified,\n' + | |
'the data specification in S3 was incorrectly specified or the role specified\n' + | |
'does not have permission to access the data.').format(args.train, "train")) | |
raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files] | |
train_data = pd.concat(raw_data) | |
# labels are in the last column, train data are in the latter columns | |
train_y = train_data.iloc[:,-1] | |
train_X = train_data.iloc[:,0:-1] | |
# Here we support a single hyperparameter | |
learning_rate = args.learning_rate | |
n_estimators = args.n_estimators | |
# Now use scikit-learn's decision tree classifier to train the model. | |
clf = GradientBoostingClassifier(learning_rate=learning_rate, n_estimators=n_estimators) | |
clf = clf.fit(train_X, train_y) | |
print(clf) | |
# The trained classifier, and save the coefficients | |
joblib.dump(clf, os.path.join(args.model_dir, "model.joblib")) | |
def model_fn(model_dir): | |
"""Deserialized and return fitted model | |
Note that this should have the same name as the serialized model in the main method | |
""" | |
clf = joblib.load(os.path.join(model_dir, "model.joblib")) | |
return clf |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment