Skip to content

Instantly share code, notes, and snippets.

@longcao
Last active September 11, 2024 18:55
Show Gist options
  • Save longcao/bb61f1798ccbbfa4a0d7b76e49982f84 to your computer and use it in GitHub Desktop.
Save longcao/bb61f1798ccbbfa4a0d7b76e49982f84 to your computer and use it in GitHub Desktop.
COPY Spark DataFrame rows to PostgreSQL (via JDBC)
import java.io.InputStream
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided
val connectionProperties = {
val props = new java.util.Properties()
props.setProperty("driver", "org.postgresql.Driver")
props
}
// Spark reads the "driver" property to allow users to override the default driver selected, otherwise
// it picks the Redshift driver, which doesn't support JDBC CopyManager.
// https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51
val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties)
// Convert every partition (an `Iterator[Row]`) to bytes (InputStream)
def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
val bytes: Iterator[Byte] = rows.map { row =>
(row.mkString(delimiter) + "\n").getBytes
}.flatten
new InputStream {
override def read(): Int = if (bytes.hasNext) {
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
} else {
-1
}
}
}
// Beware: this will open a db connection for every partition of your DataFrame.
frame.foreachPartition { rows =>
val conn = cf()
val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
cm.copyIn(
"""COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html
rowsToInputStream(rows, "\t"))
conn.close()
}
@melin
Copy link

melin commented Dec 26, 2023

Slightly modified to deal with escaping and Spark 2.2

import java.io.InputStream
import java.sql.DriverManager
import java.util.Properties

import org.apache.spark.sql.{DataFrame, Row}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection

object CopyHelper {

  def rowsToInputStream(rows: Iterator[Row]): InputStream = {
    val bytes: Iterator[Byte] = rows.map { row =>
      (row.toSeq
        .map { v =>
          if (v == null) {
            """\N"""
          } else {
            "\"" + v.toString.replaceAll("\"", "\"\"") + "\""
          }
        }
        .mkString("\t") + "\n").getBytes
    }.flatten

    new InputStream {
      override def read(): Int =
        if (bytes.hasNext) {
          bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
        } else {
          -1
        }
    }
  }

  def copyIn(driver: String,
             url: String,
             user: String,
             password: String,
             properties: Properties)(df: DataFrame, table: String): Unit = {
    df.foreachPartition { rows =>
      Class.forName(driver)

      val conn = DriverManager.getConnection(url, user, password)

      try {
        val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
        cm.copyIn(
          s"COPY $table " + """FROM STDIN WITH (NULL '\N', FORMAT CSV, DELIMITER E'\t')""",
          rowsToInputStream(rows))
        ()
      } finally {
        conn.close()
      }
    }
  }
}

Avoid call replaceAll method: https://github.com/melin/datatunnel

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite
import org.apache.spark.sql.jdbc.JdbcDialects

import java.io.InputStream
import org.apache.spark.sql.{DataFrame, Row}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection

import java.nio.ByteBuffer

// https://gist.github.com/longcao/bb61f1798ccbbfa4a0d7b76e49982f84
object CopyHelper extends Logging{

  private val fieldDelimiter = ",";

  def rowsToInputStream(rows: Iterator[Row]): InputStream = {
    val bytes: Iterator[Byte] = rows.flatMap {
      row => {
        val columns = row.toSeq.map { v =>
          if (v == null) {
            Array[Byte]('\\', 'N')
          } else {
            v.toString.getBytes()
          }
        }

        val bytesSize = columns.map(_.length).sum
        val byteBuffer = ByteBuffer.allocate((bytesSize * 2 + 10).toInt)

        var index: Int = 0;
        columns.foreach(bytes => {
          if (index > 0) {
            byteBuffer.put(fieldDelimiter.getBytes)
          }

          if (bytes.length == 2 && bytes(0) == '\\'.toByte && bytes(1) == 'N'.toByte) {
            byteBuffer.put(bytes)
          } else {
            byteBuffer.put('"'.toByte)
            bytes.foreach(ch => {
              if (ch == '"'.toByte) {
                byteBuffer.put('"'.toByte).put('"'.toByte)
              } else {
                byteBuffer.put(ch)
              }
            })
            byteBuffer.put('"'.toByte)
          }

          index = index + 1
        })

        byteBuffer.put('\n'.toByte)
        byteBuffer.flip()
        val bytesArray = new Array[Byte](byteBuffer.remaining)
        byteBuffer.get(bytesArray, 0, bytesArray.length)
        println(new String(bytesArray))
        bytesArray
      }
    }

    () => if (bytes.hasNext) {
      bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
    } else {
      -1
    }
  }

  def copyIn(parameters: Map[String, String])(df: DataFrame, table: String): Unit = {
    df.rdd.foreachPartition { rows =>
      val options = new JdbcOptionsInWrite(parameters)
      val dialect = JdbcDialects.get(options.url)
      val conn = dialect.createConnectionFactory(options)(-1)
      try {
        val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
        val sql = s"COPY $table FROM STDIN WITH (NULL '\\N', FORMAT CSV, DELIMITER E'${fieldDelimiter}')";
        logInfo(s"copy from sql: $sql")
        //LogUtils.info(s"copy from sql: $sql")
        cm.copyIn(sql, rowsToInputStream(rows))
        ()
      } finally {
        conn.close()
      }
    }
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment