Created
July 2, 2025 09:13
-
-
Save usausa/906517f23a523b16f54d9d126e06aaa9 to your computer and use it in GitHub Desktop.
Bitmap color clustering by ML.NET
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Microsoft.ML; | |
using Microsoft.ML.Data; | |
using Microsoft.ML.Trainers; | |
using SkiaSharp; | |
public record ColorCount( | |
byte R, | |
byte G, | |
byte B, | |
int Count); | |
public sealed class RgbData | |
{ | |
public float R; | |
public float G; | |
public float B; | |
} | |
public sealed class ClusterPrediction | |
{ | |
[ColumnName("PredictedLabel")] | |
public uint ClusterId; | |
[ColumnName("Score")] | |
public float[] Distances = default!; | |
} | |
public List<ColorCount> ClusterColors( | |
SKBitmap bitmap, | |
int maxClusters, | |
int maxIterations, | |
float tolerance) | |
{ | |
var width = bitmap.Width; | |
var height = bitmap.Height; | |
var colors = new HashSet<SKColor>(); | |
var pixels = new RgbData[width * height]; | |
var index = 0; | |
for (var y = 0; y < height; y++) | |
{ | |
for (var x = 0; x < width; x++) | |
{ | |
var color = bitmap.GetPixel(x, y); | |
pixels[index++] = new RgbData { R = color.Red, G = color.Green, B = color.Blue }; | |
if (colors.Count < maxClusters) | |
{ | |
colors.Add(color); | |
} | |
} | |
} | |
var actualClusters = Math.Min(maxClusters, colors.Count); | |
// KMeans | |
var mlContext = new MLContext(); | |
var dataView = mlContext.Data.LoadFromEnumerable(pixels); | |
var options = new KMeansTrainer.Options | |
{ | |
FeatureColumnName = "Features", | |
NumberOfClusters = actualClusters, | |
MaximumNumberOfIterations = maxIterations, | |
OptimizationTolerance = tolerance, | |
//InitializationAlgorithm = KMeansTrainer.InitializationAlgorithm.KMeansPlusPlus | |
InitializationAlgorithm = KMeansTrainer.InitializationAlgorithm.Random | |
}; | |
var pipeline = mlContext.Transforms | |
.Concatenate("Features", nameof(RgbData.R), nameof(RgbData.G), nameof(RgbData.B)) | |
.Append(mlContext.Clustering.Trainers.KMeans(options)); | |
var model = pipeline.Fit(dataView); | |
var transformed = model.Transform(dataView); | |
// Get center | |
var centroids = default(VBuffer<float>[]); | |
model.LastTransformer.Model.GetClusterCentroids(ref centroids, out _); | |
// Count | |
var counts = new int[actualClusters]; | |
foreach (var prediction in mlContext.Data.CreateEnumerable<ClusterPrediction>(transformed, reuseRowObject: false)) | |
{ | |
counts[prediction.ClusterId - 1]++; | |
} | |
var list = new List<ColorCount>(actualClusters); | |
for (var i = 0; i < counts.Length; i++) | |
{ | |
var centroid = centroids[i].DenseValues().ToArray(); | |
var r = (byte)Math.Round(centroid[0]); | |
var g = (byte)Math.Round(centroid[1]); | |
var b = (byte)Math.Round(centroid[2]); | |
list.Add(new ColorCount(r, g, b, counts[i])); | |
} | |
list.Sort(static (x, y) => y.Count - x.Count); | |
return list; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment