Skip to content

Instantly share code, notes, and snippets.

@mgronhol
Created March 6, 2012 11:23
Show Gist options
  • Save mgronhol/1985756 to your computer and use it in GitHub Desktop.
Save mgronhol/1985756 to your computer and use it in GitHub Desktop.
k-NN luokitin
import java.io.*;
import java.util.Iterator;
import java.util.HashMap;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.ArrayList;
import java.util.Collections;
import java.lang.Math.*;
class KnnEntry {
private String luokka;
private ArrayList<Double> arvot;
public KnnEntry( String aLuokka ){
luokka = aLuokka;
arvot = new ArrayList<Double>();
}
public void add( double value ){
arvot.add( new Double( value ) );
}
public String get_class(){ return luokka; }
public Double get_value( int index ){ return arvot.get( index ); }
public int get_dimension(){ return arvot.size(); }
public KnnEntry normita( ArrayList<Double> ka, ArrayList<Double> kh ){
KnnEntry out = new KnnEntry( luokka );
for( int i = 0 ; i < arvot.size() ; ++i ){
// normitettu arvo = (vanha arvo - keskiarvo) / keskihajonta
out.add( new Double( (arvot.get(i) - ka.get(i)) / kh.get(i) ) );
}
return out;
}
}
class KnnClassifier {
private ArrayList<KnnEntry> points;
private ArrayList<KnnEntry> norm_points;
private ArrayList< Double > keskiarvot;
private ArrayList< Double > keskihajonnat;
public KnnClassifier(){
points = new ArrayList<KnnEntry>();
norm_points = new ArrayList<KnnEntry>();
keskiarvot = new ArrayList<Double>();
keskihajonnat = new ArrayList<Double>();
}
public void add( KnnEntry entry ){
points.add( entry );
}
private double keskiarvo( int index ){
double summa = 0;
for( KnnEntry entry : points ){
summa += entry.get_value( index );
}
return summa / points.size();
}
private double keskihajonta( int index ){
double hajonta = 0;
double avg = this.keskiarvo( index );
for( KnnEntry entry : points ){
hajonta += Math.pow( entry.get_value( index ) - avg, 2 );
}
return Math.sqrt( hajonta / ( points.size() - 1 ) );
}
private double distance( KnnEntry k0, KnnEntry k1 ){
double summa = 0.0;
int dim = k0.get_dimension();
for( int i = 0 ; i < dim ; ++i ){
summa += Math.pow( k0.get_value( i ) - k1.get_value( i ), 2 );
}
return summa;
}
public void preprocess(){
int dim = points.get(0).get_dimension();
norm_points.clear();
keskiarvot.clear();
keskihajonnat.clear();
for( int i = 0 ; i < dim ; ++i ){
keskiarvot.add( this.keskiarvo( i ) );
keskihajonnat.add( this.keskihajonta( i ) );
}
for( KnnEntry piste : points ){
norm_points.add( piste.normita( keskiarvot, keskihajonnat ) );
}
}
public String luokittele( int N, KnnEntry point ){
String out = new String();
SortedMap< Double, String > distances = new TreeMap<Double, String>();
HashMap<String, Integer> votes = new HashMap<String,Integer>();
KnnEntry tpoint = point.normita( this.keskiarvot, this.keskihajonnat );
for( KnnEntry entry : norm_points ){
distances.put( distance( entry, tpoint ), entry.get_class() );
}
Iterator iterator = distances.keySet().iterator();
int count = 0;
while( iterator.hasNext() && count < N ){
Double key = (Double)iterator.next();
String luokka = distances.get( key );
if( votes.containsKey( luokka ) ){
votes.put( luokka, votes.get( luokka ) + 1 );
}
else{
votes.put( luokka, 1 );
}
}
int max_votes = 0;
iterator = votes.keySet().iterator();
while( iterator.hasNext() ){
String luokka = (String)iterator.next();
if( votes.get( luokka ) > max_votes ){
max_votes = votes.get( luokka );
out = luokka;
}
}
return out;
}
}
public class KnnDemo {
static public void main( String[] args ){
KnnClassifier luokitin = new KnnClassifier();
KnnEntry entry;
// Lisätään jokaista kaksi entryä, lähes täysin samanvärisiä
// Tämä siis sen takia, ettei tosta keskihajonnasta tule nollaa :D
entry = new KnnEntry( "blue" );
entry.add( 0.0 );
entry.add( 0.0 );
entry.add( 255.0 );
luokitin.add( entry );
entry = new KnnEntry( "blue" );
entry.add( 1.0 );
entry.add( 1.0 );
entry.add( 254.0 );
luokitin.add( entry );
entry = new KnnEntry( "red" );
entry.add( 255.0 );
entry.add( 0.0 );
entry.add( 0.0 );
luokitin.add( entry );
entry = new KnnEntry( "red" );
entry.add( 254.0 );
entry.add( 1.0 );
entry.add( 1.0 );
luokitin.add( entry );
entry = new KnnEntry( "green" );
entry.add( 0.0 );
entry.add( 255.0 );
entry.add( 0.0 );
luokitin.add( entry );
entry = new KnnEntry( "green" );
entry.add( 1.0 );
entry.add( 254.0 );
entry.add( 1.0 );
luokitin.add( entry );
luokitin.preprocess();
// kokeillaan
// Tästä pitäisi tulla punainen
entry = new KnnEntry("");
entry.add( 200.0 );
entry.add( 10.0 );
entry.add( 10.0 );
System.out.println( "Kokeilu luokittuu väriksi '" + luokitin.luokittele( 1, entry ) + "'" );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment