Created
October 14, 2015 09:43
-
-
Save alexanderscott/7c9b87bdb11f4b6a4ba1 to your computer and use it in GitHub Desktop.
MySql connection pooling in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Taken largely from https://gist.github.com/tsuna/2245176 | |
import java.sql._ | |
import java.util.concurrent.ArrayBlockingQueue | |
import java.util.concurrent.Executors | |
import java.util.concurrent.ThreadFactory | |
import java.util.concurrent.atomic.{AtomicLong, AtomicInteger} | |
import java.util.concurrent.TimeUnit.MILLISECONDS | |
import scala.collection.mutable.ArrayBuffer | |
import org.slf4j.LoggerFactory | |
import com.mysql.jdbc.log.Log | |
import com.twitter.conversions.time._ | |
import com.twitter.util.Duration | |
import com.twitter.util.Future | |
import com.twitter.util.FuturePool | |
import scala.collection.{JavaConversions, JavaConverters} | |
import java.util.concurrent.atomic.LongAdder | |
/** | |
* Configuration for a connection Pool. | |
* @param servers A list of "ip:port". | |
* @param user MySQL username to connect as. | |
* @param pass MySQL password for that user. | |
* @param schema Name of the DB schema to use. | |
*/ | |
final case class PoolConfig(servers: Seq[String], user: String, pass: String, schema: String) { | |
/** Like `equals' but ignores the order in `servers' in case they were shuffled. */ | |
def equivalentTo(other: PoolConfig): Boolean = | |
user == other.user && pass == other.pass && schema == other.schema && | |
servers.sorted == other.servers.sorted | |
} | |
/** | |
* Wrapper for JDBC connections. | |
* We have to wrap every connection just so we can remember which server this | |
* connection is connected to, so we can reconnect when something bad happens. | |
* Because, yes, believe it or not, there's no way to reliably extract this | |
* information from a JDBC connection object. | |
* @param server A "host:port" string. | |
*/ | |
final case class MySQLConnection(server: String, connection: Connection) { | |
def prepareStatement(sql: String) = connection.prepareStatement(sql) | |
def close() = connection.close() | |
} | |
/** | |
* A connection pool for asynchronous operations. | |
* For each connection, there is a dedicated thread, because MySQL doesn't | |
* have an asynchronous RPC protocol, and because JDBC doesn't have an | |
* asynchronous API. | |
* @param cfg Configuration for this connection pool. | |
* @param options A "query string" passed as-is in the JDBC URL. | |
* @param readonly Whether or not to set the connection read-only mode. | |
* @param appName Name of the current app (e.g. "honeybadger"). | |
*/ | |
final class ConnectionPool(cfg: PoolConfig, | |
options: String, | |
val readonly: Boolean, | |
appName: String) { | |
import ConnectionPool._ | |
ensureDriverLoaded | |
@volatile private[this] var conf = cfg | |
private[this] val pool = makePool(cfg.servers.length) | |
@volatile private[this] var connections: ArrayBlockingQueue[MySQLConnection] = _ | |
createConnections() // Populates `connections'. | |
/** Returns the current configuration of this pool. */ | |
def config = conf | |
/** | |
* Attempts to apply the new configuration given to this pool. | |
* Changes are applied atomically without disruptive ongoing traffic. | |
* If successful, this closes all the connections and replaces them all with | |
* new connections. | |
* If there's an exception thrown, changes are rolled back first and both | |
* the configuration and the connection pool will remain unchanged. | |
* <strong>WARNING:</strong> this function is blocking, and might take a | |
* while (maybe several seconds) to return. | |
* @param newcfg The new configuration to apply to this pool. The | |
* configuration is assumed to be sane. | |
* @throws SQLException if something bad happens (e.g. being unable to open | |
* a connection to any one of the hosts for whatever reason). | |
*/ | |
def updateConfig(newcfg: PoolConfig) { | |
// Almost everything we do is thread-safe but in order to guarantee that | |
// we can correctly rollback the changes in case of an exception, and in | |
// order to ensure that we only attempt to apply one change at a time, | |
// it's much safer and easier to make this entire method synchronized. | |
synchronized { | |
val prevconns = connections // volatile-read | |
val prevcfg = conf // volatile-read | |
try { | |
conf = newcfg | |
createConnections() // volatile-write on connections | |
// Success! Now dispose of the previous connections, to not leak them. | |
try { | |
closeAllConnections(prevcfg, prevconns) | |
} catch { | |
case e: Exception => | |
log.warn("Uncaught exception while closing an old connection after" | |
+ " reloading a new configuration", e) | |
} | |
} catch { | |
case e: Exception => | |
// Roll-back. | |
connections = prevconns // volatile-write | |
conf = prevcfg // volatile-write | |
throw e | |
} | |
} | |
} | |
/** Creates and populates all the connections for this pool .*/ | |
private def createConnections() { | |
val newconns = new ArrayBlockingQueue[MySQLConnection](conf.servers.length) | |
conf.servers foreach { server => // server is already "ip:port". | |
newconns.add(newConnection(server)) | |
} | |
connections = newconns // commit: volatile-write | |
} | |
/** How many queries did we send to MySQL. */ | |
private[this] val queries = new LongAdder() | |
/** How many exceptions we got from JDBC. */ | |
private[this] val exceptions = new LongAdder() | |
/** Returns the number of queries sent to MySQL. */ | |
def queriesSent: Long = queries.longValue() | |
/** Returns the number of exception caught while MySQL stuff. */ | |
def exceptionsCaught: Long = exceptions.longValue() | |
/** Closes all connections and releases all threads. */ | |
def shutdown() { | |
pool.executor.shutdown() | |
closeAllConnections(conf, connections) | |
} | |
/** | |
* Executes a SELECT statement on the database. | |
* @param f Function called on each row returned by the database. This | |
* function is called with the connection locked, so if this function takes | |
* time, it will prevent the connection from beind reused for another query. | |
* @param sql The SQL statement, e.g. "SELECT foo FROM t WHERE id = ?" | |
* @param params The parameters to substitute in the `?' placeholders. | |
* These parameters don't need to be escaped as prepared statements are used | |
* and they already prevent SQL injections. | |
* @return A future sequence of things returned by `f'. | |
* @throws SQLException (async) if something bad happens (sorry I don't know more). | |
*/ | |
def select[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = { | |
pool(execute(f, "/*" + appName + "*/ " + sql, params)) | |
} | |
def insert[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = { | |
pool(execute(f, sql, params)) | |
} | |
def update[T](f: ResultSet => T, sql: String, params: Seq[Any]): Future[Seq[T]] = { | |
pool(execute(f, sql, params)) | |
} | |
// TODO(tsuna): Provide code for insert, update etc, not just select. | |
def execute[T](f: ResultSet => T, sql: String, params: Seq[Any]): Seq[T] = { | |
queries.increment() | |
val connpool = this.connections // volatile-read | |
var connection = connpool.poll | |
if (connection == null) { // Should never happen. | |
// We have as many threads as connections so this can only happen if a | |
// thread is leaking a connection, which would be really bad. | |
val e = new IllegalStateException("WTF? Couldn't get a connection from the pool.") | |
exceptions.increment | |
log.error(e.getMessage) | |
throw e | |
} | |
try { | |
val statement = connection.prepareStatement(sql) | |
try { | |
bindParameters(statement, params) | |
if (log.isDebugEnabled) | |
log.debug(connection.server + ": " + sql | |
+ " " + params.mkString("(", ", ", ")")) | |
val rs = statement.executeQuery | |
try { | |
val results = new ArrayBuffer[T] | |
while (rs.next) { | |
results += f(rs) | |
} | |
results | |
} finally { | |
rs.close() | |
} | |
} finally { | |
statement.close() | |
} | |
} catch { | |
case e: SQLSyntaxErrorException => | |
logAndRethrow(connection, "Syntax error in SQL query", | |
sql, params, e) | |
case e: SQLIntegrityConstraintViolationException => | |
logAndRethrow(connection, "Integrity constraint violated by SQL query", | |
sql, params, e) | |
case e: SQLFeatureNotSupportedException => | |
logAndRethrow(connection, "Feature not supported in SQL query", | |
sql, params, e) | |
case e: SQLDataException => | |
logAndRethrow(connection, "Data exception caused by SQL query", | |
sql, params, e) | |
case e @ (_: SQLRecoverableException | _: SQLNonTransientException) => | |
// The remaining kinds of SQLNonTransientException are typically | |
// connection-level problems, so let's close this connection and get a | |
// new one. | |
// For a SQLRecoverableException the JDK javadoc manual says that "the | |
// recovery operation must include closing the current connection and | |
// getting a new connection". | |
connection.close() // If we double-close it's OK, it's a no-op. | |
// Create a new connection, the `finally' block below will put it back | |
// in the pool. | |
connection = newConnection(connection.server) | |
// TODO(tsuna): If we wanted we could retry once here. | |
logAndRethrow(connection, "Error on connection when trying to execute", | |
sql, params, e) | |
case e: Throwable => | |
// TODO(tsuna): Should we close the connection here? I'm not sure. | |
logAndRethrow(connection, "Uncaught exception", sql, params, e) | |
} finally { | |
// Always return the connection to the pool. | |
connpool.put(connection) | |
} | |
} | |
/** Logs an exception and rethrows it. */ | |
private def logAndRethrow(connection: MySQLConnection, msg: String, | |
sql: String, params: Seq[Any], e: Throwable) = { | |
// This function must never throw a new exception of its own. | |
exceptions.increment() | |
val cause = new StringBuilder | |
var exception = e | |
// Get names & messages of all exceptions in the chain. | |
while (exception != null) { | |
cause.append(", caused by ") | |
.append(e.getClass.getName) | |
.append(": ") | |
.append(e.getMessage) | |
exception = exception.getCause // previous exception causing this one. | |
} | |
log.error(connection.server + ": " + msg + ": " + sql | |
+ " with params " + params.mkString("(", ", ", ")") | |
+ cause) | |
throw e | |
} | |
private def bindParameters(statement: PreparedStatement, | |
params: TraversableOnce[Any]) { | |
bindParameters(statement, 1, params) | |
} | |
private def bindParameters(statement: PreparedStatement, | |
startIndex: Int, | |
params: TraversableOnce[Any]): Int = { | |
var index = startIndex | |
for (param <- params) { | |
param match { | |
case i: Int => statement.setInt(index, i) | |
case l: Long => statement.setLong(index, l) | |
case s: String => statement.setString(index, s) | |
case l: TraversableOnce[_] => | |
index = bindParameters(statement, index, l) - 1 | |
case p: Product => | |
index = bindParameters(statement, index, p.productIterator.toList) - 1 | |
//case ab: Array[Byte] => statement.setBytes(index, ab) | |
case b: Boolean => statement.setBoolean(index, b) | |
case s: Short => statement.setShort(index, s) | |
case f: Float => statement.setFloat(index, f) | |
case d: Double => statement.setDouble(index, d) | |
case _ => | |
throw new IllegalArgumentException("Unsupported data type " | |
+ param.asInstanceOf[AnyRef].getClass.getName + ": " + param) | |
} | |
index += 1 | |
} | |
index | |
} | |
/** | |
* Returns a new MySQL connection. | |
* @param server A "host:port" string. | |
*/ | |
private def newConnection(server: String): MySQLConnection = { | |
val connection = | |
DriverManager.getConnection("jdbc:mysql://" + server + "/" + conf.schema + jdbcOptions, | |
conf.user, conf.pass) | |
connection.setReadOnly(readonly) | |
MySQLConnection(server, connection) | |
} | |
override def toString = "ConnectionPool(" + conf + ")" | |
} | |
object ConnectionPool { | |
private val log = LoggerFactory.getLogger(getClass) | |
private def ensureDriverLoaded = | |
// Load the MySQL JDBC driver. Yeah this looks like it has no side | |
// effect but it's required as it causes the driver to register itself | |
// with the JDBC DriverManager. Awesome design, right? | |
if (classOf[com.mysql.jdbc.Driver] == null) | |
throw new AssertionError("MySQL JDBC connector missing.") | |
/** Default options we use to connect to MySQL */ | |
val jdbcOptions: String = "?" + { | |
val options = Map( | |
"connectTimeout" -> 4.seconds, | |
"socketTimeout" -> 2.seconds, | |
"useServerPrepStmts" -> true, | |
"cachePrepStmts" -> true, | |
"cacheResultSetMetadata" -> true, | |
"cacheServerConfiguration" -> true, | |
"logger" -> classOf[MySQLLogger] | |
) | |
options.toList.map { case (option, value) => | |
option + "=" + (value match { | |
case c: Class[_] => c.getName | |
case d: Duration => d.inMilliseconds | |
case _ => value | |
}) | |
} mkString "&" | |
} | |
def readOnly(config: PoolConfig, appName: String): ConnectionPool = | |
new ConnectionPool(config, jdbcOptions, true, appName) | |
def readWrite(config: PoolConfig, appName: String): ConnectionPool = | |
new ConnectionPool(config, jdbcOptions, false, appName) | |
/** Creates a thread-pool to use the given connections. */ | |
private def makePool(size: Int) = { | |
val factory = new ThreadFactory { | |
val id = new AtomicInteger(0) | |
def newThread(r: Runnable) = | |
new Thread(r, "MySQL-" + id.incrementAndGet) | |
} | |
FuturePool(Executors.newFixedThreadPool(size, factory)) | |
} | |
/** | |
* Closes all the connections from the given pool with the given config. | |
* <strong>WARNING:</strong> this function is blocking, and might take a | |
* while (maybe several seconds) to clear up the pool. | |
*/ | |
private def closeAllConnections(conf: PoolConfig, | |
connections: ArrayBlockingQueue[MySQLConnection]) { | |
for (i <- 1 to conf.servers.length) { | |
// We're not serving a query to an end-user, and our goal is to | |
// close the connection but we don't want to wait forever in case | |
// the connection is somehow badly stuck. So allow quite a bit of | |
// time to grab a connection. | |
val connection = connections.poll(500, MILLISECONDS) | |
if (connection == null) { | |
log.error("Timeout while trying to get connection #" + i + " / " | |
+ conf.servers.length + ", connection will be leaked.") | |
} else { | |
connection.close() | |
} | |
} | |
} | |
} | |
/** Class for MySQL's JDBC logging (otherwise it goes to stderr by default). */ | |
private final class MySQLLogger(name: String) extends Log { | |
val log = LoggerFactory.getLogger(name) | |
def isDebugEnabled: Boolean = log.isDebugEnabled | |
def isErrorEnabled: Boolean = log.isErrorEnabled | |
def isFatalEnabled: Boolean = log.isErrorEnabled | |
def isInfoEnabled: Boolean = log.isInfoEnabled | |
def isTraceEnabled: Boolean = log.isTraceEnabled | |
def isWarnEnabled: Boolean = log.isWarnEnabled | |
private def cast(msg: Any): String = | |
msg match { | |
case m: String => m | |
case _ => | |
throw new ClassCastException("argument isn't a String but a " | |
+ msg.asInstanceOf[AnyRef].getClass.getName + ": " + msg) | |
} | |
def logDebug(msg: Any) { | |
log.debug(cast(msg)) | |
} | |
def logDebug(msg: Any, e: Throwable) { | |
log.debug(cast(msg), e) | |
} | |
def logError(msg: Any) { | |
log.error(cast(msg)) | |
} | |
def logError(msg: Any, e: Throwable) { | |
log.error(cast(msg), e) | |
} | |
def logFatal(msg: Any) { | |
log.error("** FATAL ** " + cast(msg)) // Keep going anyway. | |
} | |
def logFatal(msg: Any, e: Throwable) { | |
log.error("** FATAL ** " + cast(msg), e) // Keep going anyway. | |
} | |
def logInfo(msg: Any) { | |
log.info(cast(msg)) | |
} | |
def logInfo(msg: Any, e: Throwable) { | |
log.info(cast(msg), e) | |
} | |
def logTrace(msg: Any) { | |
log.trace(cast(msg)) | |
} | |
def logTrace(msg: Any, e: Throwable) { | |
log.trace(cast(msg), e) | |
} | |
def logWarn(msg: Any) { | |
log.warn(cast(msg)) | |
} | |
def logWarn(msg: Any, e: Throwable) { | |
log.warn(cast(msg), e) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment