Created
February 20, 2021 17:22
-
-
Save behrica/68972a626ab797931232bec5611341f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns sciloj.evaluate | |
(:require [tech.v3.dataset.modelling :as ds-mod] | |
[tech.v3.datatype.functional :as dfn] | |
[tech.v3.datatype.argops :as argops] | |
[tech.v3.dataset :as ds] | |
[tech.v3.datatype.errors :as errors] | |
[tech.v3.dataset.column-filters :as cf] | |
[tech.v3.ml.loss :as loss] | |
[clojure.tools.logging :as log] | |
[tech.v3.ml :as ml] | |
[tech.v3.ml.gridsearch :as ml-gs] | |
[pppmap.core :as ppp] | |
)) | |
(defn default-loss-fn | |
"Given a datset which must have exactly 1 inference target column return a default | |
loss fn. If column is categorical, loss is tech.v3.ml.loss/classification-loss, else | |
the loss is tech.v3.ml.loss/mae (mean average error)." | |
[dataset] | |
(let [target-ds (cf/target dataset)] | |
(errors/when-not-errorf | |
(== 1 (ds/column-count target-ds)) | |
"Dataset has more than 1 target specified: %d" | |
(ds/column-count target-ds)) | |
(if (:categorical? (meta (first (vals target-ds)))) | |
loss/classification-loss | |
loss/mae))) | |
(defn train-split | |
"Train a model splitting the dataset using tech.v3.dataset.modelling/train-test-split | |
and then calculate the loss using loss-fn. Loss is added to the model map under :loss. | |
* `loss-fn` defaults to loss/mae if target column is not categorical else defaults to | |
loss/classification-loss." | |
([dataset pipeline-fn options loss-fn] | |
(let [{:keys [train-ds test-ds]} (ds-mod/train-test-split dataset options) | |
fitted-ctx (pipeline-fn {:metamorph/mode :fit :metamorph/data train-ds}) | |
predicted-ctx (pipeline-fn (merge fitted-ctx {:metamorph/mode :transform :metamorph/data test-ds}) ) | |
predictions (:metamorph/data predicted-ctx) | |
target-colname (first (ds/column-names (cf/target (:metamorph/data fitted-ctx) )))] | |
(assoc predicted-ctx :loss (loss-fn (test-ds target-colname) | |
(predictions target-colname))))) ) | |
(defn do-k-fold [pipeline-fn loss-fn target-colname ds-seq] | |
(let [;; ds-seq (ds-mod/k-fold-datasets ds k {:randomize-dataset? false}) | |
models (mapv (fn [{:keys [train-ds test-ds]}] | |
(let [fitted-ctx (pipeline-fn {:metamorph/mode :fit :metamorph/data train-ds}) | |
predicted-ctx (pipeline-fn (merge fitted-ctx {:metamorph/mode :transform :metamorph/data test-ds}) ) | |
predictions (:metamorph/data predicted-ctx) | |
target-colname (first (ds/column-names (cf/target (:metamorph/data fitted-ctx) ))) | |
] | |
(assoc predicted-ctx :loss (loss-fn (predictions target-colname) | |
(test-ds target-colname))))) | |
ds-seq) | |
loss-vec (mapv :loss models) | |
{min-loss :min | |
max-loss :max | |
avg-loss :mean} | |
(dfn/descriptive-statistics [:min :max :mean] loss-vec) | |
min-model-idx (argops/argmin loss-vec)] | |
(assoc (models min-model-idx) | |
:min-loss min-loss | |
:max-loss max-loss | |
:avg-loss avg-loss | |
:n-k-folds (count ds-seq)))) | |
(defn train-k-fold | |
"Train a model across k-fold datasets using tech.v3.dataset.modelling/k-fold-dataset | |
and then calculate the min,max,and avg across results using loss-fn. Adds | |
:n-k-folds, :min-loss, :max-loss, :avg-loss and :loss (min-loss) to the | |
model with the lowest loss. | |
* `n-k-folds` defaults to 5. | |
* `loss-fn` defaults to loss/mae if target column is not categorical else defaults to | |
loss/classification-loss." | |
([dataset pipeline-fn n-k-folds loss-fn] | |
(let [ dataset (:metamorph/data (pipeline-fn {:metamorph/data dataset :metamorph/mode :fit})) | |
target-colname (first (ds/column-names (cf/target dataset)))] | |
(do-k-fold pipeline-fn loss-fn target-colname | |
(ds-mod/k-fold-datasets dataset n-k-folds) | |
))) | |
([dataset pipeline-fn n-k-folds] | |
(train-k-fold pipeline-fn dataset n-k-folds (ml/default-loss-fn dataset))) | |
([dataset pipeline-fn] | |
(train-k-fold pipeline-fn dataset 5 (ml/default-loss-fn dataset)))) | |
(defn- pprint-to-string [o] | |
(let [out (java.io.StringWriter.)] | |
(clojure.pprint/pprint o out) | |
(.toString out))) | |
(defn- safe-do-k-fold | |
[pipeline-fn loss-fn target-colname ds-seq] | |
(try | |
(do-k-fold pipeline-fn loss-fn target-colname ds-seq) | |
(catch Exception e | |
(log/error e "Exception caught during grid search ")))) | |
(defn train-auto-gridsearch | |
"Train a model gridsearching across the options map. The gridsearch map is built by | |
merging the model's hyperparameter definitions into the options map. If the sobol | |
sequence returned has only one element a warning is issued. Note this returns a | |
sequence of models as opposed to a single model. | |
* Searches across k-fold datasets if n-k-folds is > 1. n-k-folds defaults to 5. | |
* Searches (in parallel) through n-gridsearch option maps created via | |
sobol-gridsearch. | |
* Returns n-result-models (defaults to 5) sorted by avg-loss. | |
* loss-fn can be provided or is the loss-fn returned via default-loss-fn." | |
([dataset pipeline-create-fn options {:keys [n-k-folds | |
n-gridsearch | |
n-result-models | |
loss-fn] | |
:or {n-k-folds 5 | |
n-gridsearch 75 | |
n-result-models 5} | |
:as gridsearch-options}] | |
(let [ | |
;; loss-fn (or loss-fn (ml/default-loss-fn dataset)) | |
;; options (merge (ml/hyperparameters (:model-type options)) options) | |
gs-seq (take n-gridsearch (ml-gs/sobol-gridsearch options)) | |
dataset (:metamorph/data ((pipeline-create-fn options) {:metamorph/data dataset :metamorph/mode :fit})) | |
target-colname (first (ds/column-names (cf/target dataset))) | |
_ (when (== 1 (count gs-seq)) | |
(log/warn "Did not find any gridsearch axis in options map")) | |
ds-seq (ds-mod/k-fold-datasets dataset n-k-folds gridsearch-options)] | |
(->> gs-seq | |
(ppp/ppmap-with-progress "gridsearch" 1 #(safe-do-k-fold (pipeline-create-fn %) loss-fn target-colname ds-seq)) | |
(sort-by :avg-loss) | |
(take n-result-models) | |
))) | |
([dataset pipeline-create-fn options] | |
(train-auto-gridsearch dataset pipeline-create-fn options nil))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment