Skip to content

Instantly share code, notes, and snippets.

@dirkraft
Created February 23, 2018 23:09
Show Gist options
  • Save dirkraft/2ae6b4c1c9e4e9123ed64920a8c293bc to your computer and use it in GitHub Desktop.
Save dirkraft/2ae6b4c1c9e4e9123ed64920a8c293bc to your computer and use it in GitHub Desktop.
Postgres COPY helper in Scala using JDBC Postgres driver. Made for TSV-representable rows.
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
import org.slf4j.LoggerFactory
import java.nio.charset.Charset
import java.sql.Connection
/** Postgres helper to bulk load data into a db as quickly as possible. */
class PgCopy(db: DbAccess) extends Serializable {
private val log = LoggerFactory.getLogger(classOf[PgCopy])
/** Batch inserts rows into Postgres database using
* - a temp table definition
* - a COPY statement into the temp table
* - given SQL to select from the temp table into the real table however the caller sees fit
*
* See `encodeTsvCell` for details about kinds of value types are supported in `rows`.
*
* @param table target table which will have rows merged into it
* @param tempTableRows iterator which provides row data to be copied into `tempTable`. See
* `encodeTsvCell` for supported value types.
* @param createTempTableSql create temp table. The name of the table will be extracted
* from this SQL for the COPY operation and eventually be dropped.
* @param mergeSql SQL which (presumably) selects from the temp table into the real table
* which can provide postgres `ON CONFLICT` behavior
* @return the number of rows updated as a result of `mergeSql`
*/
def copy(
table: String,
tempTableRows: Seq[Array[Any]],
createTempTableSql: String,
mergeSql: String
)(conn: Connection): Int = {
val st = conn.createStatement()
st.execute(createTempTableSql)
val tempTableName: String = detectTempTableName(createTempTableSql)
log.debug(s"Detected temp table name as $tempTableName")
try {
val copyMan: CopyManager = new CopyManager(conn.unwrap(classOf[BaseConnection]))
val copy = copyMan.copyIn(s"COPY $tempTableName FROM STDIN NULL AS ''")
try {
var numWritten = 0
var bytesWritten = 0
tempTableRows.foreach { cells: Array[Any] =>
val tsvRow = cells.map(encodeTsvCell).mkString("\t") + "\n"
val bytes = tsvRow.getBytes(Charset.forName("UTF-8"))
copy.writeToCopy(bytes, 0, bytes.length)
numWritten += 1
bytesWritten += bytes.length
if (numWritten % 10000 == 0) {
log.debug(s"COPY wrote $numWritten rows (${bytesWritten / 1000} KB) so far.")
}
}
val numRows = copy.endCopy()
log.info(s"Processed $numWritten rows. CopyManager reports $numRows. These should agree.")
log.info(s"Merging $numRows from temp table $tempTableName into $table")
st.executeUpdate(mergeSql)
} finally {
if (copy.isActive) {
copy.cancelCopy()
}
}
} finally {
st.execute(s"DROP TABLE $tempTableName")
}
}
def detectTempTableName(sql: String): String = {
".*create\\s+temp\\s+table.+?([\\w_]+)".r
.findFirstMatchIn(sql.toLowerCase())
.getOrElse {
throw new IllegalArgumentException(s"Couldn't detect table name in: $sql")
}
.group(1)
}
/** Encodes a single cell for embedding in TSV data sent to a Postgres COPY execution.
*
* Everything eventually becomes a string, but some types have special handling for convenience.
* - null or None becomes the empty string. If COPY specifies `NULL AS ''`
* (`PgCopy.copy` does), then nullable columns will turn empty strings into null.
* - Some(value) will be unwrapped of the Option. `value` will go into the cell.
* - Iterable turns into {abcdef,ghijklk}. Supports only non-null values without commas.
*/
private def encodeTsvCell(cell: Any): String = {
def escapePgSequences(s: String): String = {
// Escape the escape character itself so that postgres does not interpret it.
s.replace("\\", "\\\\")
// Then we escape the literals that would mess up the TSV format.
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("\t", "\\t")
}
def encodeArray(els: Iterable[_]): String = {
val escapedElements = els.map { el =>
require(el != null)
val elStr = el.toString
// jasond: I couldn't figure out if there was a way to escape commas inside of array values
// for COPY use. If you figure it out, remove this and update the scala docs.
require(
!elStr.contains(","),
s"',' character inside of array values are not supported: $elStr"
)
escapePgSequences(elStr)
}
if (escapedElements.isEmpty) {
// Leave empty arrays as blank (which can be turned into NULL) to save space in the DB.
""
} else {
// You end up with something like this: {abcdefg,hijklmnop}
"{" + escapedElements.mkString(",") + "}"
}
}
cell match {
// Turn null or none into an empty cell, which is interpreted in the SQL `NULL AS ''`.
case null | None => ""
// Unwrap Some to its contained value.
case Some(o) => escapePgSequences(o.toString)
// Light support for SQL arrays.
case o: Iterable[_] => encodeArray(o)
// Everything else gets the toString treatment.
case o => escapePgSequences(o.toString)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment