Created
April 2, 2018 14:29
-
-
Save a-h/e8600dadea283c8643b5b027506f8e33 to your computer and use it in GitHub Desktop.
KMeans on random 2D data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"math/rand" | |
"os" | |
"strconv" | |
"time" | |
"github.com/a-h/ml/clustering" | |
"github.com/a-h/ml/distance" | |
"gonum.org/v1/plot" | |
"gonum.org/v1/plot/plotter" | |
"gonum.org/v1/plot/plotutil" | |
"gonum.org/v1/plot/vg" | |
) | |
func init() { | |
rand.Seed(time.Now().Unix()) | |
} | |
func main() { | |
p, err := plot.New() | |
if err != nil { | |
fmt.Println("Error creating Plot: ", err) | |
os.Exit(-1) | |
} | |
p.Title.Text = "KMeans" | |
p.X.Min = 0 | |
p.X.Padding = 0 | |
p.X.Label.Text = "X" | |
p.Y.Min = 0 | |
p.Y.Padding = 0 | |
p.Y.Label.Text = "Y" | |
// Create some random data and assign to n clusters. | |
data := random2DVectors(50) | |
n := 3 | |
assignment, err := clustering.KMeans(data, n, distance.Euclidean) | |
if err != nil { | |
fmt.Println("Error clustering data: ", err) | |
os.Exit(-1) | |
} | |
// Get the clusters. | |
clusters, err := clustering.Assign(data, assignment) | |
if err != nil { | |
fmt.Println("Error assigning data to clusters: ", err) | |
os.Exit(-1) | |
} | |
// Convert them to scatter inputs (something that implements the XYer interface). | |
for i, cluster := range clusters { | |
scatter := convert2DVectorToPlotterXY(cluster) | |
// Add them to the chart. | |
err = addScatters(p, i, strconv.Itoa(i+1), scatter) | |
if err != nil { | |
panic(err) | |
} | |
} | |
// Save the plot to a PNG file. | |
if err := p.Save(15*vg.Centimeter, 15*vg.Centimeter, "points.png"); err != nil { | |
panic(err) | |
} | |
} | |
func convert2DVectorToPlotterXY(v []clustering.Vector) plotter.XYs { | |
pts := make(plotter.XYs, len(v)) | |
for i := 0; i < len(v); i++ { | |
pts[i] = xy{ | |
X: v[i][0], | |
Y: v[i][1], | |
} | |
} | |
return pts | |
} | |
type xy struct { | |
X, Y float64 | |
} | |
func random2DVectors(n int) []clustering.Vector { | |
op := make([]clustering.Vector, n) | |
for i := 0; i < n; i++ { | |
v := make(clustering.Vector, 2) | |
randomise(v, -10, 10) | |
op[i] = v | |
} | |
return op | |
} | |
func randomise(v []float64, min, max int) { | |
for i := 0; i < len(v); i++ { | |
v[i] = float64(rand.Intn(max-min) + min) | |
} | |
} | |
func addScatters(plt *plot.Plot, index int, name string, xyers plotter.XYs) error { | |
s, err := plotter.NewScatter(xyers) | |
if err != nil { | |
return err | |
} | |
s.Color = plotutil.Color(index) | |
s.Shape = plotutil.Shape(index) | |
plt.Add(s) | |
plt.Legend.Add(name) | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment