Created
May 18, 2013 02:29
-
-
Save dapurv5/5603016 to your computer and use it in GitHub Desktop.
Sparse Matrix Multiplication in Map Reduce
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Copyright (c) 2013 Apurv Verma | |
*/ | |
package org.anahata.hadoop.illustrations; | |
import java.io.IOException; | |
import java.util.HashMap; | |
import java.util.Map; | |
import org.anahata.commons.hadoop.io.IntInt; | |
import org.anahata.commons.hadoop.io.IntIntInt; | |
import org.anahata.commons.io.IOBoilerplate; | |
import org.apache.commons.logging.Log; | |
import org.apache.commons.logging.LogFactory; | |
import org.apache.hadoop.conf.Configuration; | |
import org.apache.hadoop.fs.FileStatus; | |
import org.apache.hadoop.fs.FileSystem; | |
import org.apache.hadoop.fs.Path; | |
import org.apache.hadoop.io.IntWritable; | |
import org.apache.hadoop.io.SequenceFile; | |
import org.apache.hadoop.mapreduce.Job; | |
import org.apache.hadoop.mapreduce.Mapper; | |
import org.apache.hadoop.mapreduce.Reducer; | |
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; | |
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; | |
public class SpMatrixMult { | |
private final static Log LOG = LogFactory.getLog(SpMatrixMult.class); | |
private final static String INPUT_PATH = "/tmp/matrix-mult/input.seq"; | |
private final static String OUTPUT_PATH = "/tmp/matrix-mult/output"; | |
private final static String L = "org.anahata.hadoop.illustrations.SpMatrixMult.L"; | |
private final static String M = "org.anahata.hadoop.illustrations.SpMatrixMult.M"; | |
private final static String N = "org.anahata.hadoop.illustrations.SpMatrixMult.N"; | |
private final static Configuration conf = new Configuration(); | |
public static class SpMatrixMultMapper extends Mapper<IntIntInt, IntWritable, IntIntInt, IntInt>{ | |
private int n; | |
private int l; | |
private IntIntInt keyOut; | |
private IntInt valOut; | |
@Override | |
protected void setup(Context context) throws IOException, | |
InterruptedException { | |
super.setup(context); | |
Configuration conf = context.getConfiguration(); | |
n = conf.getInt(N, 0); | |
l = conf.getInt(L, 0); | |
keyOut = new IntIntInt(); | |
valOut = new IntInt(); | |
} | |
@Override | |
protected void map(IntIntInt key, IntWritable value, Context context) | |
throws IOException, InterruptedException { | |
int matrixId = key.getFirst(); | |
int i = key.getSecond(); | |
int j = key.getThird(); | |
if(matrixId == 0){ | |
for(int n_ = 0; n_ < n; n_++){ | |
keyOut.set(2, i, n_); | |
valOut.set(j, value.get()); | |
context.write(keyOut, valOut); | |
} | |
} else{ | |
for(int l_ = 0; l_ < l; l_++){ | |
keyOut.set(2, l_, j); | |
valOut.set(i, value.get()); | |
context.write(keyOut, valOut); | |
} | |
} | |
} | |
} | |
public static class SpMatrixMultReducer extends Reducer<IntIntInt, IntInt, IntIntInt, IntWritable>{ | |
private IntIntInt key; | |
private IntWritable val; | |
private Map<Integer, Integer> aggrMap; | |
@Override | |
protected void setup(Context context) | |
throws IOException, InterruptedException { | |
super.setup(context); | |
key = new IntIntInt(); | |
val = new IntWritable(); | |
aggrMap = new HashMap<>(); | |
} | |
@Override | |
protected void reduce(IntIntInt key, Iterable<IntInt> values, Context context) throws IOException, | |
InterruptedException { | |
for(IntInt ii: values){ | |
int a = 1; | |
if(aggrMap.get(ii.getFirst()) != null){ | |
a = aggrMap.get(ii.getFirst()); | |
} | |
aggrMap.put(ii.getFirst(), a * ii.getSecond()); | |
} | |
int res = 0; | |
for(Integer index: aggrMap.keySet()){ | |
res += aggrMap.get(index); | |
} | |
val.set(res); | |
context.write(key, val); | |
aggrMap.clear(); | |
} | |
} | |
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException { | |
int[][] A = { | |
{0, 1, 2}, | |
{2, 1, 3} | |
}; | |
int[][] B = { | |
{1, 2}, | |
{2, 0}, | |
{1, 4} | |
}; | |
assert(A[0].length == B.length); | |
//A (L cross M), B(M cross N) | |
writeInput(A, B); | |
conf.setInt(L, A.length); | |
conf.setInt(M, A[0].length); | |
conf.setInt(N, B[0].length); | |
Job job = new Job(conf); | |
job.setJarByClass(SpMatrixMult.class); | |
job.setJobName("Matrix Multiplication"); | |
job.setMapperClass(SpMatrixMultMapper.class); | |
job.setReducerClass(SpMatrixMultReducer.class); | |
SequenceFileInputFormat.addInputPaths(job, INPUT_PATH); | |
SequenceFileOutputFormat.setOutputPath(job, new Path(OUTPUT_PATH)); | |
job.setInputFormatClass(SequenceFileInputFormat.class); | |
job.setOutputFormatClass(SequenceFileOutputFormat.class); | |
job.setMapOutputKeyClass(IntIntInt.class); | |
job.setMapOutputValueClass(IntInt.class); | |
job.setOutputKeyClass(IntIntInt.class); | |
job.setOutputValueClass(IntWritable.class); | |
job.waitForCompletion(true); | |
writeOutput(); | |
cleanUp(); | |
} | |
private static void writeInput(int[][] A, int[][] B){ | |
IntIntInt key = new IntIntInt(); | |
IntWritable val = new IntWritable(); | |
SequenceFile.Writer writer = null; | |
try { | |
FileSystem fs = FileSystem.get(conf); | |
writer = SequenceFile.createWriter(fs, conf, new Path(INPUT_PATH), | |
IntIntInt.class, IntWritable.class); | |
for(int i = 0; i < A.length; i++){ | |
for(int j = 0; j < A[0].length; j++){ | |
key.set(0, i, j); //0 specifies first matrix | |
val.set(A[i][j]); | |
writer.append(key, val); | |
} | |
} | |
for(int i = 0; i < B.length; i++){ | |
for(int j = 0; j < B[0].length; j++){ | |
key.set(1, i, j); //1 specifies first matrix | |
val.set(B[i][j]); | |
writer.append(key, val); | |
} | |
} | |
} catch (IOException e) { | |
LOG.error("Could not write the input", e); | |
} finally{ | |
IOBoilerplate.closeGracefully(writer); | |
} | |
} | |
private static void writeOutput(){ | |
IntIntInt key = new IntIntInt(); | |
IntWritable val = new IntWritable(); | |
SequenceFile.Reader reader = null; | |
try { | |
FileSystem fs = FileSystem.get(conf); | |
FileStatus[] status = fs.listStatus(new Path(OUTPUT_PATH)); | |
for(FileStatus stat : status){ | |
if(stat.isDir() || stat.getLen() == 0){continue;} | |
reader = new SequenceFile.Reader(fs, stat.getPath(), conf); | |
while(reader.next(key, val)){ | |
System.out.println(key+" __ "+val); | |
} | |
reader.close(); | |
} | |
} catch (IOException e) { | |
LOG.error("Could not fetch result", e); | |
} finally{ | |
IOBoilerplate.closeGracefully(reader); | |
} | |
} | |
private static void cleanUp(){ | |
try { | |
FileSystem fs = FileSystem.get(conf); | |
fs.delete(new Path(OUTPUT_PATH), true); | |
} catch (IOException e) { | |
LOG.error("Could not cleanup files", e); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment