Skip to content

Instantly share code, notes, and snippets.

@tomron
Created July 20, 2017 06:49
Show Gist options
  • Save tomron/d27fbb13d41d2817d2250c1b2c0c2fd8 to your computer and use it in GitHub Desktop.
Save tomron/d27fbb13d41d2817d2250c1b2c0c2fd8 to your computer and use it in GitHub Desktop.
Merge Map Spark User Defined Aggregation function - merge two maps of type <String, Long> to one Map.
package com.tomron;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Created by tomron on 7/10/17.
*/
public class MergeMapUDAF extends UserDefinedAggregateFunction {
private StructType _inputDataType;
private StructType _bufferSchema;
private DataType _returnDataType;
private static DataType _valueType = DataTypes.LongType;
private static DataType _innerKeyType = DataTypes.StringType;
private static DataType _outerKeyType = DataTypes.StringType;
private static DataType _innerMap = DataTypes.createMapType(_innerKeyType, _valueType);
private static DataType _outerMap = DataTypes.createMapType(_outerKeyType, _innerMap);
public MergeMapUDAF() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("key", _outerKeyType, true));
inputFields.add(DataTypes.createStructField("values", _innerMap, true));
_inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("data", _outerMap, true));
_bufferSchema = DataTypes.createStructType(bufferFields);
_returnDataType = _outerMap;
}
@Override
public StructType inputSchema() {
return _inputDataType;
}
@Override
public StructType bufferSchema() {
return _bufferSchema;
}
@Override
public DataType dataType() {
return _returnDataType;
}
@Override
public boolean deterministic() {
return false;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new HashMap<String, Map<String, Long>>());
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
String inputKey = input.getString(0);
Map<String, Long> inputValues = input.<String, Long>getJavaMap(1);
Map<String, Map<String, Long>> newData = new HashMap<>();
if (!buffer.isNullAt(0)) {
Map<String, Map<String, Long>> currData = buffer.<String, Map<String, Long>>getJavaMap(0);
newData.putAll(currData);
}
newData.put(inputKey, inputValues);
buffer.update(0, newData);
}
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
Map<String, Map<String, Long>> data1 = buffer1.<String, Map<String, Long>>getJavaMap(0);
Map<String, Map<String, Long>> data2 = buffer2.<String, Map<String, Long>>getJavaMap(0);
Map<String, Map<String, Long>> newData = new HashMap<>();
newData.putAll(data1);
newData.putAll(data2);
buffer1.update(0, newData);
}
@Override
public Object evaluate(Row buffer) {
return buffer.<String, Map<String, Long>>getJavaMap(0);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment