Last active
September 11, 2024 18:55
-
-
Save longcao/bb61f1798ccbbfa4a0d7b76e49982f84 to your computer and use it in GitHub Desktop.
COPY Spark DataFrame rows to PostgreSQL (via JDBC)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.InputStream | |
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils | |
import org.apache.spark.sql.{ DataFrame, Row } | |
import org.postgresql.copy.CopyManager | |
import org.postgresql.core.BaseConnection | |
val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided | |
val connectionProperties = { | |
val props = new java.util.Properties() | |
props.setProperty("driver", "org.postgresql.Driver") | |
props | |
} | |
// Spark reads the "driver" property to allow users to override the default driver selected, otherwise | |
// it picks the Redshift driver, which doesn't support JDBC CopyManager. | |
// https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51 | |
val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties) | |
// Convert every partition (an `Iterator[Row]`) to bytes (InputStream) | |
def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = { | |
val bytes: Iterator[Byte] = rows.map { row => | |
(row.mkString(delimiter) + "\n").getBytes | |
}.flatten | |
new InputStream { | |
override def read(): Int = if (bytes.hasNext) { | |
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255 | |
} else { | |
-1 | |
} | |
} | |
} | |
// Beware: this will open a db connection for every partition of your DataFrame. | |
frame.foreachPartition { rows => | |
val conn = cf() | |
val cm = new CopyManager(conn.asInstanceOf[BaseConnection]) | |
cm.copyIn( | |
"""COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html | |
rowsToInputStream(rows, "\t")) | |
conn.close() | |
} |
Slightly modified to deal with escaping and Spark 2.2
import java.io.InputStream import java.sql.DriverManager import java.util.Properties import org.apache.spark.sql.{DataFrame, Row} import org.postgresql.copy.CopyManager import org.postgresql.core.BaseConnection object CopyHelper { def rowsToInputStream(rows: Iterator[Row]): InputStream = { val bytes: Iterator[Byte] = rows.map { row => (row.toSeq .map { v => if (v == null) { """\N""" } else { "\"" + v.toString.replaceAll("\"", "\"\"") + "\"" } } .mkString("\t") + "\n").getBytes }.flatten new InputStream { override def read(): Int = if (bytes.hasNext) { bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255 } else { -1 } } } def copyIn(driver: String, url: String, user: String, password: String, properties: Properties)(df: DataFrame, table: String): Unit = { df.foreachPartition { rows => Class.forName(driver) val conn = DriverManager.getConnection(url, user, password) try { val cm = new CopyManager(conn.asInstanceOf[BaseConnection]) cm.copyIn( s"COPY $table " + """FROM STDIN WITH (NULL '\N', FORMAT CSV, DELIMITER E'\t')""", rowsToInputStream(rows)) () } finally { conn.close() } } } }
Avoid call replaceAll method: https://github.com/melin/datatunnel
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite
import org.apache.spark.sql.jdbc.JdbcDialects
import java.io.InputStream
import org.apache.spark.sql.{DataFrame, Row}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
import java.nio.ByteBuffer
// https://gist.github.com/longcao/bb61f1798ccbbfa4a0d7b76e49982f84
object CopyHelper extends Logging{
private val fieldDelimiter = ",";
def rowsToInputStream(rows: Iterator[Row]): InputStream = {
val bytes: Iterator[Byte] = rows.flatMap {
row => {
val columns = row.toSeq.map { v =>
if (v == null) {
Array[Byte]('\\', 'N')
} else {
v.toString.getBytes()
}
}
val bytesSize = columns.map(_.length).sum
val byteBuffer = ByteBuffer.allocate((bytesSize * 2 + 10).toInt)
var index: Int = 0;
columns.foreach(bytes => {
if (index > 0) {
byteBuffer.put(fieldDelimiter.getBytes)
}
if (bytes.length == 2 && bytes(0) == '\\'.toByte && bytes(1) == 'N'.toByte) {
byteBuffer.put(bytes)
} else {
byteBuffer.put('"'.toByte)
bytes.foreach(ch => {
if (ch == '"'.toByte) {
byteBuffer.put('"'.toByte).put('"'.toByte)
} else {
byteBuffer.put(ch)
}
})
byteBuffer.put('"'.toByte)
}
index = index + 1
})
byteBuffer.put('\n'.toByte)
byteBuffer.flip()
val bytesArray = new Array[Byte](byteBuffer.remaining)
byteBuffer.get(bytesArray, 0, bytesArray.length)
println(new String(bytesArray))
bytesArray
}
}
() => if (bytes.hasNext) {
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
} else {
-1
}
}
def copyIn(parameters: Map[String, String])(df: DataFrame, table: String): Unit = {
df.rdd.foreachPartition { rows =>
val options = new JdbcOptionsInWrite(parameters)
val dialect = JdbcDialects.get(options.url)
val conn = dialect.createConnectionFactory(options)(-1)
try {
val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
val sql = s"COPY $table FROM STDIN WITH (NULL '\\N', FORMAT CSV, DELIMITER E'${fieldDelimiter}')";
logInfo(s"copy from sql: $sql")
//LogUtils.info(s"copy from sql: $sql")
cm.copyIn(sql, rowsToInputStream(rows))
()
} finally {
conn.close()
}
}
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the foreachPartition works on RDD objects.
Fixes it.