Skip to content

Instantly share code, notes, and snippets.

@justinormont
Created May 23, 2019 14:13
Show Gist options
  • Save justinormont/85dd47c513ceae66df2a498c204fc269 to your computer and use it in GitHub Desktop.
Save justinormont/85dd47c513ceae66df2a498c204fc269 to your computer and use it in GitHub Desktop.
ML.NET Feature Importance for Multiclass Logistic Regression
/*
This demonstrates feature importance for a multiclass logistic regression in ML.NET
var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyMulticlassTrainer.Options() { LabelColumnName = "MyLabel", FeatureColumnName = "Features" })
*/
var p = (Microsoft.ML.Data.EstimatorChain<Microsoft.ML.Data.MulticlassPredictionTransformer<Microsoft.ML.Trainers.MaximumEntropyModelParameters>>)trainingPipeline;
// Train the model
var model = p.Fit(trainingDataView);
var linearPredictor = model.LastTransformer;
// Transform the dataset.
var transformedData = model.Transform(trainingDataView);
var schema = transformedData.Preview().Schema;
// Get the feature weights per class
VBuffer<float>[] weights = null;
model.LastTransformer.Model.GetWeights(ref weights, out int classes);
var denseWeights = weights.Select((arg) => arg.DenseValues());
// Get the biases per class
var biases = model.LastTransformer.Model.GetBiases().ToArray();
// Get slot names
var featuresCol = schema.Last(col => col.Name == "Features");
var slotNames = new VBuffer<ReadOnlyMemory<char>>();
featuresCol.GetSlotNames(ref slotNames);
var denseSlotNames = slotNames.DenseValues();
int classIdx = 0;
denseWeights.ToList().ForEach(classWeights => {
Console.WriteLine("== Class " + classIdx + " ==");
Console.WriteLine("bias: " + biases[classIdx]);
var weightPlusSlotNamesTuples = classWeights.ToList().Zip(denseSlotNames, (a, b) => Tuple.Create(a, b)).Where((t) => t.Item1 != 0.0);
var weightPlusSlotNamesTuplesSorted = weightPlusSlotNamesTuples.OrderBy(t => -Math.Abs(t.Item1));
var weightPlusSlotNamesTuplesSortedFirst20 = weightPlusSlotNamesTuplesSorted.Take(20);
weightPlusSlotNamesTuplesSortedFirst20.ToList().ForEach(t => Console.WriteLine(t.Item1.ToString() + " " + t.Item2));
Console.WriteLine("\n\n");
classIdx++;
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment