Last active
September 1, 2016 18:30
-
-
Save ssimeonov/89c1e57474e38e7d05f55e5687708ee7 to your computer and use it in GitHub Desktop.
Querying DataFrame with SQL without explicit registration of a temporary table
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object DataFrameFunctions { | |
final val TEMP_TABLE_PLACEHOLDER = "~tbl~" | |
/** Executes a SQL statement on the dataframe. | |
* Behind the scenes, it registers and cleans up a temporary table. | |
* | |
* @param df input dataframe | |
* @param stmtTemplate SQL statement template that uses the value of | |
* `TEMP_TABLE_PLACEHOLDER` for the table name. | |
* @return the dataframe which is the output of the SQL statement | |
*/ | |
def sql(df: DataFrame, stmtTemplate: String): DataFrame = | |
withTempTable(df, (tableName: String) => { | |
val stmt = stmtTemplate.replace(TEMP_TABLE_PLACEHOLDER, tableName) | |
df.sqlContext.sql(stmt) | |
}) | |
/** Registers the dataframe as a temp table and executes a function passing | |
* in the name of the just created temporary table. Cleans up at the end. | |
* | |
* @param df input dataframe | |
* @param f transformation function | |
* @tparam B return type of the transformation function | |
* @return result of the transformation function | |
*/ | |
def withTempTable[B](df: DataFrame, f: Function1[String, B]): B = { | |
val name = safeRegisterTempTable(df) | |
try { | |
f(name) | |
} finally { | |
df.sqlContext.ensureNoTempTable(name) | |
} | |
} | |
/** | |
* Registers the dataframe with a unique temp table name | |
* | |
* @param df input dataframe | |
* @param prefix prefix for the temp table | |
* @return the name of the temp table `tmp_{prefix}_{millis}_{uuid}` | |
*/ | |
def safeRegisterTempTable(df: DataFrame, prefix: String = "tbl"): String = { | |
val uuid = java.util.UUID.randomUUID.toString.replace("-", "") | |
val name = s"tmp_${prefix}_${DateTime.now.getMillis}_$uuid" | |
df.registerTempTable(name) | |
name | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
implicit class DataFrameOps(val underlying: DataFrame) extends AnyVal { | |
def sql(stmtTemplate: String): DataFrame = | |
DataFrameFunctions.sql(underlying, stmtTemplate) | |
def withTempTable[B](f: Function1[String, B]): B = | |
DataFrameFunctions.withTempTable(underlying, f) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment