Skip to content

Instantly share code, notes, and snippets.

@prakhar1989
Created November 15, 2015 20:23
Show Gist options
  • Save prakhar1989/a49bd385053384137417 to your computer and use it in GitHub Desktop.
Save prakhar1989/a49bd385053384137417 to your computer and use it in GitHub Desktop.
K means in Go
package main
import (
"image"
"image/color"
"image/png"
"log"
"math"
"math/rand"
"os"
"strconv"
)
func norm(c1 color.Color, c2 color.Color) float64 {
r1, g1, b1, _ := c1.RGBA()
r2, g2, b2, _ := c2.RGBA()
dist := (r1-r2)*(r1-r2) + (g1-g2)*(g1-g2) + (b1-b2)*(b1-b2)
return math.Sqrt(float64(dist))
}
func getMin(nums []float64) (float64, int) {
minIndex, minValue := 0, nums[0]
for i := 1; i < len(nums); i++ {
if nums[i] < minValue {
minValue = nums[i]
minIndex = i
}
}
return minValue, minIndex
}
func main() {
// set configuration here
const imageFilePath string = "./tree.png"
if len(os.Args) != 3 {
log.Fatal("USAGE: k-means <clusters> <number of iterations>")
}
K, err := strconv.Atoi(os.Args[1])
if err != nil {
log.Fatal("USAGE: k-means <clusters> <number of iterations>")
}
maxIter, err := strconv.Atoi(os.Args[2])
if err != nil {
log.Fatal("USAGE: k-means <clusters> <number of iterations>")
}
// reading the image
file, err := os.Open(imageFilePath)
if err != nil {
log.Fatal(err)
}
defer file.Close()
img, err := png.Decode(file)
if err != nil {
log.Fatal(err)
}
// setting the max size
var N int = img.Bounds().Max.X
// create the centroid array
centroids := make([]color.Color, K)
for i := 0; i < K; i++ {
centroids[i] = img.At(rand.Intn(N), rand.Intn(N))
}
for iter := 0; iter < maxIter; iter++ {
// initializing cluster infos
clusterInfo := make([][]int, N)
for i := 0; i < N; i++ {
clusterInfo[i] = make([]int, N)
}
// calculate distances
for i := 0; i < N; i++ {
for j := 0; j < N; j++ {
distances := make([]float64, K)
for k := 0; k < K; k++ {
distances[k] = norm(img.At(i, j), centroids[k])
}
_, index := getMin(distances)
clusterInfo[i][j] = index
}
}
// finding new centroid means
newCentroids := make([][]uint32, K)
for i := 0; i < K; i++ {
newCentroids[i] = make([]uint32, 3)
}
pointsInCluster := make([]uint32, K)
for k := 0; k < K; k++ {
for i := 0; i < N; i++ {
for j := 0; j < N; j++ {
if clusterInfo[i][j] == k {
pointsInCluster[k] += 1
r, g, b, _ := img.At(i, j).RGBA()
newCentroids[k][0] += uint32(r)
newCentroids[k][1] += uint32(g)
newCentroids[k][2] += uint32(b)
}
}
}
}
for k := 0; k < K; k++ {
if pointsInCluster[k] > 0 {
r := newCentroids[k][0] / pointsInCluster[k]
g := newCentroids[k][1] / pointsInCluster[k]
b := newCentroids[k][2] / pointsInCluster[k]
centroids[k] = color.RGBA64{uint16(r), uint16(g), uint16(b), 255}
}
}
// creating a new image
newImg := image.NewRGBA(image.Rect(0, 0, N, N))
for i := 0; i < N; i++ {
for j := 0; j < N; j++ {
r, g, b, _ := centroids[clusterInfo[i][j]].RGBA()
c := color.RGBA{uint8(r / 255), uint8(g / 255), uint8(b / 255), 255}
newImg.Set(i, j, c)
}
}
if iter == maxIter-1 {
filename := "img-" + strconv.Itoa(maxIter) + ".png"
out_file, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
defer out_file.Close()
log.Print("Saving image to:", filename)
png.Encode(out_file, newImg)
}
}
}
@prakhar1989
Copy link
Author

kmeans

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment