Created
November 19, 2012 13:29
-
-
Save rofr/4110637 to your computer and use it in GitHub Desktop.
Multithreaded kNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Thread pool from the Java 1.5 Executor Framework | |
*/ | |
private ExecutorService executorService; | |
/** | |
* Initialize the thread pool | |
*/ | |
private void init() { | |
int numThreads = Runtime.getRuntime().availableProcessors(); | |
executorService = Executors.newFixedThreadPool(numThreads); | |
} | |
protected double[] calculateDistances(final T item) { | |
int numThreads = Runtime.getRuntime().availableProcessors(); | |
int numItems = items.size(); | |
final double[] distances = new double[numItems]; | |
final int numItemsPerThread = numItems / numThreads; | |
//Create the tasks to be executed by the executor service | |
List<Callable<Void>> tasks = new ArrayList<Callable<Void>>(numThreads); | |
for (int i = 0; i < numThreads;i++) { | |
final int j = i; | |
tasks.add( new Callable<Void>() { | |
public Void call() { | |
calculateSegment(distances,item, j * numItemsPerThread, numItemsPerThread); | |
return null; | |
} | |
}); | |
} | |
//run all tasks, waiting for them to complete | |
try { | |
executorService.invokeAll(tasks); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
return distances; | |
} | |
/** | |
* Calculates distances for a specific slice of the items | |
* @param A reference to the entire array of distances | |
* @param item The item to compare with | |
* @param offset index of the first item to compare | |
* @param count number of items to compare | |
*/ | |
private void calculateSegment(double[] distances, T item, int offset, int count) { | |
for(int i = offset; i < offset + count; i++) { | |
//avoid array out of bounds when items not evenly | |
//divisible by number of threads | |
if (i == distances.length) break; | |
distances[i] = distanceFunction.calculateDistance(item, items.get(i)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment