Created
August 13, 2016 19:13
-
-
Save salamanders/cd42f99b8483e8d0d89f6edfa5b43a10 to your computer and use it in GitHub Desktop.
Guava Table<Long,String,String> to JSAT DataSet (Classification or Regression)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.Collection; | |
import java.util.Collections; | |
import java.util.HashSet; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Map.Entry; | |
import java.util.Set; | |
import java.util.SortedMap; | |
import java.util.TreeMap; | |
import java.util.logging.Logger; | |
import com.google.common.base.Preconditions; | |
import com.google.common.base.Predicates; | |
import com.google.common.collect.BiMap; | |
import com.google.common.collect.HashBiMap; | |
import com.google.common.collect.Iterables; | |
import com.google.gson.Gson; | |
import com.google.gson.GsonBuilder; | |
import jsat.classifiers.CategoricalData; | |
/** | |
* Everything you could ever want to know about a column. | |
* | |
* @author benjaminhill@gmail | |
* | |
*/ | |
public class ColumnInfo { | |
protected static final Logger LOG = Logger.getLogger(ColumnInfo.class.getName()); | |
public static final Gson GSON = new GsonBuilder().setPrettyPrinting().create(); | |
/** | |
* Utility | |
* | |
* @param from | |
* any collection of "sortable" objects | |
* @return sorted unique List of Strings | |
*/ | |
private static List<String> collectionToSortedUniqueStringList(final Collection<Object> from) { | |
final Set<String> uniqueValues = new HashSet<>(); | |
for (final Object o : from) { | |
if (null != o) { | |
uniqueValues.add(String.valueOf(o)); | |
} | |
} | |
final List<String> uniqueValuesList = new ArrayList<>(uniqueValues); | |
Collections.sort(uniqueValuesList); | |
return uniqueValuesList; | |
} | |
/** | |
* Parse a table's column from unknown string-wrapped values into Longs, Doubles, and Strings. Take into account | |
* priority, and return each column with consistent types. | |
* | |
* @param rawTable | |
* @return Table with consistent columns | |
*/ | |
private static SortedMap<Long, Object> parseColumn(final Map<Long, String> rawColumn) { | |
final SortedMap<Long, Object> parsedColumn = new TreeMap<>(); | |
// Convert the entire column. Stinks if the last cell throws it to a higher, but meh, optimize later. | |
Class<?> currentWorstClass = Long.class; | |
for (final Entry<Long, String> cell : rawColumn.entrySet()) { | |
final Object parsedObject = parseToLowestObject(cell.getValue(), currentWorstClass); | |
if (null == parsedObject) { | |
continue; | |
} | |
parsedColumn.put(cell.getKey(), parsedObject); | |
if (!currentWorstClass.equals(parsedObject.getClass())) { | |
LOG.info("During column parse, bumping up lowest class from " + currentWorstClass.getSimpleName() + " to " | |
+ parsedObject.getClass().getSimpleName() + " because of '" + cell.getValue() + "'"); | |
currentWorstClass = parsedObject.getClass(); | |
} | |
} | |
// Now double check the column. Could optimize to only check upwards, but meh. | |
for (final Entry<Long, String> cell : rawColumn.entrySet()) { | |
final Object currentValue = parsedColumn.get(cell.getKey()); | |
if (null == currentValue || currentWorstClass.equals(currentValue.getClass())) { | |
continue; | |
} | |
if (Double.class.equals(currentWorstClass)) { | |
// It can only be a Long | |
parsedColumn.put(cell.getKey(), ((Long) currentValue).doubleValue()); | |
} else if (String.class.equals(currentWorstClass)) { | |
parsedColumn.put(cell.getKey(), String.valueOf(currentValue)); | |
} else { | |
throw new RuntimeException( | |
"No idea how to turn a " + currentValue.getClass().getName() + " into a " + currentWorstClass.getName()); | |
} | |
} | |
return parsedColumn; | |
} | |
/** | |
* Utility function to parse a string to the most basic possible type - Long, Double, or String. May return null. | |
* | |
* @param value | |
* @return | |
*/ | |
private static Object parseToLowestObject(final String value, final Class<?> currentClass) { | |
if (null == value || value.isEmpty()) { | |
return null; | |
} | |
if (String.class.equals(currentClass)) { | |
return value; | |
} | |
if (Long.class.equals(currentClass)) { | |
try { | |
return Long.valueOf(value); | |
} catch (final NumberFormatException nfe) { | |
// ignore instead of nest | |
} | |
} | |
try { | |
return Double.valueOf(value); | |
} catch (final NumberFormatException nfe) { | |
// ignore instead of nest | |
} | |
return value; | |
} | |
private final CategoricalData categoricalData; | |
private final BiMap<String, Integer> lookup = HashBiMap.create(); | |
private final String name; | |
private final SortedMap<Long, Object> parsedData; | |
private final Class<?> type; | |
/** | |
* Column will be parsed and saved as lowest common type (Long/Double/String) | |
* | |
* @param columnData | |
*/ | |
public ColumnInfo(final String columnName, final Map<Long, String> columnData) { | |
this.name = columnName; | |
parsedData = parseColumn(columnData); | |
type = Iterables.find(parsedData.values(), Predicates.notNull()).getClass(); | |
constructLabelLookups(); | |
categoricalData = isLookup() ? constructJSATCategoricalData() : null; | |
} | |
/** | |
* Once we have created a per-column label-to-int mapping, create a mapping from column names to the CategoricalData | |
* | |
* @param columnAndValueToInt | |
* @return | |
*/ | |
private CategoricalData constructJSATCategoricalData() { | |
Preconditions.checkState(!lookup.isEmpty(), "Tried to get CategoricalData for non-lookup column"); | |
final CategoricalData cd = new CategoricalData(parsedData.size()); | |
cd.setCategoryName(name); | |
for (final Entry<String, Integer> cell : lookup.entrySet()) { | |
cd.setOptionName(cell.getKey(), cell.getValue()); | |
} | |
return cd; | |
} | |
/** | |
* Map from column->Value->Integer so we can keep all the lookups for the CategoricalData. | |
* | |
* Has a hidden threshold of >=20 unique Long values means "not a category" | |
*/ | |
private void constructLabelLookups() { | |
if (String.class.equals(type) || Long.class.equals(type)) { | |
final List<String> uniqueValuesList = collectionToSortedUniqueStringList(parsedData.values()); | |
// Bail if too many long values | |
if (Long.class.equals(type) && uniqueValuesList.size() > 20) { | |
return; | |
} | |
for (int i = 0; i < uniqueValuesList.size(); i++) { | |
lookup.put(uniqueValuesList.get(i), i); | |
} | |
} | |
} | |
public CategoricalData getCategoricalData() { | |
Preconditions.checkNotNull(categoricalData, "Tried to get CategoricalData for non-lookup column"); | |
return categoricalData; | |
} | |
public String getName() { | |
return name; | |
} | |
public Class<?> getType() { | |
return type; | |
} | |
public boolean isLookup() { | |
return !lookup.isEmpty(); | |
} | |
/** | |
* Takes into account lookup tables for strings | |
* | |
* @param rowId | |
* @return | |
*/ | |
public Number getRowValue(final Number rowId) { | |
final Object value = parsedData.get(rowId); | |
if (!isLookup()) { | |
return (Number) value; | |
} | |
final Integer intValue = lookup.get(String.valueOf(value)); | |
Preconditions.checkNotNull(intValue); | |
return intValue; | |
} | |
public String getKeyFromLookupId(final int id) { | |
return lookup.inverse().get(id); | |
} | |
public Set<Long> getAllRowKeys() { | |
return parsedData.keySet(); | |
} | |
@Override | |
public int hashCode() { | |
return name.hashCode(); | |
} | |
@Override | |
public boolean equals(Object obj) { | |
if (this == obj) { | |
return true; | |
} | |
if (obj == null) { | |
return false; | |
} | |
if (!(obj instanceof ColumnInfo)) { | |
return false; | |
} | |
ColumnInfo other = (ColumnInfo) obj; | |
if (name == null) { | |
if (other.name != null) { | |
return false; | |
} | |
} else if (!name.equals(other.name)) { | |
return false; | |
} | |
return true; | |
} | |
@Override | |
public String toString() { | |
final Map<String, Object> tmp = new TreeMap<>(); | |
tmp.put("name", name); | |
tmp.put("type", type.getSimpleName().substring(0, 1)); | |
tmp.put("lookup", lookup); | |
tmp.put("sample", parsedData.subMap(0L, 25L)); | |
return GSON.toJson(tmp); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Map.Entry; | |
import java.util.SortedMap; | |
import java.util.SortedSet; | |
import java.util.TreeMap; | |
import java.util.TreeSet; | |
import java.util.logging.Logger; | |
import com.google.common.base.Preconditions; | |
import com.google.common.collect.ImmutableMap; | |
import com.google.common.collect.Table; | |
import com.google.common.primitives.Ints; | |
import com.google.gson.Gson; | |
import com.google.gson.GsonBuilder; | |
import jsat.DataSet; | |
import jsat.classifiers.CategoricalData; | |
import jsat.classifiers.ClassificationDataSet; | |
import jsat.linear.DenseVector; | |
import jsat.regression.RegressionDataSet; | |
import jsat.utils.DoubleList; | |
/** | |
* Convert a Guava Table into a JSAT-friendly data source. | |
* | |
* Cribbed mainly from JSAT/JSAT/src/jsat/ARFFLoader.java and JSAT/src/jsat/io/CSV.java There has to be a better way to | |
* do this. | |
* | |
* @author benjaminhill@gmail | |
* | |
*/ | |
public class TableDataLoader { | |
protected static final Logger LOG = Logger.getLogger(TableDataLoader.class.getName()); | |
public static final Gson GSON = new GsonBuilder().create(); | |
private final SortedMap<String, ColumnInfo> columns = new TreeMap<>(); | |
public TableDataLoader(final Table<Long, String, String> rawTable) { | |
Preconditions.checkNotNull(rawTable); | |
for (final Entry<String, Map<Long, String>> column : rawTable.columnMap().entrySet()) { | |
final ColumnInfo ci = new ColumnInfo(column.getKey(), column.getValue()); | |
columns.put(ci.getName(), ci); | |
LOG.info(ci.toString()); | |
} | |
} | |
private void columnsToFeats(final ColumnInfo targetCi, final Number rowId, final DoubleList numericFeats, | |
final List<Integer> catFeats) { | |
for (final ColumnInfo ci : columns.values()) { | |
if (targetCi != ci) { | |
if (ci.isLookup()) { | |
catFeats.add(((Number) ci.getRowValue(rowId)).intValue()); | |
} else { | |
numericFeats.add(((Number) ci.getRowValue(rowId)).doubleValue()); | |
} | |
} | |
} | |
} | |
/** | |
* | |
* @param rawTable | |
* a Table of String-encoded data (with both nulls or empty strings allowed) | |
* @param outputColumnName | |
* which column you are trying to predict, may be Integer/String for Classification, Doubles for Regression | |
* @return | |
*/ | |
public DataSet<?> getDataSet(final String outputColumnName) { | |
Preconditions.checkNotNull(outputColumnName); | |
Preconditions.checkState(columns.containsKey(outputColumnName), "output column doesn't exist."); | |
final ColumnInfo targetCi = columns.get(outputColumnName); | |
final List<CategoricalData> inputCDs = new ArrayList<>(); | |
final SortedSet<Number> allRowKeys = new TreeSet<>(); | |
for (final ColumnInfo ci : columns.values()) { | |
allRowKeys.addAll(ci.getAllRowKeys()); | |
if (ci.isLookup() && targetCi != ci) { | |
inputCDs.add(ci.getCategoricalData()); | |
} | |
} | |
// Count the numbers after removing the category | |
final int categoricalColumnCount = inputCDs.size(); | |
final int numericalColumnCount = columns.size() - categoricalColumnCount - 1; | |
Preconditions.checkState(categoricalColumnCount >= 0); | |
Preconditions.checkState(numericalColumnCount >= 0); | |
LOG.info("Forking on output type (name:" + targetCi.getName() + ", isLookup:" + targetCi.isLookup() + ", cat:" | |
+ categoricalColumnCount + ", num:" + numericalColumnCount + ")"); | |
if (targetCi.isLookup()) { | |
LOG.info("Classification"); | |
return tableToDataSet_Classification(targetCi, inputCDs, allRowKeys, categoricalColumnCount, | |
numericalColumnCount); | |
} | |
LOG.info("Regression"); | |
return tableToDataSet_Regression(targetCi, inputCDs, allRowKeys, categoricalColumnCount, numericalColumnCount); | |
} | |
/** | |
* All done with the prep, time to build the actual data set (Classification) | |
* | |
* @param parsedTable | |
* @param outputColumnName | |
* @param columnTypes | |
* @param columnAndValueToInt | |
* @param categoricalDatas | |
* @return | |
*/ | |
private ClassificationDataSet tableToDataSet_Classification(final ColumnInfo targetCi, | |
final List<CategoricalData> inputCDs, final SortedSet<Number> allRowKeys, final int categoricalColumnCount, | |
final int numericalColumnCount) { | |
final CategoricalData targetCD = targetCi.getCategoricalData(); | |
final ClassificationDataSet cds = new ClassificationDataSet(numericalColumnCount, | |
inputCDs.toArray(new CategoricalData[inputCDs.size()]), targetCD); | |
// Add all data points | |
for (final Number rowId : allRowKeys) { | |
final DoubleList numericFeats = new DoubleList(numericalColumnCount); | |
final List<Integer> catFeats = new ArrayList<>(categoricalColumnCount); | |
final int outputClass = ((Number) targetCi.getRowValue(rowId)).intValue(); | |
columnsToFeats(targetCi, rowId, numericFeats, catFeats); | |
cds.addDataPoint(new DenseVector(numericFeats), Ints.toArray(catFeats), outputClass); | |
} | |
return cds; | |
} | |
/** | |
* All done with the prep, time to build the actual data set (Regression) | |
* | |
* @param parsedTable | |
* @param outputColumnName | |
* @param columnTypes | |
* @param columnAndValueToInt | |
* @param categoricalDatas | |
* @return | |
*/ | |
private RegressionDataSet tableToDataSet_Regression(final ColumnInfo targetCi, final List<CategoricalData> inputCDs, | |
final SortedSet<Number> allRowKeys, final int categoricalColumnCount, final int numericalColumnCount) { | |
final RegressionDataSet rds = new RegressionDataSet(numericalColumnCount, | |
inputCDs.toArray(new CategoricalData[inputCDs.size()])); | |
// Add all data points | |
for (final Number rowId : allRowKeys) { | |
final DoubleList numericFeats = new DoubleList(numericalColumnCount); | |
final List<Integer> catFeats = new ArrayList<>(categoricalColumnCount); | |
final double outputValue = ((Number) targetCi.getRowValue(rowId)).doubleValue(); | |
columnsToFeats(targetCi, rowId, numericFeats, catFeats); | |
try { | |
rds.addDataPoint(new DenseVector(numericFeats), Ints.toArray(catFeats), outputValue); | |
} catch (final RuntimeException re) { | |
LOG.severe("NO go:" + GSON.toJson(ImmutableMap.of("targetCi", targetCi.getName(), "rowId", rowId, | |
"numericFeats", numericFeats, "catFeats", catFeats))); | |
throw re; | |
} | |
} | |
return rds; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment