Skip to content

Instantly share code, notes, and snippets.

@andrewm4894
Created May 8, 2019 16:27
Show Gist options
  • Select an option

  • Save andrewm4894/7751a5115b7073af1eb06fe9c5dc1dd0 to your computer and use it in GitHub Desktop.

Select an option

Save andrewm4894/7751a5115b7073af1eb06fe9c5dc1dd0 to your computer and use it in GitHub Desktop.
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class myExample {
public static void main(String[] args) throws Exception {
Schema inputDataSchema = new Schema.Builder()
.addColumnString("DateTimeString")
.addColumnsString("CustomerID", "MerchantID")
.addColumnInteger("NumItemsInTransaction")
.addColumnCategorical("MerchantCountryCode", Arrays.asList("USA","CAN","FR","MX"))
.addColumnDouble("TransactionAmountUSD",0.0,null,false,false) //$0.0 or more, no maximum limit, no NaN and no Infinite values
.addColumnCategorical("FraudLabel", Arrays.asList("Fraud","Legit"))
.build();
TransformProcess tp = new TransformProcess.Builder(inputDataSchema)
.removeAllColumnsExceptFor("DateTimeString","TransactionAmountUSD")
.build();
File inputFile = new ClassPathResource("BasicDataVecExample/exampledata.csv").getFile();
//Define input reader and output writer:
RecordReader rr = new CSVRecordReader(1, ',');
rr.initialize(new FileSplit(inputFile));
//Process the data:
List<List<Writable>> originalData = new ArrayList<>();
while(rr.hasNext()){
originalData.add(rr.next());
}
List<List<Writable>> processedData = LocalTransformExecutor.execute(originalData, tp);
int numRows = 5;
System.out.println("=== BEFORE ===");
for (int i=0;i<=numRows;i++) {
System.out.println(originalData.get(i));
}
System.out.println("=== AFTER ===");
for (int i=0;i<=numRows;i++) {
System.out.println(processedData.get(i));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment