Created
February 26, 2018 20:59
-
-
Save Daenyth/1f6f1ffa8f70349185403c6d05ef26a3 to your computer and use it in GitHub Desktop.
A slick profile extension to allow native postgres batch upsert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.github.tminglei.slickpg.ExPostgresProfile | |
import slick.SlickException | |
import slick.ast.ColumnOption.PrimaryKey | |
import slick.ast.{ColumnOption, FieldSymbol, Insert, Node, Select} | |
import slick.compiler.{InsertCompiler, Phase, QueryCompiler} | |
import slick.dbio.{Effect, NoStream} | |
import slick.jdbc.InsertBuilderResult | |
import slick.lifted.Query | |
// format: off | |
/** | |
* This file contains support for multi-inserting postgres ON CONFLICT based upserts. | |
*/ | |
object SlickUpsert { | |
trait MultiUpsertPgProfile extends ExPostgresProfile { | |
override def createInsertActionExtensionMethods[T]( | |
compiled: CompiledInsert): InsertActionComposerImpl[T] = | |
new InsertActionComposerImpl[T](compiled) | |
implicit def multiUpsertExtensionMethods[U, C[_]]( | |
q: Query[_, U, C]): InsertActionComposerImpl[U] = | |
createInsertActionExtensionMethods[U](compileInsert(q.toNode)) | |
lazy val multiUpsertCompiler = QueryCompiler( | |
Phase.assignUniqueSymbols, | |
Phase.inferTypes, | |
new InsertCompiler(InsertCompiler.AllColumns), | |
new JdbcInsertCodeGen(insert => new MultiUpsertBuilder(insert))) | |
/** | |
* See NativeUpsertBuilder for reference. | |
*/ | |
class MultiUpsertBuilder(ins: Insert) extends super.InsertBuilder(ins) { | |
/* NOTE: pk defined by using method `primaryKey` and pk defined with `PrimaryKey` can only have one, | |
here we let table ddl to help us ensure this. */ | |
private lazy val funcDefinedPKs = | |
table.profileTable.asInstanceOf[Table[_]].primaryKeys | |
private lazy val (nonPkAutoIncSyms, insertingSyms) = | |
syms.toSeq.partition { s => | |
s.options.contains(ColumnOption.AutoInc) && !(s.options contains ColumnOption.PrimaryKey) | |
} | |
private lazy val (pkSyms, softSyms) = insertingSyms.partition { sym => | |
sym.options.contains(ColumnOption.PrimaryKey) || funcDefinedPKs.exists(pk => | |
pk.columns | |
.collect { case Select(_, f: FieldSymbol) => f } | |
.exists(_.name == sym.name)) | |
} | |
private lazy val pkNames = pkSyms.map { fs => quoteIdentifier(fs.name) } | |
private lazy val softNames = softSyms.map { fs => quoteIdentifier(fs.name) } | |
override def buildInsert: InsertBuilderResult = { | |
val start = allNames.iterator.mkString(s"insert into $tableName (", ",", ") ") | |
val insert = s"$start values $allVars" | |
val conflictWithPadding = "conflict (" + pkNames.mkString(", ") + ")" + ( | |
if (nonPkAutoIncSyms.isEmpty) "" else "where ? is null or ?=?" | |
) | |
val updateOrNothing = | |
if (softNames.isEmpty) "nothing" | |
else "update set " + softNames.map(n => s"$n=EXCLUDED.$n").mkString(",") | |
new InsertBuilderResult(table, s"$insert on $conflictWithPadding do $updateOrNothing", syms) | |
} | |
override def transformMapping(n: Node) = | |
reorderColumns(n, insertingSyms ++ nonPkAutoIncSyms ++ nonPkAutoIncSyms ++ nonPkAutoIncSyms) | |
} | |
class MultiUpsertCompiledInsert(node: Node) | |
extends JdbcCompiledInsert(node) { | |
lazy val multiUpsert = compile(multiUpsertCompiler) | |
} | |
override def compileInsert(tree: Node) = new MultiUpsertCompiledInsert(tree) | |
protected class InsertActionComposerImpl[U]( | |
override val compiled: CompiledInsert) | |
extends super.CountingInsertActionComposerImpl[U](compiled) { | |
/** Insert a single row if its primary key does not exist in the table, | |
* otherwise update the existing record. */ | |
def insertOrUpdateAll(values: Iterable[U]): ProfileAction[MultiInsertResult, NoStream, Effect.Write] = | |
new MultiInsertOrUpdateAction(values) | |
class MultiInsertOrUpdateAction(values: Iterable[U]) | |
extends SimpleJdbcProfileAction[MultiInsertResult]( | |
"MultiInsertOrUpdateAction", | |
Vector( | |
compiled | |
.asInstanceOf[MultiUpsertCompiledInsert] | |
.multiUpsert | |
.sql)) { | |
private def tableHasPrimaryKey: Boolean = | |
List(compiled.upsert, compiled.checkInsert, compiled.updateInsert) | |
.filter(_ != null) | |
.exists(artifacts => | |
artifacts.ibr.table.profileTable.asInstanceOf[Table[_]].primaryKeys.nonEmpty | |
|| artifacts.ibr.fields.exists(_.options.contains(PrimaryKey)) | |
) | |
if (!tableHasPrimaryKey) | |
throw new SlickException("InsertOrUpdate is not supported on a table without PK.") | |
override def run(ctx: Backend#Context, sql: Vector[String]) = | |
nativeUpsert(values, sql.head)(ctx.session) | |
protected def nativeUpsert(values: Iterable[U], sql: String)( | |
implicit session: Backend#Session): MultiInsertResult = | |
preparedInsert(sql, session) { st => | |
st.clearParameters() | |
for (value <- values) { | |
compiled | |
.asInstanceOf[MultiUpsertCompiledInsert] | |
.multiUpsert | |
.converter | |
.set(value, st) | |
st.addBatch() | |
} | |
val counts = st.executeBatch() | |
retManyBatch(st, values, counts) | |
} | |
} | |
} | |
} | |
} | |
// format: on |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment