Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yuyasugano/0fe89f0984c76e596a3aae23687933cf to your computer and use it in GitHub Desktop.
Save yuyasugano/0fe89f0984c76e596a3aae23687933cf to your computer and use it in GitHub Desktop.
Training script for SageMaker scikit-learn
import argparse
import pandas as pd
import os
# GradientBoosting Classifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.externals import joblib
# Pipeline and StandardScaler
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Hyperparameters are described here. In this simple example we are just including one hyperparameter.
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--n_estimators', type=int, default=100)
# Sagemaker specific arguments. Defaults are set in the environment variables.
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
args = parser.parse_args()
# Take the set of files and read them all into a single pandas dataframe
input_files = [os.path.join(args.train, file) for file in os.listdir(args.train) ]
if len(input_files) == 0:
raise ValueError(('There are no files in {}.\n' +
'This usually indicates that the channel ({}) was incorrectly specified,\n' +
'the data specification in S3 was incorrectly specified or the role specified\n' +
'does not have permission to access the data.').format(args.train, "train"))
raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files]
train_data = pd.concat(raw_data)
# labels are in the last column, train data are in the latter columns
train_y = train_data.iloc[:,-1]
train_X = train_data.iloc[:,0:-1]
# Here we support a single hyperparameter
learning_rate = args.learning_rate
n_estimators = args.n_estimators
# Now use scikit-learn's decision tree classifier to train the model.
clf = GradientBoostingClassifier(learning_rate=learning_rate, n_estimators=n_estimators)
clf = clf.fit(train_X, train_y)
print(clf)
# The trained classifier, and save the coefficients
joblib.dump(clf, os.path.join(args.model_dir, "model.joblib"))
def model_fn(model_dir):
"""Deserialized and return fitted model
Note that this should have the same name as the serialized model in the main method
"""
clf = joblib.load(os.path.join(model_dir, "model.joblib"))
return clf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment