Skip to content

Instantly share code, notes, and snippets.

@eredding-rmn
Forked from deinspanjer/BulkLoadTester.java
Created February 28, 2014 19:32
Show Gist options
  • Save eredding-rmn/9278072 to your computer and use it in GitHub Desktop.
Save eredding-rmn/9278072 to your computer and use it in GitHub Desktop.
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.CharBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Arrays;
import java.util.Properties;
import java.util.Scanner;
import java.util.concurrent.Executors;
import java.util.zip.GZIPOutputStream;
import com.vertica.jdbc.VerticaConnection;
import com.vertica.jdbc.VerticaCopyStream;
public class BulkLoadTester {
private static long start;
private static int rowCounter;
private static int batchCounter;
private static boolean isDryRun = false;
private enum methods {
batch_wos, batch_ros, copy_binary, copy_binary_gzip, copy_varchar, copy_varchar_gzip, copy_delimiter, copy_delimiter_gzip
}
private static final byte ZERO = (byte)0;
private static final byte ONE = (byte)1;
public static void main(String[] args) {
if (args.length < 6 || args.length > 7) {
usage(null);
return;
}
if (args.length == 7 && "dryrun".equals(args[6])) {
isDryRun = true;
}
try {
Class.forName("com.vertica.jdbc.Driver");
} catch (ClassNotFoundException e) {
System.err.println("Could not find the JDBC driver class.");
return;
}
Properties myProp = new Properties();
myProp.put("user", args[1]);
myProp.put("password", args[2]);
Connection conn;
try {
conn = DriverManager.getConnection(args[0], myProp);
conn.setAutoCommit(false);
} catch (Exception e) {
usage(e);
return;
}
try {
int batchSize = Integer.parseInt(args[4]);
File input = new File(args[5]);
Scanner scanner = new Scanner(input);
scanner.useDelimiter("[\t\n]");
// Skip header
scanner.nextLine();
createTable(conn);
start = System.currentTimeMillis();
switch (methods.valueOf(args[3])) {
case batch_ros:
runBatch(conn, scanner, true, batchSize);
break;
case batch_wos:
runBatch(conn, scanner, false, batchSize);
break;
case copy_binary:
runCopyBinary(conn, scanner, batchSize, "UNCOMPRESSED");
break;
case copy_delimiter:
runCopyDelimiter(conn, scanner, batchSize, "UNCOMPRESSED");
break;
case copy_varchar:
runCopyVarchar(conn, scanner, batchSize, "UNCOMPRESSED");
break;
case copy_binary_gzip:
runCopyBinary(conn, scanner, batchSize, "GZIP");
break;
case copy_varchar_gzip:
runCopyVarchar(conn, scanner, batchSize, "GZIP");
break;
case copy_delimiter_gzip:
runCopyDelimiter(conn, scanner, batchSize, "GZIP");
break;
}
printResult(conn, ((System.currentTimeMillis() - start) / 1000));
} catch (Exception e) {
usage(e);
return;
}
}
private static void runCopyVarchar(Connection conn, Scanner scanner, int batchSize, String compressed) {
System.err.println("method not yet implemented");
}
private static void runCopyDelimiter(final Connection conn, Scanner scanner, int batchSize, String compression) throws SQLException, IOException {
PipedOutputStream pipedOutputStream = new PipedOutputStream();
final PipedInputStream pipedInputStream = new PipedInputStream(pipedOutputStream);
OutputStream stream;
if ("GZIP".equals(compression)) {
stream = new GZIPOutputStream(pipedOutputStream);
} else {
stream = pipedOutputStream;
}
final String stmt = String.format("COPY test_load FROM STDIN %s DELIMITER '|' DIRECT", compression);
startCopyWorker(conn, stmt, pipedInputStream);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(stream,Charset.forName("UTF-8")));
Object[] row = new Object[6];
while (scanner.hasNext()) {
buildRow(scanner, row);
writer.write(String.valueOf(row[0]));
writer.write('|');
writer.write(String.valueOf(row[1]));
writer.write('|');
writer.write(String.valueOf(row[2]));
writer.write('|');
writer.write(String.valueOf(row[3]));
writer.write('|');
writer.write(String.valueOf(row[4]));
writer.write('|');
writer.write(String.valueOf(row[5]));
writer.write('\n');
if (scanner.hasNextLine()) {
scanner.nextLine();
}
if (++rowCounter % 100000 == 0) {
System.err.format("\r%,10d rows %,3d batches (%,6d r/s)", rowCounter, batchCounter, (rowCounter / ((System.currentTimeMillis() - start) / 1000)));
}
}
writer.close();
}
private static void runCopyBinary(Connection conn, Scanner scanner, int batchSize, String compression) throws IOException {
// Column byte widths as 32bit ints
// boolean(1), int(4), int(4), varchar(-1), char(36), float/double(8)
int[] columnByteWidths = new int[] { 1, 4, 4, -1, 36, 8 };
int columnCount = columnByteWidths.length;
VerticaBinaryFormatBitSet rowNulls = new VerticaBinaryFormatBitSet(columnCount);
int rowHeaderSize = 4 + rowNulls.numBytes();
int rowVarcharCount = 0;
int rowMaxSize = rowHeaderSize;
for (int i : columnByteWidths) {
if (i == -1) {
rowVarcharCount++;
} else {
rowMaxSize += i;
}
}
rowMaxSize += (rowVarcharCount * 65000);
int bufferSize = rowMaxSize * (batchSize == 0 ? 1000 : batchSize);
PipedOutputStream pipedOutputStream = new PipedOutputStream();
final PipedInputStream pipedInputStream = new PipedInputStream(pipedOutputStream);
OutputStream stream;
if ("GZIP".equals(compression)) {
stream = new GZIPOutputStream(pipedOutputStream);
} else {
stream = pipedOutputStream;
}
WritableByteChannel channel = Channels.newChannel(stream);
final String stmt = String.format("COPY test_load FROM STDIN %s NATIVE DIRECT", compression);
startCopyWorker(conn, stmt, pipedInputStream);
ByteBuffer bb = ByteBuffer.allocate(bufferSize);
bb.order(ByteOrder.LITTLE_ENDIAN);
bb.clear();
// Buffer at max size for any char or varchar
ByteBuffer eb;
CharBuffer cb = CharBuffer.allocate(65000);
Charset charset = Charset.forName("UTF-8");
CharsetEncoder encoder = charset.newEncoder();
// File signature
cb.put("NATIVE").append('\n').append('\317').append('\r').append('\n').append('\000');
cb.flip();
bb.put(encoder.encode(cb));
cb.clear();
// Header area length (5 bytes for next three puts + (4 * N columns))
bb.putInt(5 + (4 * 6));
// NATIVE file version
bb.putShort((short)1);
// Filler (Always 0)
bb.put(ZERO);
// Number of columns
bb.putShort((short)6);
for (int i : columnByteWidths) {
bb.putInt(i);
}
// In PDI, the rows come as an array of Objects whose types correspond to the available metadata for the row.
// Since we have to predetermine the size of the row, let's build that array here.
Object[] row = new Object[columnCount];
while (scanner.hasNext()) {
rowNulls.clear();
checkAndFlushBuffer(rowMaxSize, channel, bb);
buildRow(scanner, row);
if (scanner.hasNextLine()) {
scanner.nextLine();
}
int rowDataSize = 0;
// record the start of this row so we can come back and update the size and nulls
int rowDataSizeFieldPosition = bb.position();
bb.putInt(rowDataSize);
int rowNullsFieldPosition = bb.position();
rowNulls.writeBytes(bb);
if (row[0] != null) {
rowDataSize += 1;
bb.put(((Boolean)row[0]).booleanValue() ? ONE : ZERO);
} else {
rowNulls.setBit(0);
}
if (row[1] != null) {
rowDataSize += 4;
bb.putInt((Integer)row[1]);
} else {
rowNulls.setBit(1);
}
if (row[2] != null) {
rowDataSize += 4;
bb.putInt((Integer)row[2]);
} else {
rowNulls.setBit(2);
}
if (row[3] != null) {
// 4 bytes for int declaring field data length
rowDataSize += 4;
cb.put((String)row[3]);
cb.flip();
eb = encoder.encode(cb);
cb.clear();
// Limit should be the number of bytes in the encoded string
bb.putInt(eb.limit());
bb.put(eb);
} else {
rowNulls.setBit(3);
}
//XXX: I'm not space padding this field because I know my dataset, but I should.
if (row[4] != null) {
rowDataSize += columnByteWidths[4];
cb.put((String)row[4]);
cb.flip();
eb = encoder.encode(cb);
cb.clear();
if (eb.limit() != columnByteWidths[4]) {
throw new IOException(String.format("CHAR string is an invalid byte size. Expected %d, got %d '%s' bytes:", columnByteWidths[4], eb.limit(), row[4], eb.array()));
}
bb.put(eb);
} else {
rowNulls.setBit(4);
}
if (row[5] != null) {
rowDataSize += 8;
bb.putDouble((Double)row[5]);
} else {
rowNulls.setBit(5);
}
// Now fill in the row header
bb.putInt(rowDataSizeFieldPosition, rowDataSize);
rowNulls.writeBytes(rowNullsFieldPosition, bb);
if (++rowCounter % 100000 == 0) {
System.err.format("\r%,10d rows %,3d batches (%,6d r/s)", rowCounter, batchCounter, (rowCounter / ((System.currentTimeMillis() - start) / 1000)));
}
}
bb.flip();
channel.write(bb);
bb.clear();
batchCounter++;
channel.close();
stream.close();
}
private static void buildRow(Scanner scanner, Object[] row) {
row[0] = scanner.nextBoolean();
row[1] = scanner.nextInt();
row[2] = scanner.nextInt();
row[3] = scanner.next();
row[4] = scanner.next();
row[5] = scanner.nextDouble();
}
private static void checkAndFlushBuffer(int rowMaxSize, WritableByteChannel channel, ByteBuffer bb)
throws IOException {
if (bb.position() + rowMaxSize > bb.capacity()) {
bb.flip();
channel.write(bb);
bb.clear();
batchCounter++;
}
}
private static void startCopyWorker(final Connection conn, final String stmt, final PipedInputStream pipedInputStream) {
Thread worker = Executors.defaultThreadFactory().newThread(new Runnable() {
public void run() {
try {
if (isDryRun) {
while (pipedInputStream.read() != -1) {
pipedInputStream.skip(pipedInputStream.available());
}
} else {
VerticaCopyStream stream = new VerticaCopyStream((VerticaConnection)conn, stmt);
stream.start();
stream.addStream(pipedInputStream);
stream.execute();
stream.finish();
conn.commit();
}
} catch (Exception e) {
e.printStackTrace();
}
}
});
worker.start();
}
private static void runBatch(Connection conn, Scanner scanner, boolean isDirect, int batchSize) throws SQLException, FileNotFoundException {
((VerticaConnection) conn).setProperty("DirectBatchInsert", isDirect);
// Create the prepared statement
PreparedStatement pstmt = conn.prepareStatement("INSERT INTO test_load"
+ "(_bool, _seq, _int, _string, _uuid, _float)"
+ " VALUES(?,?,?,?,?,?)");
Object[] row = new Object[6];
while (scanner.hasNext()) {
buildRow(scanner,row);
pstmt.setBoolean(1, (Boolean)row[0]);
pstmt.setInt(2, (Integer)row[1]);
pstmt.setInt(3, (Integer)row[2]);
pstmt.setString(4, (String)row[3]);
pstmt.setString(5, (String)row[4]);
pstmt.setDouble(6, (Double)row[5]);
pstmt.addBatch();
if (scanner.hasNextLine()) {
scanner.nextLine();
}
if (++rowCounter % 100000 == 0) {
System.err.format("\r%,10d rows %,3d batches (%,6d r/s)", rowCounter, batchCounter, (rowCounter / ((System.currentTimeMillis() - start) / 1000)));
}
if (batchSize > 0 && rowCounter % batchSize == 0) {
if (!isDryRun) {
pstmt.executeBatch();
}
pstmt.clearBatch();
batchCounter++;
}
}
// If we didn't do an execute on the last iteration of the loop
if (batchSize == 0 || rowCounter % batchSize != 0) {
if (!isDryRun) {
pstmt.executeBatch();
}
batchCounter++;
}
// Commit the transaction to close the COPY command
conn.commit();
}
private static void usage(Exception e) {
System.err.println("Usage: BulkLoadTester <jdbc:vertica://<host>:<port>/<db>> <user> <pass> <run method> <batch size> <input file> [dryrun]");
System.err.format("Run Methods: %s%n", Arrays.toString(methods.values()));
if (e != null) {
e.printStackTrace();
}
System.exit(1);
}
private static void printResult(Connection conn, long seconds) throws SQLException {
int rows;
if (isDryRun) {
rows = rowCounter;
} else {
ResultSet rs = null;
Statement stmt = conn.createStatement();
rs = stmt.executeQuery("SELECT COUNT(1) FROM test_load;");
rs.next();
rows = rs.getInt(1);
}
System.out.format("\r%,d rows %s inserted with %,d batches in %s (%,d r/s)%n", rows, (isDryRun?"NOT":""), batchCounter, formatDuration(seconds), rows / seconds);
// Cleanup
conn.close();
}
private static void createTable(Connection conn) throws SQLException {
Statement stmt = conn.createStatement();
stmt.execute("DROP TABLE IF EXISTS test_load CASCADE");
stmt.execute("CREATE TABLE test_load(_bool BOOLEAN, _seq INTEGER, _int INTEGER, _string VARCHAR(13), _uuid CHAR(36), _float FLOAT); CREATE PROJECTION test_load_unseg_super(_bool, _seq, _int, _string, _uuid, _float) AS SELECT _bool, _seq, _int, _string, _uuid, _float FROM test_load UNSEGMENTED ALL NODES;");
}
private static String formatDuration(long seconds) {
if (seconds <= 60) {
return String.format("%s seconds", String.valueOf(seconds));
} else if (seconds <= 60 * 60) {
int min = (int) (seconds / 60);
int rem = (int) (seconds % 60);
return String.format("%sm %ss", String.valueOf(min), String.valueOf(rem), String.valueOf(seconds));
} else if (seconds <= 60 * 60 * 24) {
int rem;
int hour = (int) (seconds / (60 * 60));
rem = (int) (seconds % (60 * 60));
int min = rem / 60;
rem = rem % 60;
return String.format("%sh %sm %ss", String.valueOf(hour), String.valueOf(min), String.valueOf(rem),
String.valueOf(seconds));
} else {
int rem;
int days = (int) (seconds / (60 * 60 * 24));
rem = (int) (seconds % (60 * 60 * 24));
int hour = rem / (60 * 60);
rem = rem % (60 * 60);
int min = rem / 60;
rem = rem % 60;
return String.format("%sd %sh %sm %ss", String.valueOf(days), String.valueOf(hour), String.valueOf(min),
String.valueOf(rem), String.valueOf(seconds));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment