Created
November 21, 2013 23:35
-
-
Save loicknuchel/7591918 to your computer and use it in GitHub Desktop.
Writting some code in scala and java about collections.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa | |
---|---|---|---|---|---|
4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa | |
4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa | |
4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa | |
5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa | |
5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa | |
4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa | |
5.0 | 3.4 | 1.5 | 0.2 | Iris-setosa | |
4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa | |
4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa | |
5.4 | 3.7 | 1.5 | 0.2 | Iris-setosa | |
4.8 | 3.4 | 1.6 | 0.2 | Iris-setosa | |
4.8 | 3.0 | 1.4 | 0.1 | Iris-setosa | |
4.3 | 3.0 | 1.1 | 0.1 | Iris-setosa | |
5.8 | 4.0 | 1.2 | 0.2 | Iris-setosa | |
5.7 | 4.4 | 1.5 | 0.4 | Iris-setosa | |
5.4 | 3.9 | 1.3 | 0.4 | Iris-setosa | |
5.1 | 3.5 | 1.4 | 0.3 | Iris-setosa | |
5.7 | 3.8 | 1.7 | 0.3 | Iris-setosa | |
5.1 | 3.8 | 1.5 | 0.3 | Iris-setosa | |
5.4 | 3.4 | 1.7 | 0.2 | Iris-setosa | |
5.1 | 3.7 | 1.5 | 0.4 | Iris-setosa | |
4.6 | 3.6 | 1.0 | 0.2 | Iris-setosa | |
5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa | |
4.8 | 3.4 | 1.9 | 0.2 | Iris-setosa | |
5.0 | 3.0 | 1.6 | 0.2 | Iris-setosa | |
5.0 | 3.4 | 1.6 | 0.4 | Iris-setosa | |
5.2 | 3.5 | 1.5 | 0.2 | Iris-setosa | |
5.2 | 3.4 | 1.4 | 0.2 | Iris-setosa | |
4.7 | 3.2 | 1.6 | 0.2 | Iris-setosa | |
4.8 | 3.1 | 1.6 | 0.2 | Iris-setosa | |
5.4 | 3.4 | 1.5 | 0.4 | Iris-setosa | |
5.2 | 4.1 | 1.5 | 0.1 | Iris-setosa | |
5.5 | 4.2 | 1.4 | 0.2 | Iris-setosa | |
4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa | |
5.0 | 3.2 | 1.2 | 0.2 | Iris-setosa | |
5.5 | 3.5 | 1.3 | 0.2 | Iris-setosa | |
4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa | |
4.4 | 3.0 | 1.3 | 0.2 | Iris-setosa | |
5.1 | 3.4 | 1.5 | 0.2 | Iris-setosa | |
5.0 | 3.5 | 1.3 | 0.3 | Iris-setosa | |
4.5 | 2.3 | 1.3 | 0.3 | Iris-setosa | |
4.4 | 3.2 | 1.3 | 0.2 | Iris-setosa | |
5.0 | 3.5 | 1.6 | 0.6 | Iris-setosa | |
5.1 | 3.8 | 1.9 | 0.4 | Iris-setosa | |
4.8 | 3.0 | 1.4 | 0.3 | Iris-setosa | |
5.1 | 3.8 | 1.6 | 0.2 | Iris-setosa | |
4.6 | 3.2 | 1.4 | 0.2 | Iris-setosa | |
5.3 | 3.7 | 1.5 | 0.2 | Iris-setosa | |
5.0 | 3.3 | 1.4 | 0.2 | Iris-setosa | |
7.0 | 3.2 | 4.7 | 1.4 | Iris-versicolor | |
6.4 | 3.2 | 4.5 | 1.5 | Iris-versicolor | |
6.9 | 3.1 | 4.9 | 1.5 | Iris-versicolor | |
5.5 | 2.3 | 4.0 | 1.3 | Iris-versicolor | |
6.5 | 2.8 | 4.6 | 1.5 | Iris-versicolor | |
5.7 | 2.8 | 4.5 | 1.3 | Iris-versicolor | |
6.3 | 3.3 | 4.7 | 1.6 | Iris-versicolor | |
4.9 | 2.4 | 3.3 | 1.0 | Iris-versicolor | |
6.6 | 2.9 | 4.6 | 1.3 | Iris-versicolor | |
5.2 | 2.7 | 3.9 | 1.4 | Iris-versicolor | |
5.0 | 2.0 | 3.5 | 1.0 | Iris-versicolor | |
5.9 | 3.0 | 4.2 | 1.5 | Iris-versicolor | |
6.0 | 2.2 | 4.0 | 1.0 | Iris-versicolor | |
6.1 | 2.9 | 4.7 | 1.4 | Iris-versicolor | |
5.6 | 2.9 | 3.6 | 1.3 | Iris-versicolor | |
6.7 | 3.1 | 4.4 | 1.4 | Iris-versicolor | |
5.6 | 3.0 | 4.5 | 1.5 | Iris-versicolor | |
5.8 | 2.7 | 4.1 | 1.0 | Iris-versicolor | |
6.2 | 2.2 | 4.5 | 1.5 | Iris-versicolor | |
5.6 | 2.5 | 3.9 | 1.1 | Iris-versicolor | |
5.9 | 3.2 | 4.8 | 1.8 | Iris-versicolor | |
6.1 | 2.8 | 4.0 | 1.3 | Iris-versicolor | |
6.3 | 2.5 | 4.9 | 1.5 | Iris-versicolor | |
6.1 | 2.8 | 4.7 | 1.2 | Iris-versicolor | |
6.4 | 2.9 | 4.3 | 1.3 | Iris-versicolor | |
6.6 | 3.0 | 4.4 | 1.4 | Iris-versicolor | |
6.8 | 2.8 | 4.8 | 1.4 | Iris-versicolor | |
6.7 | 3.0 | 5.0 | 1.7 | Iris-versicolor | |
6.0 | 2.9 | 4.5 | 1.5 | Iris-versicolor | |
5.7 | 2.6 | 3.5 | 1.0 | Iris-versicolor | |
5.5 | 2.4 | 3.8 | 1.1 | Iris-versicolor | |
5.5 | 2.4 | 3.7 | 1.0 | Iris-versicolor | |
5.8 | 2.7 | 3.9 | 1.2 | Iris-versicolor | |
6.0 | 2.7 | 5.1 | 1.6 | Iris-versicolor | |
5.4 | 3.0 | 4.5 | 1.5 | Iris-versicolor | |
6.0 | 3.4 | 4.5 | 1.6 | Iris-versicolor | |
6.7 | 3.1 | 4.7 | 1.5 | Iris-versicolor | |
6.3 | 2.3 | 4.4 | 1.3 | Iris-versicolor | |
5.6 | 3.0 | 4.1 | 1.3 | Iris-versicolor | |
5.5 | 2.5 | 4.0 | 1.3 | Iris-versicolor | |
5.5 | 2.6 | 4.4 | 1.2 | Iris-versicolor | |
6.1 | 3.0 | 4.6 | 1.4 | Iris-versicolor | |
5.8 | 2.6 | 4.0 | 1.2 | Iris-versicolor | |
5.0 | 2.3 | 3.3 | 1.0 | Iris-versicolor | |
5.6 | 2.7 | 4.2 | 1.3 | Iris-versicolor | |
5.7 | 3.0 | 4.2 | 1.2 | Iris-versicolor | |
5.7 | 2.9 | 4.2 | 1.3 | Iris-versicolor | |
6.2 | 2.9 | 4.3 | 1.3 | Iris-versicolor | |
5.1 | 2.5 | 3.0 | 1.1 | Iris-versicolor | |
5.7 | 2.8 | 4.1 | 1.3 | Iris-versicolor | |
6.3 | 3.3 | 6.0 | 2.5 | Iris-virginica | |
5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica | |
7.1 | 3.0 | 5.9 | 2.1 | Iris-virginica | |
6.3 | 2.9 | 5.6 | 1.8 | Iris-virginica | |
6.5 | 3.0 | 5.8 | 2.2 | Iris-virginica | |
7.6 | 3.0 | 6.6 | 2.1 | Iris-virginica | |
4.9 | 2.5 | 4.5 | 1.7 | Iris-virginica | |
7.3 | 2.9 | 6.3 | 1.8 | Iris-virginica | |
6.7 | 2.5 | 5.8 | 1.8 | Iris-virginica | |
7.2 | 3.6 | 6.1 | 2.5 | Iris-virginica | |
6.5 | 3.2 | 5.1 | 2.0 | Iris-virginica | |
6.4 | 2.7 | 5.3 | 1.9 | Iris-virginica | |
6.8 | 3.0 | 5.5 | 2.1 | Iris-virginica | |
5.7 | 2.5 | 5.0 | 2.0 | Iris-virginica | |
5.8 | 2.8 | 5.1 | 2.4 | Iris-virginica | |
6.4 | 3.2 | 5.3 | 2.3 | Iris-virginica | |
6.5 | 3.0 | 5.5 | 1.8 | Iris-virginica | |
7.7 | 3.8 | 6.7 | 2.2 | Iris-virginica | |
7.7 | 2.6 | 6.9 | 2.3 | Iris-virginica | |
6.0 | 2.2 | 5.0 | 1.5 | Iris-virginica | |
6.9 | 3.2 | 5.7 | 2.3 | Iris-virginica | |
5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica | |
7.7 | 2.8 | 6.7 | 2.0 | Iris-virginica | |
6.3 | 2.7 | 4.9 | 1.8 | Iris-virginica | |
6.7 | 3.3 | 5.7 | 2.1 | Iris-virginica | |
7.2 | 3.2 | 6.0 | 1.8 | Iris-virginica | |
6.2 | 2.8 | 4.8 | 1.8 | Iris-virginica | |
6.1 | 3.0 | 4.9 | 1.8 | Iris-virginica | |
6.4 | 2.8 | 5.6 | 2.1 | Iris-virginica | |
7.2 | 3.0 | 5.8 | 1.6 | Iris-virginica | |
7.4 | 2.8 | 6.1 | 1.9 | Iris-virginica | |
7.9 | 3.8 | 6.4 | 2.0 | Iris-virginica | |
6.4 | 2.8 | 5.6 | 2.2 | Iris-virginica | |
6.3 | 2.8 | 5.1 | 1.5 | Iris-virginica | |
6.1 | 2.6 | 5.6 | 1.4 | Iris-virginica | |
7.7 | 3.0 | 6.1 | 2.3 | Iris-virginica | |
6.3 | 3.4 | 5.6 | 2.4 | Iris-virginica | |
6.4 | 3.1 | 5.5 | 1.8 | Iris-virginica | |
6.0 | 3.0 | 4.8 | 1.8 | Iris-virginica | |
6.9 | 3.1 | 5.4 | 2.1 | Iris-virginica | |
6.7 | 3.1 | 5.6 | 2.4 | Iris-virginica | |
6.9 | 3.1 | 5.1 | 2.3 | Iris-virginica | |
5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica | |
6.8 | 3.2 | 5.9 | 2.3 | Iris-virginica | |
6.7 | 3.3 | 5.7 | 2.5 | Iris-virginica | |
6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica | |
6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica | |
6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica | |
6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica | |
5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.knuchel.playground; | |
import java.io.BufferedReader; | |
import java.io.FileNotFoundException; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.Comparator; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Map.Entry; | |
// K-nearest neighbor | |
/* | |
0.0 < sepalLength < 7.9 | |
0.0 < sepalWidth < 4.4 | |
0.0 < petalLength < 6.9 | |
0.0 < petalWidth < 2.5 | |
*/ | |
public class KNNJava { | |
public static void main(String[] args) { | |
List<Iris> dataset = loadDataset("data/iris.data.csv"); | |
Integer k = 5; | |
Double sepalLength = 5.7d, sepalWidth = 2.6d, petalLength = 3.5d, petalWidth = 1d; | |
String predictedSpecie = predictSpecie(dataset, k, sepalLength, sepalWidth, petalLength, petalWidth); | |
System.out.println("Iris with [sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" | |
+ petalWidth + "] should be a " + predictedSpecie); | |
// try different values of k and see how prediction errors change | |
// evaluate(dataset); | |
} | |
public static String predictSpecie(List<Iris> dataset, Integer k, Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth) { | |
// calculate distance for each sample in dataset | |
Iris unknownIris = new Iris(sepalLength, sepalWidth, petalLength, petalWidth, null); | |
List<Score> scores = new ArrayList<Score>(); | |
for (Iris iris : dataset) { | |
scores.add(new Score(unknownIris.distance(iris), iris.specie)); | |
} | |
Collections.sort(scores, Score.COMPARATOR); | |
// count occurences for K nearest neighbor | |
Map<String, Integer> occurenceCount = new HashMap<String, Integer>(); | |
for (Integer i = 0; i < scores.size(); i++) { | |
String specie = scores.get(i).specie; | |
if (occurenceCount.containsKey(specie)) { | |
occurenceCount.put(specie, occurenceCount.get(specie) + 1); | |
} else { | |
occurenceCount.put(specie, 1); | |
} | |
if (i >= k - 1) { | |
break; | |
} | |
} | |
// find the most frequent occurence | |
String mostFrequentSpecie = null; | |
Integer nbOccurence = 0; | |
for (Entry<String, Integer> entry : occurenceCount.entrySet()) { | |
if (nbOccurence < entry.getValue()) { | |
nbOccurence = entry.getValue(); | |
mostFrequentSpecie = entry.getKey(); | |
} | |
} | |
return mostFrequentSpecie; | |
} | |
public static List<Iris> loadDataset(String csvFile) { | |
List<Iris> dataset = new ArrayList<Iris>(); | |
BufferedReader br = null; | |
String line = ""; | |
String cvsSplitBy = ","; | |
try { | |
br = new BufferedReader(new FileReader(csvFile)); | |
while ((line = br.readLine()) != null) { | |
if (line.length() > 0) { | |
String[] cell = line.split(cvsSplitBy); | |
dataset.add(new Iris(Double.parseDouble(cell[0]), Double.parseDouble(cell[1]), Double.parseDouble(cell[2]), Double.parseDouble(cell[3]), | |
cell[4])); | |
} | |
} | |
} catch (FileNotFoundException e) { | |
e.printStackTrace(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} finally { | |
if (br != null) { | |
try { | |
br.close(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
return dataset; | |
} | |
static class Score { | |
public static final Comparator<Score> COMPARATOR = new Comparator<Score>() { | |
@Override | |
public int compare(Score o1, Score o2) { | |
return o1.score.compareTo(o2.score); | |
} | |
}; | |
public Double score; | |
public String specie; | |
public Score(Double score, String specie) { | |
this.score = score; | |
this.specie = specie; | |
} | |
} | |
static class Iris { | |
public Double sepalLength; | |
public Double sepalWidth; | |
public Double petalLength; | |
public Double petalWidth; | |
public String specie; | |
public Iris(Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth, String specie) { | |
this.sepalLength = sepalLength; | |
this.sepalWidth = sepalWidth; | |
this.petalLength = petalLength; | |
this.petalWidth = petalWidth; | |
this.specie = specie; | |
} | |
public Double distance(Iris that) { | |
return Math.sqrt(Math.pow(sepalLength - that.sepalLength, 2) + Math.pow(sepalWidth - that.sepalWidth, 2) | |
+ Math.pow(petalLength - that.petalLength, 2) + Math.pow(petalWidth - that.petalWidth, 2)); | |
} | |
@Override | |
public String toString() { | |
return "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" | |
+ petalWidth + "]"; | |
} | |
} | |
public static void evaluate(List<Iris> dataset) { | |
// split dataset in 2 parts : one part to learn, the other part to test | |
List<Iris> versicolor = new ArrayList<Iris>(), virginica = new ArrayList<Iris>(), setosa = new ArrayList<Iris>(); | |
for (Iris iris : dataset) { | |
if (iris.specie.equals("Iris-versicolor")) | |
versicolor.add(iris); | |
else if (iris.specie.equals("Iris-virginica")) | |
virginica.add(iris); | |
else if (iris.specie.equals("Iris-setosa")) | |
setosa.add(iris); | |
} | |
Collections.shuffle(versicolor); | |
Collections.shuffle(virginica); | |
Collections.shuffle(setosa); | |
List<Iris> learningData = new ArrayList<Iris>(), testData = new ArrayList<Iris>(); | |
for (Integer i = 0; i < versicolor.size(); i++) { | |
if (i < versicolor.size() / 2) | |
learningData.add(versicolor.get(i)); | |
else | |
testData.add(versicolor.get(i)); | |
} | |
for (Integer i = 0; i < virginica.size(); i++) { | |
if (i < virginica.size() / 2) | |
learningData.add(virginica.get(i)); | |
else | |
testData.add(virginica.get(i)); | |
} | |
for (Integer i = 0; i < setosa.size(); i++) { | |
if (i < setosa.size() / 2) | |
learningData.add(setosa.get(i)); | |
else | |
testData.add(setosa.get(i)); | |
} | |
// for each value of k, count the number of errors | |
for (Integer k = 1; k <= 20; k++) { | |
Integer cpt = 0; | |
for (Iris iris : testData) { | |
if (!predictSpecie(learningData, k, iris.sepalLength, iris.sepalWidth, iris.petalLength, iris.petalWidth).equals(iris.specie)) { | |
cpt++; | |
} | |
} | |
System.out.println(cpt + " errors on " + testData.size() + " tests with k=" + k); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.knuchel.playground | |
import scala.io.Source | |
import scala.util.Random | |
// K-nearest neighbor | |
/* | |
0.0 < sepalLength < 7.9 | |
0.0 < sepalWidth < 4.4 | |
0.0 < petalLength < 6.9 | |
0.0 < petalWidth < 2.5 | |
*/ | |
object KNNScala { | |
def main(args: Array[String]) { | |
val dataset = loadDataset("data/iris.data.csv") | |
val k = 5 | |
val features = (5.7, 2.6, 3.5, 1d) | |
val predictedSpecie = predictSpecie(dataset, k, features) | |
println("Iris with [sepalLength=" + features._1 + ", sepalWidth=" + features._2 + ", petalLength=" + features._3 + ", petalWidth=" + features._4 + "] should be a " + predictedSpecie); | |
// try different values of k and see how prediction errors change | |
// evaluate(dataset) | |
} | |
def predictSpecie(dataset: List[Iris], k: Int, features: (Double, Double, Double, Double)) = { | |
dataset | |
.map(iris => (iris.distance(features), iris.getSpecie)).sorted.take(k) // calculate distance for each sample in dataset, sort by distance and take K nearest | |
.groupBy(_._2).map(elt => (elt._2.length, elt._1)).toList // group by specie, count occurences of species and transform map to list | |
.sortBy(-_._1).head._2 // sort descending by number of specie occurences, get the first one and return the specie name | |
} | |
def loadDataset(csvFile: String) = { | |
val file = Source.fromFile(csvFile) | |
val iter = file.getLines().filter(s => s.length() > 0).map(line => { | |
val cell = line.split(",") | |
new Iris(cell(0).toDouble, cell(1).toDouble, cell(2).toDouble, cell(3).toDouble, cell(4)) | |
}).toList | |
file.close() | |
iter | |
} | |
class Iris(sepalLength: Double, sepalWidth: Double, petalLength: Double, petalWidth: Double, specie: String) { | |
def getSpecie = specie | |
def getFeatures = (sepalLength, sepalWidth, petalLength, petalWidth) | |
def distance(that: (Double, Double, Double, Double)) = Math.sqrt(Math.pow(this.sepalLength - that._1, 2) + Math.pow(this.sepalWidth - that._2, 2) + Math.pow(this.petalLength - that._3, 2) + Math.pow(this.petalWidth - that._4, 2)) | |
// override def toString = "Iris" | |
override def toString = "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" + petalWidth + "]" | |
} | |
def evaluate(dataset: List[Iris]) { | |
// split dataset in 2 parts : one part to learn, the other part to test | |
val versicolor = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-versicolor")) | |
val virginica = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-virginica")) | |
val setosa = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-setosa")) | |
val learningData = versicolor.take(versicolor.length / 2) ++ virginica.take(virginica.length / 2) ++ setosa.take(setosa.length / 2) | |
val testData = versicolor.drop(versicolor.length / 2) ++ virginica.drop(virginica.length / 2) ++ setosa.drop(setosa.length / 2) | |
// test errors for different values of k | |
for (k <- (1 to 20)) { | |
println(testData.filter(iris => iris.getSpecie != predictSpecie(learningData, k, iris.getFeatures)).length + " errors on " + testData.length + " tests with k=" + k); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment