Skip to content

Instantly share code, notes, and snippets.

@myui
Created September 25, 2015 03:27
Show Gist options
  • Save myui/c79011fccde45e327d9c to your computer and use it in GitHub Desktop.
Save myui/c79011fccde45e327d9c to your computer and use it in GitHub Desktop.
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.tools;
import static hivemall.HivemallConstants.BIGINT_TYPE_NAME;
import static hivemall.HivemallConstants.BOOLEAN_TYPE_NAME;
import static hivemall.HivemallConstants.INT_TYPE_NAME;
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import static hivemall.HivemallConstants.TINYINT_TYPE_NAME;
import hivemall.utils.collections.BoundedPriorityQueue;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare.CompareType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
@Description(name = "each_top_k", value = "_FUNC_(const int K, Object group, double cmpKey, *) - Returns top-K values (or tail-K values when k is less than 0)")
public final class EachTopKUDTF extends GenericUDTF {
private transient ObjectInspector[] _argOIs;
private transient FieldComparer _groupComparer;
private transient PrimitiveObjectInspector _rankKeyOI;
private int _sizeK;
private BoundedPriorityQueue<TupleWithKey> _queue;
private TupleWithKey _tuple;
private Object _previousGroup;
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
if(numArgs < 4) {
throw new UDFArgumentException("each_top_k(const int K, Object group, double cmpKey, *) takes at least 4 arguments: "
+ numArgs);
}
this._argOIs = argOIs;
int k = HiveUtils.getAsConstInt(argOIs[0]);
if(k == 0) {
throw new UDFArgumentException("k should not be 0");
}
ObjectInspector prevGroupOI = ObjectInspectorUtils.getStandardObjectInspector(argOIs[1], ObjectInspectorCopyOption.DEFAULT);
this._groupComparer = new FieldComparer(argOIs[1], prevGroupOI);
this._rankKeyOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]);
this._sizeK = Math.abs(k);
final Comparator<TupleWithKey> comparator;
if(k < 0) {
comparator = Collections.reverseOrder();
} else {
comparator = new Comparator<TupleWithKey>() {
@Override
public int compare(TupleWithKey o1, TupleWithKey o2) {
return o1.compareTo(o2);
}
};
}
//this.queue = new BoundedPriorityQueue<Row>(sizeK, Comparator.nullsFirst(comparator));
this._queue = new BoundedPriorityQueue<TupleWithKey>(_sizeK, comparator);
this._tuple = null;
this._previousGroup = null;
final ArrayList<String> fieldNames = new ArrayList<String>(numArgs);
final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(numArgs);
fieldNames.add("rank");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("key");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
for(int i = 3; i < numArgs; i++) {
fieldNames.add("c" + (i - 2));
ObjectInspector rawOI = argOIs[i];
ObjectInspector retOI = ObjectInspectorUtils.getStandardObjectInspector(rawOI, ObjectInspectorCopyOption.DEFAULT);
fieldOIs.add(retOI);
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
final Object arg1 = args[1];
if(isSameGroup(arg1) == false) {
Object group = ObjectInspectorUtils.copyToStandardObject(arg1, _argOIs[1], ObjectInspectorCopyOption.DEFAULT); // arg1 and group may be null
drainQueue();
this._previousGroup = group;
}
final double key = PrimitiveObjectInspectorUtils.getDouble(args[2], _rankKeyOI);
final Object[] row;
TupleWithKey tuple = this._tuple;
if(_tuple == null) {
row = new Object[args.length - 1];
tuple = new TupleWithKey(key, row);
this._tuple = tuple;
} else {
row = tuple.getRow();
tuple.setKey(key);
}
for(int i = 3; i < args.length; i++) {
Object arg = args[i];
ObjectInspector argOI = _argOIs[i];
row[i - 1] = ObjectInspectorUtils.copyToStandardObject(arg, argOI, ObjectInspectorCopyOption.DEFAULT);
}
if(_queue.offer(tuple)) {
this._tuple = null;
}
}
private boolean isSameGroup(final Object group) {
return _groupComparer.areEqual(group, _previousGroup);
}
private void drainQueue() throws HiveException {
final int queueSize = _queue.size();
if(queueSize > 0) {
final TupleWithKey[] tuples = new TupleWithKey[queueSize];
for(int i = 0; i < queueSize; i++) {
TupleWithKey tuple = _queue.poll();
if(tuple == null) {
throw new IllegalStateException("Found null element in the queue");
}
tuples[i] = tuple;
}
final IntWritable rankProbe = new IntWritable(-1);
final DoubleWritable keyProbe = new DoubleWritable(Double.NaN);
int rank = 0;
double lastKey = Double.NaN;
for(int i = queueSize - 1; i >= 0; i--) {
TupleWithKey tuple = tuples[i];
tuples[i] = null; // help GC
double key = tuple.getKey();
if(key != lastKey) {
++rank;
rankProbe.set(rank);
keyProbe.set(key);
lastKey = key;
}
Object[] row = tuple.getRow();
row[0] = rankProbe;
row[1] = keyProbe;
forward(row);
}
_queue.clear();
}
}
@Override
public void close() throws HiveException {
drainQueue();
this._queue = null;
this._tuple = null;
}
private static final class TupleWithKey implements Comparable<TupleWithKey> {
double key;
Object[] row;
TupleWithKey(double key, Object[] row) {
this.key = key;
this.row = row;
}
double getKey() {
return key;
}
Object[] getRow() {
return row;
}
void setKey(final double key) {
this.key = key;
}
@Override
public int compareTo(TupleWithKey o) {
return Double.compare(key, o.key);
}
}
private static final class FieldComparer {
protected final ObjectInspector oi0, oi1;
protected final CompareType compareType;
protected StringObjectInspector soi0, soi1;
protected IntObjectInspector ioi0, ioi1;
protected LongObjectInspector loi0, loi1;
protected ByteObjectInspector byoi0, byoi1;
protected BooleanObjectInspector boi0, boi1;
FieldComparer(ObjectInspector oi0, ObjectInspector oi1) {
this.oi0 = oi0;
this.oi1 = oi1;
final String type0 = oi0.getTypeName();
final String type1 = oi1.getTypeName();
if(STRING_TYPE_NAME.equals(type0) && STRING_TYPE_NAME.equals(type1)) {
soi0 = (StringObjectInspector) oi0;
soi1 = (StringObjectInspector) oi1;
if(soi0.preferWritable() || soi1.preferWritable()) {
compareType = CompareType.COMPARE_TEXT;
} else {
compareType = CompareType.COMPARE_STRING;
}
} else if(INT_TYPE_NAME.equals(type0) && INT_TYPE_NAME.equals(type1)) {
compareType = CompareType.COMPARE_INT;
ioi0 = (IntObjectInspector) oi0;
ioi1 = (IntObjectInspector) oi1;
} else if(BIGINT_TYPE_NAME.equals(type0) && BIGINT_TYPE_NAME.equals(type1)) {
compareType = CompareType.COMPARE_LONG;
loi0 = (LongObjectInspector) oi0;
loi1 = (LongObjectInspector) oi1;
} else if(TINYINT_TYPE_NAME.equals(type0) && TINYINT_TYPE_NAME.equals(type1)) {
compareType = CompareType.COMPARE_BYTE;
byoi0 = (ByteObjectInspector) oi0;
byoi1 = (ByteObjectInspector) oi1;
} else if(BOOLEAN_TYPE_NAME.equals(type0) && BOOLEAN_TYPE_NAME.equals(type1)) {
compareType = CompareType.COMPARE_BOOL;
boi0 = (BooleanObjectInspector) oi0;
boi1 = (BooleanObjectInspector) oi1;
} else {
// We don't check compatibility of two object inspectors, but directly
// pass them into ObjectInspectorUtils.compare(), users of this class
// should make sure ObjectInspectorUtils.compare() doesn't throw exceptions
// and returns correct results.
compareType = CompareType.SAME_TYPE;
}
}
boolean areEqual(final Object o0, final Object o1) {
if(o0 == null && o1 == null) {
return true;
} else if(o0 == null || o1 == null) {
return false;
}
switch (compareType) {
case COMPARE_TEXT:
return (soi0.getPrimitiveWritableObject(o0).equals(soi1.getPrimitiveWritableObject(o1)));
case COMPARE_INT:
return (ioi0.get(o0) == ioi1.get(o1));
case COMPARE_LONG:
return (loi0.get(o0) == loi1.get(o1));
case COMPARE_BYTE:
return (byoi0.get(o0) == byoi1.get(o1));
case COMPARE_BOOL:
return (boi0.get(o0) == boi1.get(o1));
case COMPARE_STRING:
return (soi0.getPrimitiveJavaObject(o0).equals(soi1.getPrimitiveJavaObject(o1)));
default:
return (ObjectInspectorUtils.compare(o0, oi0, o1, oi1) == 0);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment