Skip to content

Instantly share code, notes, and snippets.

@Nanthini10
Created May 13, 2020 16:39
Show Gist options
  • Save Nanthini10/c93d9d8d2401ae96f8c68cfdf58208f8 to your computer and use it in GitHub Desktop.
Save Nanthini10/c93d9d8d2401ae96f8c68cfdf58208f8 to your computer and use it in GitHub Desktop.
Creating a RayTune Trainable model and evaluating performance
def _train(self):
iteration = getattr(self, "iteration", 0)
if compute == "GPU":
# split data
X_train, X_test, y_train, y_test = train_test_split(
X=self._dataset,
y=self._y_label,
train_size=0.8,
shuffle=True,
random_state=iteration,
)
self.rf_model = cuml.ensemble.RandomForestClassifier(
n_estimators=self._model_params["n_estimators"],
max_depth=self._model_params["max_depth"],
n_bins=self._model_params["n_bins"],
max_features=self._model_params["max_features"],
)
elif compute == "CPU":
# Optionally allow CPU version for performance comparison
X_train, X_test, y_train, y_test = sktrain_test_split(
self._dataset.loc[:, self._dataset.columns != self._y_label],
self._dataset[self._y_label],
train_size=0.8,
shuffle=True,
random_state=iteration,
)
self.rf_model = sklearn.ensemble.RandomForestClassifier(
n_estimators=self._model_params["n_estimators"],
max_depth=self._model_params["max_depth"],
max_features=self._model_params["max_features"],
n_jobs=-1,
)
else:
print("Unknown option. Please select CPU or GPU")
return
# train model
with PerfTimer() as train_timer:
trained_model = self.rf_model.fit(X_train, y_train)
training_time = train_timer.duration
# evaluate perf
with PerfTimer() as inference_timer:
test_accuracy = trained_model.score(X_test, y_test.astype("int32"))
infer_time = inference_timer.duration
# update best model [ assumes maximization of perf metric ]
if test_accuracy > self._global_best_test_accuracy:
self._global_best_test_accuracy = test_accuracy
self._global_best_model = trained_model
return {
"test_accuracy": test_accuracy,
"train_time": round(training_time, 4),
"infer_time": round(infer_time, 4),
"is_bad": not math.isfinite(test_accuracy),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment