Skip to content

Instantly share code, notes, and snippets.

@anthonny
Forked from longcao/SparkCopyPostgres.scala
Created August 8, 2017 07:03
Show Gist options
  • Save anthonny/817517f37f4977e8daae279d38e83bb0 to your computer and use it in GitHub Desktop.
Save anthonny/817517f37f4977e8daae279d38e83bb0 to your computer and use it in GitHub Desktop.
COPY Spark DataFrame rows to PostgreSQL (via JDBC)
import java.io.InputStream
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided
val connectionProperties = {
val props = new java.util.Properties()
props.setProperty("driver", "org.postgresql.Driver")
props
}
// Spark reads the "driver" property to allow users to override the default driver selected, otherwise
// it picks the Redshift driver, which doesn't support JDBC CopyManager.
// https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51
val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties)
// Convert every partition (an `Iterator[Row]`) to bytes (InputStream)
def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
val bytes: Iterator[Byte] = rows.map { row =>
(row.mkString(delimiter) + "\n").getBytes
}.flatten
new InputStream {
override def read(): Int = if (bytes.hasNext) {
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
} else {
-1
}
}
}
// Beware: this will open a db connection for every partition of your DataFrame.
frame.foreachPartition { rows =>
val conn = cf()
val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
cm.copyIn(
"""COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html
rowsToInputStream(rows, "\t"))
conn.close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment