Skip to content

Instantly share code, notes, and snippets.

@myui
Created September 10, 2015 10:48
Show Gist options
  • Save myui/49d6fd366afb0dd8067d to your computer and use it in GitHub Desktop.
Save myui/49d6fd366afb0dd8067d to your computer and use it in GitHub Desktop.
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.smile.classification;
import hivemall.UDTFWithOptions;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
import hivemall.smile.vm.StackMachine;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import smile.data.Attribute;
import smile.math.Math;
import smile.math.Random;
@Description(name = "train_randomforest_classifier", value = "_FUNC_(double[] features, int label [, string options]) - "
+ "Returns a relation consists of <string pred_model, double[] var_importance, int oob_errors, int oob_tests>")
public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class);
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector labelOI;
private List<double[]> featuresList;
private IntArrayList labels;
/**
* The number of trees for each task
*/
private int numTrees;
/**
* The number of random selected features
*/
private int numVars;
/**
* The maximum number of leaf nodes
*/
private int maxLeafNodes;
private long seed;
private Attribute[] attributes;
private OutputType outputType;
private SplitRule splitRule;
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("trees", "num_trees", true, "The number of trees for each task [default: 50]");
opts.addOption("vars", "num_variables", true, "The number of random selected features [default: floor(sqrt(dim))]");
opts.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantative variable and C for categorical variable. e.g., [Q,C,Q,C])");
opts.addOption("output", "output_type", true, "The output type (opscode/vm or javascript/js) [default: opscode]");
opts.addOption("split", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
int T = 50, M = -1, J = Integer.MAX_VALUE;
Attribute[] attrs = null;
long seed = -1L;
String output = "opscode";
SplitRule splitRule = SplitRule.GINI;
CommandLine cl = null;
if(argOIs.length >= 3) {
String rawArgs = HiveUtils.getConstString(argOIs[2]);
cl = parseOptions(rawArgs);
T = Primitives.parseInt(cl.getOptionValue("num_trees"), T);
if(T < 1) {
throw new IllegalArgumentException("Invlaid number of trees: " + T);
}
M = Primitives.parseInt(cl.getOptionValue("num_variables"), M);
J = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), J);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
output = cl.getOptionValue("output", output);
splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
}
this.numTrees = T;
this.numVars = M;
this.maxLeafNodes = J;
this.seed = seed;
this.attributes = attrs;
this.outputType = OutputType.resolve(output);
this.splitRule = splitRule;
return cl;
}
public enum OutputType {
opscode, javascript;
public static OutputType resolve(String name) {
if("opscode".equalsIgnoreCase(name) || "vm".equalsIgnoreCase(name)) {
return opscode;
} else if("javascript".equalsIgnoreCase(name) || "js".equalsIgnoreCase(name)) {
return javascript;
}
throw new IllegalStateException("Unexpected output type: " + name);
}
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if(argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(getClass().getSimpleName()
+ " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
processOptions(argOIs);
this.featuresList = new ArrayList<double[]>(1024);
this.labels = new IntArrayList(1024);
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("pred_model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
fieldNames.add("oob_errors");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("oob_tests");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
if(args[0] == null) {
throw new HiveException("array<double> features was null");
}
double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
featuresList.add(features);
labels.add(label);
}
@Override
public void close() throws HiveException {
int numExamples = featuresList.size();
double[][] x = featuresList.toArray(new double[numExamples][]);
this.featuresList = null;
int[] y = labels.toArray();
this.labels = null;
// run training
train(x, y, attributes, splitRule, numTrees, numVars, maxLeafNodes, seed);
// clean up
this.featureListOI = null;
this.featureElemOI = null;
this.labelOI = null;
this.attributes = null;
}
/**
* @param x
* features
* @param y
* label
* @param attrs
* attribute types
* @param numTrees
* The number of trees
* @param numVars
* The number of variables to pick up in each node.
* @param seed
* The seed number for Random Forest
*/
private void train(@Nonnull final double[][] x, @Nonnull final int[] y, @Nullable final Attribute[] attrs, @Nonnull final SplitRule splitRule, final int numTrees, final int numVars, final int maxLeafs, final long seed)
throws HiveException {
if(x.length != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
int[] labels = SmileExtUtils.classLables(y);
Attribute[] attributes = SmileExtUtils.attributeTypes(attrs, x);
int numInputVars = (numVars == -1) ? (int) Math.floor(Math.sqrt(x[0].length)) : numVars;
final int numExamples = x.length;
int[][] prediction = new int[numExamples][labels.length]; // placeholder for out-of-bag prediction
int[][] order = SmileExtUtils.sort(attributes, x);
AtomicInteger remainingTasks = new AtomicInteger(numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for(int i = 0; i < numTrees; i++) {
tasks.add(new TrainingTask(this, attributes, x, y, numInputVars, maxLeafs, order, prediction, splitRule, seed
+ i, remainingTasks));
}
MapredContext mapredContext = MapredContextAccessor.get();
final SmileTaskExecutor executor = new SmileTaskExecutor(mapredContext);
try {
executor.run(tasks);
} catch (Exception ex) {
throw new HiveException(ex);
} finally {
executor.shotdown();
}
}
/**
* Synchronized because {@link #forward(Object)} should be called from a
* single thread.
*/
public synchronized void forward(@Nonnull final String model, @Nonnull final double[] importance, final int[] y, final int[][] prediction, final boolean lastTask)
throws HiveException {
int oobErrors = 0;
int oobTests = 0;
if(lastTask) {
// out-of-bag error estimate
for(int i = 0; i < y.length; i++) {
int pred = Math.whichMax(prediction[i]);
if(prediction[i][pred] > 0) {
oobTests++;
if(pred != y[i]) {
oobErrors++;
}
}
}
}
Object[] forwardObjs = new Object[4];
forwardObjs[0] = new Text(model);
forwardObjs[1] = WritableUtils.toWritableList(importance);
forwardObjs[2] = new IntWritable(oobErrors);
forwardObjs[3] = new IntWritable(oobTests);
forward(forwardObjs);
}
/**
* Trains a regression tree.
*/
private static final class TrainingTask implements Callable<Integer> {
/**
* Attribute properties.
*/
private final Attribute[] attributes;
/**
* Training instances.
*/
private final double[][] x;
/**
* Training sample labels.
*/
private final int[] y;
/**
* The index of training values in ascending order. Note that only
* numeric attributes will be sorted.
*/
private final int[][] order;
/**
* Split rule of DecisionTrere.
*/
private final SplitRule splitRule;
/**
* The number of variables to pick up in each node.
*/
private final int numVars;
/**
* The maximum number of leaf nodes in the tree.
*/
private final int numLeafs;
/**
* The out-of-bag predictions.
*/
private final int[][] prediction;
private final RandomForestClassifierUDTF udtf;
private final long seed;
private final AtomicInteger remainingTasks;
/**
* Constructor.
*/
TrainingTask(RandomForestClassifierUDTF udtf, Attribute[] attributes, double[][] x, int[] y, int M, int J, int[][] order, int[][] prediction, SplitRule splitRule, long seed, AtomicInteger remainingTasks) {
this.udtf = udtf;
this.attributes = attributes;
this.x = x;
this.y = y;
this.order = order;
this.splitRule = splitRule;
this.numVars = M;
this.numLeafs = J;
this.prediction = prediction;
this.seed = seed;
this.remainingTasks = remainingTasks;
}
@Override
public Integer call() throws HiveException {
long s = (this.seed == -1L) ? Thread.currentThread().getId()
* System.currentTimeMillis() : this.seed;
final Random random = new Random(s);
final int n = x.length;
int[] samples = new int[n]; // Training samples draw with replacement.
for(int i = 0; i < n; i++) {
samples[random.nextInt(n)]++;
}
DecisionTree2 tree = new DecisionTree2(attributes, x, y, numVars, numLeafs, samples, order, splitRule);
// out-of-bag prediction
for(int i = 0; i < n; i++) {
if(samples[i] == 0) {
final int p = tree.predict(x[i]);
synchronized(prediction[i]) {
prediction[i][p]++;
}
}
}
String model = getModel(tree, udtf.outputType);
double[] importance = tree.importance();
int remain = remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
udtf.forward(model, importance, y, prediction, lastTask);
return Integer.valueOf(remain);
}
private String getModel(@Nonnull final DecisionTree2 tree, @Nonnull final OutputType outputType) {
final String model;
switch (outputType) {
case opscode: {
model = tree.predictOpCodegen(StackMachine.SEP);
break;
}
case javascript: {
model = tree.predictCodegen();
break;
}
default: {
logger.warn("Unexpected output type: " + udtf.outputType
+ ". Use javascript for the output instead");
model = tree.predictCodegen();
break;
}
}
return model;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment