Skip to content

Instantly share code, notes, and snippets.

@usausa
Created July 2, 2025 09:13
Show Gist options
  • Save usausa/906517f23a523b16f54d9d126e06aaa9 to your computer and use it in GitHub Desktop.
Save usausa/906517f23a523b16f54d9d126e06aaa9 to your computer and use it in GitHub Desktop.
Bitmap color clustering by ML.NET
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using SkiaSharp;
public record ColorCount(
byte R,
byte G,
byte B,
int Count);
public sealed class RgbData
{
public float R;
public float G;
public float B;
}
public sealed class ClusterPrediction
{
[ColumnName("PredictedLabel")]
public uint ClusterId;
[ColumnName("Score")]
public float[] Distances = default!;
}
public List<ColorCount> ClusterColors(
SKBitmap bitmap,
int maxClusters,
int maxIterations,
float tolerance)
{
var width = bitmap.Width;
var height = bitmap.Height;
var colors = new HashSet<SKColor>();
var pixels = new RgbData[width * height];
var index = 0;
for (var y = 0; y < height; y++)
{
for (var x = 0; x < width; x++)
{
var color = bitmap.GetPixel(x, y);
pixels[index++] = new RgbData { R = color.Red, G = color.Green, B = color.Blue };
if (colors.Count < maxClusters)
{
colors.Add(color);
}
}
}
var actualClusters = Math.Min(maxClusters, colors.Count);
// KMeans
var mlContext = new MLContext();
var dataView = mlContext.Data.LoadFromEnumerable(pixels);
var options = new KMeansTrainer.Options
{
FeatureColumnName = "Features",
NumberOfClusters = actualClusters,
MaximumNumberOfIterations = maxIterations,
OptimizationTolerance = tolerance,
//InitializationAlgorithm = KMeansTrainer.InitializationAlgorithm.KMeansPlusPlus
InitializationAlgorithm = KMeansTrainer.InitializationAlgorithm.Random
};
var pipeline = mlContext.Transforms
.Concatenate("Features", nameof(RgbData.R), nameof(RgbData.G), nameof(RgbData.B))
.Append(mlContext.Clustering.Trainers.KMeans(options));
var model = pipeline.Fit(dataView);
var transformed = model.Transform(dataView);
// Get center
var centroids = default(VBuffer<float>[]);
model.LastTransformer.Model.GetClusterCentroids(ref centroids, out _);
// Count
var counts = new int[actualClusters];
foreach (var prediction in mlContext.Data.CreateEnumerable<ClusterPrediction>(transformed, reuseRowObject: false))
{
counts[prediction.ClusterId - 1]++;
}
var list = new List<ColorCount>(actualClusters);
for (var i = 0; i < counts.Length; i++)
{
var centroid = centroids[i].DenseValues().ToArray();
var r = (byte)Math.Round(centroid[0]);
var g = (byte)Math.Round(centroid[1]);
var b = (byte)Math.Round(centroid[2]);
list.Add(new ColorCount(r, g, b, counts[i]));
}
list.Sort(static (x, y) => y.Count - x.Count);
return list;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment