Created
May 23, 2019 14:13
-
-
Save justinormont/85dd47c513ceae66df2a498c204fc269 to your computer and use it in GitHub Desktop.
ML.NET Feature Importance for Multiclass Logistic Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
This demonstrates feature importance for a multiclass logistic regression in ML.NET | |
var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyMulticlassTrainer.Options() { LabelColumnName = "MyLabel", FeatureColumnName = "Features" }) | |
*/ | |
var p = (Microsoft.ML.Data.EstimatorChain<Microsoft.ML.Data.MulticlassPredictionTransformer<Microsoft.ML.Trainers.MaximumEntropyModelParameters>>)trainingPipeline; | |
// Train the model | |
var model = p.Fit(trainingDataView); | |
var linearPredictor = model.LastTransformer; | |
// Transform the dataset. | |
var transformedData = model.Transform(trainingDataView); | |
var schema = transformedData.Preview().Schema; | |
// Get the feature weights per class | |
VBuffer<float>[] weights = null; | |
model.LastTransformer.Model.GetWeights(ref weights, out int classes); | |
var denseWeights = weights.Select((arg) => arg.DenseValues()); | |
// Get the biases per class | |
var biases = model.LastTransformer.Model.GetBiases().ToArray(); | |
// Get slot names | |
var featuresCol = schema.Last(col => col.Name == "Features"); | |
var slotNames = new VBuffer<ReadOnlyMemory<char>>(); | |
featuresCol.GetSlotNames(ref slotNames); | |
var denseSlotNames = slotNames.DenseValues(); | |
int classIdx = 0; | |
denseWeights.ToList().ForEach(classWeights => { | |
Console.WriteLine("== Class " + classIdx + " =="); | |
Console.WriteLine("bias: " + biases[classIdx]); | |
var weightPlusSlotNamesTuples = classWeights.ToList().Zip(denseSlotNames, (a, b) => Tuple.Create(a, b)).Where((t) => t.Item1 != 0.0); | |
var weightPlusSlotNamesTuplesSorted = weightPlusSlotNamesTuples.OrderBy(t => -Math.Abs(t.Item1)); | |
var weightPlusSlotNamesTuplesSortedFirst20 = weightPlusSlotNamesTuplesSorted.Take(20); | |
weightPlusSlotNamesTuplesSortedFirst20.ToList().ForEach(t => Console.WriteLine(t.Item1.ToString() + " " + t.Item2)); | |
Console.WriteLine("\n\n"); | |
classIdx++; | |
}); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment