Created
January 21, 2017 11:37
-
-
Save aligusnet/21fff32ea3e6bb2ee9eac259474a995c to your computer and use it in GitHub Desktop.
Process outputs for Multiclass Classification.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import qualified Data.Vector.Storable as V | |
import qualified Numeric.LinearAlgebra as LA | |
-- | Process outputs for Multiclass Classification. | |
-- Takes number of labels and output vector y. | |
-- Returns matrix of binary outputs (One-vs-All Classification). | |
-- It is supposed that labels are integerets start at 0. | |
processOutputMulti :: Int -> Vector -> Matrix | |
processOutputMulti numLabels y = LA.fromColumns $ map f [0 .. numLabels-1] | |
where f sample = V.map (\a -> if round a == sample then 1 else 0) y | |
processOutputMulti numLabels y = LA.assoc (V.length y, numLabels) 0 assocList | |
where assocList = zipWith (\index label -> ((index, round label), 1)) [0..] (V.toList y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment