-
-
Save kwccoin/34cb0823b14f8cdbf988d175c94a49e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(defn f-beta | |
"F-beta score, default uses F1" | |
([precision recall] (f-beta precision recall 1)) | |
([precision recall beta] | |
(let [beta-squared (* beta beta)] | |
(* (+ 1 beta-squared) | |
(try ;; catch divide by 0 errors | |
(/ (* precision recall) | |
(+ (* beta-squared precision) recall)) | |
(catch ArithmeticException e | |
0)))))) | |
(defn f1-test-fn | |
"Test function that takes in two map arguments, global info and local epoch info. | |
Compares F1 score of current network to that of the previous network, | |
and returns map: | |
{:best-network? boolean | |
:network (assoc new-network :evaluation-score-to-compare)}" | |
[;; global arguments | |
{:keys [batch-size context]} | |
;per-epoch arguments | |
{:keys [new-network old-network test-ds]} ] | |
(let [batch-size (long batch-size) | |
test-results (execute/run new-network test-ds | |
:batch-size batch-size | |
:loss-outputs? true | |
:context context) | |
;;; test metrics | |
test-actual (mapv #(vec->label [0.0 1.0] %) (map :label test-ds)) | |
test-pred (mapv #(vec->label [0.0 1.0] % [1 0.9]) (map :label test-results)) | |
precision (metrics/precision test-actual test-pred) | |
recall (metrics/recall test-actual test-pred) | |
f-beta (f-beta precision recall) | |
;; if current f-beta higher than the old network's, current is best network | |
best-network? (or (nil? (get old-network :cv-score)) | |
(> f-beta (get old-network :cv-score))) | |
updated-network (assoc new-network :cv-score f-beta) | |
epoch (get new-network :epoch-count)] | |
(experiment-train/save-network updated-network network-file) | |
(log (str "Epoch: " epoch "\n" | |
"Precision: " precision "\n" | |
"Recall: " recall "\n" | |
"F1: " f-beta "\n\n")) | |
{:best-network? best-network? | |
:network updated-network})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment