Created
September 25, 2015 03:27
-
-
Save myui/c79011fccde45e327d9c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Hivemall: Hive scalable Machine Learning Library | |
* | |
* Copyright (C) 2015 Makoto YUI | |
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package hivemall.tools; | |
import static hivemall.HivemallConstants.BIGINT_TYPE_NAME; | |
import static hivemall.HivemallConstants.BOOLEAN_TYPE_NAME; | |
import static hivemall.HivemallConstants.INT_TYPE_NAME; | |
import static hivemall.HivemallConstants.STRING_TYPE_NAME; | |
import static hivemall.HivemallConstants.TINYINT_TYPE_NAME; | |
import hivemall.utils.collections.BoundedPriorityQueue; | |
import hivemall.utils.hadoop.HiveUtils; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.Comparator; | |
import org.apache.hadoop.hive.ql.exec.Description; | |
import org.apache.hadoop.hive.ql.exec.UDFArgumentException; | |
import org.apache.hadoop.hive.ql.metadata.HiveException; | |
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare.CompareType; | |
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; | |
import org.apache.hadoop.hive.serde2.io.DoubleWritable; | |
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; | |
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; | |
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; | |
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; | |
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; | |
import org.apache.hadoop.io.IntWritable; | |
@Description(name = "each_top_k", value = "_FUNC_(const int K, Object group, double cmpKey, *) - Returns top-K values (or tail-K values when k is less than 0)") | |
public final class EachTopKUDTF extends GenericUDTF { | |
private transient ObjectInspector[] _argOIs; | |
private transient FieldComparer _groupComparer; | |
private transient PrimitiveObjectInspector _rankKeyOI; | |
private int _sizeK; | |
private BoundedPriorityQueue<TupleWithKey> _queue; | |
private TupleWithKey _tuple; | |
private Object _previousGroup; | |
@Override | |
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { | |
final int numArgs = argOIs.length; | |
if(numArgs < 4) { | |
throw new UDFArgumentException("each_top_k(const int K, Object group, double cmpKey, *) takes at least 4 arguments: " | |
+ numArgs); | |
} | |
this._argOIs = argOIs; | |
int k = HiveUtils.getAsConstInt(argOIs[0]); | |
if(k == 0) { | |
throw new UDFArgumentException("k should not be 0"); | |
} | |
ObjectInspector prevGroupOI = ObjectInspectorUtils.getStandardObjectInspector(argOIs[1], ObjectInspectorCopyOption.DEFAULT); | |
this._groupComparer = new FieldComparer(argOIs[1], prevGroupOI); | |
this._rankKeyOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]); | |
this._sizeK = Math.abs(k); | |
final Comparator<TupleWithKey> comparator; | |
if(k < 0) { | |
comparator = Collections.reverseOrder(); | |
} else { | |
comparator = new Comparator<TupleWithKey>() { | |
@Override | |
public int compare(TupleWithKey o1, TupleWithKey o2) { | |
return o1.compareTo(o2); | |
} | |
}; | |
} | |
//this.queue = new BoundedPriorityQueue<Row>(sizeK, Comparator.nullsFirst(comparator)); | |
this._queue = new BoundedPriorityQueue<TupleWithKey>(_sizeK, comparator); | |
this._tuple = null; | |
this._previousGroup = null; | |
final ArrayList<String> fieldNames = new ArrayList<String>(numArgs); | |
final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(numArgs); | |
fieldNames.add("rank"); | |
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); | |
fieldNames.add("key"); | |
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); | |
for(int i = 3; i < numArgs; i++) { | |
fieldNames.add("c" + (i - 2)); | |
ObjectInspector rawOI = argOIs[i]; | |
ObjectInspector retOI = ObjectInspectorUtils.getStandardObjectInspector(rawOI, ObjectInspectorCopyOption.DEFAULT); | |
fieldOIs.add(retOI); | |
} | |
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); | |
} | |
@Override | |
public void process(Object[] args) throws HiveException { | |
final Object arg1 = args[1]; | |
if(isSameGroup(arg1) == false) { | |
Object group = ObjectInspectorUtils.copyToStandardObject(arg1, _argOIs[1], ObjectInspectorCopyOption.DEFAULT); // arg1 and group may be null | |
drainQueue(); | |
this._previousGroup = group; | |
} | |
final double key = PrimitiveObjectInspectorUtils.getDouble(args[2], _rankKeyOI); | |
final Object[] row; | |
TupleWithKey tuple = this._tuple; | |
if(_tuple == null) { | |
row = new Object[args.length - 1]; | |
tuple = new TupleWithKey(key, row); | |
this._tuple = tuple; | |
} else { | |
row = tuple.getRow(); | |
tuple.setKey(key); | |
} | |
for(int i = 3; i < args.length; i++) { | |
Object arg = args[i]; | |
ObjectInspector argOI = _argOIs[i]; | |
row[i - 1] = ObjectInspectorUtils.copyToStandardObject(arg, argOI, ObjectInspectorCopyOption.DEFAULT); | |
} | |
if(_queue.offer(tuple)) { | |
this._tuple = null; | |
} | |
} | |
private boolean isSameGroup(final Object group) { | |
return _groupComparer.areEqual(group, _previousGroup); | |
} | |
private void drainQueue() throws HiveException { | |
final int queueSize = _queue.size(); | |
if(queueSize > 0) { | |
final TupleWithKey[] tuples = new TupleWithKey[queueSize]; | |
for(int i = 0; i < queueSize; i++) { | |
TupleWithKey tuple = _queue.poll(); | |
if(tuple == null) { | |
throw new IllegalStateException("Found null element in the queue"); | |
} | |
tuples[i] = tuple; | |
} | |
final IntWritable rankProbe = new IntWritable(-1); | |
final DoubleWritable keyProbe = new DoubleWritable(Double.NaN); | |
int rank = 0; | |
double lastKey = Double.NaN; | |
for(int i = queueSize - 1; i >= 0; i--) { | |
TupleWithKey tuple = tuples[i]; | |
tuples[i] = null; // help GC | |
double key = tuple.getKey(); | |
if(key != lastKey) { | |
++rank; | |
rankProbe.set(rank); | |
keyProbe.set(key); | |
lastKey = key; | |
} | |
Object[] row = tuple.getRow(); | |
row[0] = rankProbe; | |
row[1] = keyProbe; | |
forward(row); | |
} | |
_queue.clear(); | |
} | |
} | |
@Override | |
public void close() throws HiveException { | |
drainQueue(); | |
this._queue = null; | |
this._tuple = null; | |
} | |
private static final class TupleWithKey implements Comparable<TupleWithKey> { | |
double key; | |
Object[] row; | |
TupleWithKey(double key, Object[] row) { | |
this.key = key; | |
this.row = row; | |
} | |
double getKey() { | |
return key; | |
} | |
Object[] getRow() { | |
return row; | |
} | |
void setKey(final double key) { | |
this.key = key; | |
} | |
@Override | |
public int compareTo(TupleWithKey o) { | |
return Double.compare(key, o.key); | |
} | |
} | |
private static final class FieldComparer { | |
protected final ObjectInspector oi0, oi1; | |
protected final CompareType compareType; | |
protected StringObjectInspector soi0, soi1; | |
protected IntObjectInspector ioi0, ioi1; | |
protected LongObjectInspector loi0, loi1; | |
protected ByteObjectInspector byoi0, byoi1; | |
protected BooleanObjectInspector boi0, boi1; | |
FieldComparer(ObjectInspector oi0, ObjectInspector oi1) { | |
this.oi0 = oi0; | |
this.oi1 = oi1; | |
final String type0 = oi0.getTypeName(); | |
final String type1 = oi1.getTypeName(); | |
if(STRING_TYPE_NAME.equals(type0) && STRING_TYPE_NAME.equals(type1)) { | |
soi0 = (StringObjectInspector) oi0; | |
soi1 = (StringObjectInspector) oi1; | |
if(soi0.preferWritable() || soi1.preferWritable()) { | |
compareType = CompareType.COMPARE_TEXT; | |
} else { | |
compareType = CompareType.COMPARE_STRING; | |
} | |
} else if(INT_TYPE_NAME.equals(type0) && INT_TYPE_NAME.equals(type1)) { | |
compareType = CompareType.COMPARE_INT; | |
ioi0 = (IntObjectInspector) oi0; | |
ioi1 = (IntObjectInspector) oi1; | |
} else if(BIGINT_TYPE_NAME.equals(type0) && BIGINT_TYPE_NAME.equals(type1)) { | |
compareType = CompareType.COMPARE_LONG; | |
loi0 = (LongObjectInspector) oi0; | |
loi1 = (LongObjectInspector) oi1; | |
} else if(TINYINT_TYPE_NAME.equals(type0) && TINYINT_TYPE_NAME.equals(type1)) { | |
compareType = CompareType.COMPARE_BYTE; | |
byoi0 = (ByteObjectInspector) oi0; | |
byoi1 = (ByteObjectInspector) oi1; | |
} else if(BOOLEAN_TYPE_NAME.equals(type0) && BOOLEAN_TYPE_NAME.equals(type1)) { | |
compareType = CompareType.COMPARE_BOOL; | |
boi0 = (BooleanObjectInspector) oi0; | |
boi1 = (BooleanObjectInspector) oi1; | |
} else { | |
// We don't check compatibility of two object inspectors, but directly | |
// pass them into ObjectInspectorUtils.compare(), users of this class | |
// should make sure ObjectInspectorUtils.compare() doesn't throw exceptions | |
// and returns correct results. | |
compareType = CompareType.SAME_TYPE; | |
} | |
} | |
boolean areEqual(final Object o0, final Object o1) { | |
if(o0 == null && o1 == null) { | |
return true; | |
} else if(o0 == null || o1 == null) { | |
return false; | |
} | |
switch (compareType) { | |
case COMPARE_TEXT: | |
return (soi0.getPrimitiveWritableObject(o0).equals(soi1.getPrimitiveWritableObject(o1))); | |
case COMPARE_INT: | |
return (ioi0.get(o0) == ioi1.get(o1)); | |
case COMPARE_LONG: | |
return (loi0.get(o0) == loi1.get(o1)); | |
case COMPARE_BYTE: | |
return (byoi0.get(o0) == byoi1.get(o1)); | |
case COMPARE_BOOL: | |
return (boi0.get(o0) == boi1.get(o1)); | |
case COMPARE_STRING: | |
return (soi0.getPrimitiveJavaObject(o0).equals(soi1.getPrimitiveJavaObject(o1))); | |
default: | |
return (ObjectInspectorUtils.compare(o0, oi0, o1, oi1) == 0); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment