Created
February 23, 2018 23:09
-
-
Save dirkraft/2ae6b4c1c9e4e9123ed64920a8c293bc to your computer and use it in GitHub Desktop.
Postgres COPY helper in Scala using JDBC Postgres driver. Made for TSV-representable rows.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.postgresql.copy.CopyManager | |
import org.postgresql.core.BaseConnection | |
import org.slf4j.LoggerFactory | |
import java.nio.charset.Charset | |
import java.sql.Connection | |
/** Postgres helper to bulk load data into a db as quickly as possible. */ | |
class PgCopy(db: DbAccess) extends Serializable { | |
private val log = LoggerFactory.getLogger(classOf[PgCopy]) | |
/** Batch inserts rows into Postgres database using | |
* - a temp table definition | |
* - a COPY statement into the temp table | |
* - given SQL to select from the temp table into the real table however the caller sees fit | |
* | |
* See `encodeTsvCell` for details about kinds of value types are supported in `rows`. | |
* | |
* @param table target table which will have rows merged into it | |
* @param tempTableRows iterator which provides row data to be copied into `tempTable`. See | |
* `encodeTsvCell` for supported value types. | |
* @param createTempTableSql create temp table. The name of the table will be extracted | |
* from this SQL for the COPY operation and eventually be dropped. | |
* @param mergeSql SQL which (presumably) selects from the temp table into the real table | |
* which can provide postgres `ON CONFLICT` behavior | |
* @return the number of rows updated as a result of `mergeSql` | |
*/ | |
def copy( | |
table: String, | |
tempTableRows: Seq[Array[Any]], | |
createTempTableSql: String, | |
mergeSql: String | |
)(conn: Connection): Int = { | |
val st = conn.createStatement() | |
st.execute(createTempTableSql) | |
val tempTableName: String = detectTempTableName(createTempTableSql) | |
log.debug(s"Detected temp table name as $tempTableName") | |
try { | |
val copyMan: CopyManager = new CopyManager(conn.unwrap(classOf[BaseConnection])) | |
val copy = copyMan.copyIn(s"COPY $tempTableName FROM STDIN NULL AS ''") | |
try { | |
var numWritten = 0 | |
var bytesWritten = 0 | |
tempTableRows.foreach { cells: Array[Any] => | |
val tsvRow = cells.map(encodeTsvCell).mkString("\t") + "\n" | |
val bytes = tsvRow.getBytes(Charset.forName("UTF-8")) | |
copy.writeToCopy(bytes, 0, bytes.length) | |
numWritten += 1 | |
bytesWritten += bytes.length | |
if (numWritten % 10000 == 0) { | |
log.debug(s"COPY wrote $numWritten rows (${bytesWritten / 1000} KB) so far.") | |
} | |
} | |
val numRows = copy.endCopy() | |
log.info(s"Processed $numWritten rows. CopyManager reports $numRows. These should agree.") | |
log.info(s"Merging $numRows from temp table $tempTableName into $table") | |
st.executeUpdate(mergeSql) | |
} finally { | |
if (copy.isActive) { | |
copy.cancelCopy() | |
} | |
} | |
} finally { | |
st.execute(s"DROP TABLE $tempTableName") | |
} | |
} | |
def detectTempTableName(sql: String): String = { | |
".*create\\s+temp\\s+table.+?([\\w_]+)".r | |
.findFirstMatchIn(sql.toLowerCase()) | |
.getOrElse { | |
throw new IllegalArgumentException(s"Couldn't detect table name in: $sql") | |
} | |
.group(1) | |
} | |
/** Encodes a single cell for embedding in TSV data sent to a Postgres COPY execution. | |
* | |
* Everything eventually becomes a string, but some types have special handling for convenience. | |
* - null or None becomes the empty string. If COPY specifies `NULL AS ''` | |
* (`PgCopy.copy` does), then nullable columns will turn empty strings into null. | |
* - Some(value) will be unwrapped of the Option. `value` will go into the cell. | |
* - Iterable turns into {abcdef,ghijklk}. Supports only non-null values without commas. | |
*/ | |
private def encodeTsvCell(cell: Any): String = { | |
def escapePgSequences(s: String): String = { | |
// Escape the escape character itself so that postgres does not interpret it. | |
s.replace("\\", "\\\\") | |
// Then we escape the literals that would mess up the TSV format. | |
.replace("\n", "\\n") | |
.replace("\r", "\\r") | |
.replace("\t", "\\t") | |
} | |
def encodeArray(els: Iterable[_]): String = { | |
val escapedElements = els.map { el => | |
require(el != null) | |
val elStr = el.toString | |
// jasond: I couldn't figure out if there was a way to escape commas inside of array values | |
// for COPY use. If you figure it out, remove this and update the scala docs. | |
require( | |
!elStr.contains(","), | |
s"',' character inside of array values are not supported: $elStr" | |
) | |
escapePgSequences(elStr) | |
} | |
if (escapedElements.isEmpty) { | |
// Leave empty arrays as blank (which can be turned into NULL) to save space in the DB. | |
"" | |
} else { | |
// You end up with something like this: {abcdefg,hijklmnop} | |
"{" + escapedElements.mkString(",") + "}" | |
} | |
} | |
cell match { | |
// Turn null or none into an empty cell, which is interpreted in the SQL `NULL AS ''`. | |
case null | None => "" | |
// Unwrap Some to its contained value. | |
case Some(o) => escapePgSequences(o.toString) | |
// Light support for SQL arrays. | |
case o: Iterable[_] => encodeArray(o) | |
// Everything else gets the toString treatment. | |
case o => escapePgSequences(o.toString) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment