Skip to content

Instantly share code, notes, and snippets.

@MikeDepies
Created September 3, 2015 14:40
Show Gist options
  • Select an option

  • Save MikeDepies/984ed735c3ba16d32416 to your computer and use it in GitHub Desktop.

Select an option

Save MikeDepies/984ed735c3ba16d32416 to your computer and use it in GitHub Desktop.
Simple class utilizing the KNN implementation
import java.util.Random;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* KNN Test Class
*
*/
public class App
{
private static Logger log = LoggerFactory.getLogger(App.class);
public static void main( String[] args )
{
knnTest();
}
private static void knnTest() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet dataset = iter.next();
SplitTestAndTrain trainTest = dataset.splitTestAndTrain(120, new Random(3));
INDArray features = trainTest.getTest().getFeatureMatrix();
KNN knn = new KNN(7);
knn.fit(trainTest.getTrain());
System.out.println("P,C");
int[] prediction = knn.predict(features);
for (int i=0; i < features.rows(); i++) {
System.out.println(prediction[i] + "," + Nd4j.getBlasWrapper().iamax(trainTest.getTest().getLabels().getRow(i)));
}
}
}
@KrisRogos
Copy link
Copy Markdown

Did anyone found a fix for the issue @parvathysasi had? I've implemented the KNN exactly as in the tutorial but get the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment