Created
March 6, 2012 11:23
-
-
Save mgronhol/1985756 to your computer and use it in GitHub Desktop.
k-NN luokitin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.*; | |
import java.util.Iterator; | |
import java.util.HashMap; | |
import java.util.SortedMap; | |
import java.util.TreeMap; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.lang.Math.*; | |
class KnnEntry { | |
private String luokka; | |
private ArrayList<Double> arvot; | |
public KnnEntry( String aLuokka ){ | |
luokka = aLuokka; | |
arvot = new ArrayList<Double>(); | |
} | |
public void add( double value ){ | |
arvot.add( new Double( value ) ); | |
} | |
public String get_class(){ return luokka; } | |
public Double get_value( int index ){ return arvot.get( index ); } | |
public int get_dimension(){ return arvot.size(); } | |
public KnnEntry normita( ArrayList<Double> ka, ArrayList<Double> kh ){ | |
KnnEntry out = new KnnEntry( luokka ); | |
for( int i = 0 ; i < arvot.size() ; ++i ){ | |
// normitettu arvo = (vanha arvo - keskiarvo) / keskihajonta | |
out.add( new Double( (arvot.get(i) - ka.get(i)) / kh.get(i) ) ); | |
} | |
return out; | |
} | |
} | |
class KnnClassifier { | |
private ArrayList<KnnEntry> points; | |
private ArrayList<KnnEntry> norm_points; | |
private ArrayList< Double > keskiarvot; | |
private ArrayList< Double > keskihajonnat; | |
public KnnClassifier(){ | |
points = new ArrayList<KnnEntry>(); | |
norm_points = new ArrayList<KnnEntry>(); | |
keskiarvot = new ArrayList<Double>(); | |
keskihajonnat = new ArrayList<Double>(); | |
} | |
public void add( KnnEntry entry ){ | |
points.add( entry ); | |
} | |
private double keskiarvo( int index ){ | |
double summa = 0; | |
for( KnnEntry entry : points ){ | |
summa += entry.get_value( index ); | |
} | |
return summa / points.size(); | |
} | |
private double keskihajonta( int index ){ | |
double hajonta = 0; | |
double avg = this.keskiarvo( index ); | |
for( KnnEntry entry : points ){ | |
hajonta += Math.pow( entry.get_value( index ) - avg, 2 ); | |
} | |
return Math.sqrt( hajonta / ( points.size() - 1 ) ); | |
} | |
private double distance( KnnEntry k0, KnnEntry k1 ){ | |
double summa = 0.0; | |
int dim = k0.get_dimension(); | |
for( int i = 0 ; i < dim ; ++i ){ | |
summa += Math.pow( k0.get_value( i ) - k1.get_value( i ), 2 ); | |
} | |
return summa; | |
} | |
public void preprocess(){ | |
int dim = points.get(0).get_dimension(); | |
norm_points.clear(); | |
keskiarvot.clear(); | |
keskihajonnat.clear(); | |
for( int i = 0 ; i < dim ; ++i ){ | |
keskiarvot.add( this.keskiarvo( i ) ); | |
keskihajonnat.add( this.keskihajonta( i ) ); | |
} | |
for( KnnEntry piste : points ){ | |
norm_points.add( piste.normita( keskiarvot, keskihajonnat ) ); | |
} | |
} | |
public String luokittele( int N, KnnEntry point ){ | |
String out = new String(); | |
SortedMap< Double, String > distances = new TreeMap<Double, String>(); | |
HashMap<String, Integer> votes = new HashMap<String,Integer>(); | |
KnnEntry tpoint = point.normita( this.keskiarvot, this.keskihajonnat ); | |
for( KnnEntry entry : norm_points ){ | |
distances.put( distance( entry, tpoint ), entry.get_class() ); | |
} | |
Iterator iterator = distances.keySet().iterator(); | |
int count = 0; | |
while( iterator.hasNext() && count < N ){ | |
Double key = (Double)iterator.next(); | |
String luokka = distances.get( key ); | |
if( votes.containsKey( luokka ) ){ | |
votes.put( luokka, votes.get( luokka ) + 1 ); | |
} | |
else{ | |
votes.put( luokka, 1 ); | |
} | |
} | |
int max_votes = 0; | |
iterator = votes.keySet().iterator(); | |
while( iterator.hasNext() ){ | |
String luokka = (String)iterator.next(); | |
if( votes.get( luokka ) > max_votes ){ | |
max_votes = votes.get( luokka ); | |
out = luokka; | |
} | |
} | |
return out; | |
} | |
} | |
public class KnnDemo { | |
static public void main( String[] args ){ | |
KnnClassifier luokitin = new KnnClassifier(); | |
KnnEntry entry; | |
// Lisätään jokaista kaksi entryä, lähes täysin samanvärisiä | |
// Tämä siis sen takia, ettei tosta keskihajonnasta tule nollaa :D | |
entry = new KnnEntry( "blue" ); | |
entry.add( 0.0 ); | |
entry.add( 0.0 ); | |
entry.add( 255.0 ); | |
luokitin.add( entry ); | |
entry = new KnnEntry( "blue" ); | |
entry.add( 1.0 ); | |
entry.add( 1.0 ); | |
entry.add( 254.0 ); | |
luokitin.add( entry ); | |
entry = new KnnEntry( "red" ); | |
entry.add( 255.0 ); | |
entry.add( 0.0 ); | |
entry.add( 0.0 ); | |
luokitin.add( entry ); | |
entry = new KnnEntry( "red" ); | |
entry.add( 254.0 ); | |
entry.add( 1.0 ); | |
entry.add( 1.0 ); | |
luokitin.add( entry ); | |
entry = new KnnEntry( "green" ); | |
entry.add( 0.0 ); | |
entry.add( 255.0 ); | |
entry.add( 0.0 ); | |
luokitin.add( entry ); | |
entry = new KnnEntry( "green" ); | |
entry.add( 1.0 ); | |
entry.add( 254.0 ); | |
entry.add( 1.0 ); | |
luokitin.add( entry ); | |
luokitin.preprocess(); | |
// kokeillaan | |
// Tästä pitäisi tulla punainen | |
entry = new KnnEntry(""); | |
entry.add( 200.0 ); | |
entry.add( 10.0 ); | |
entry.add( 10.0 ); | |
System.out.println( "Kokeilu luokittuu väriksi '" + luokitin.luokittele( 1, entry ) + "'" ); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment