Last active
December 23, 2015 14:12
-
-
Save vidma/28fa938d9e51b1c2fa9b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.DataFrame | |
implicit class DataFrameWithPivot(df: DataFrame) { | |
/** | |
* Transposes metrics in multiple columns into multiple rows, with one metric per row | |
* (a.k.a unPivot) | |
* | |
* Given: | |
* |------|----------|--------| | |
* | dim1 | new_users| buyers | | |
* |------|----------|--------| | |
* | d | 100 | 20 | | |
* |------|----------|--------| | |
* | |
* Return: | |
* |------|-------------|--------------| | |
* | dim1 | metric_name | metric_value | | |
* |------|-------------|--------------| | |
* | d | new_users | 100 | | |
* | d | buyers | 20 | | |
* |------|-------------|--------------| | |
* | |
* then for instance this can be used inside of `PivotChart` by choosing metricName as Y dimension: | |
* | |
* `PivotChart(df.pivotColumnsIntoRows(metricColumns = Seq("new_users", "buyers")))` | |
*/ | |
def pivotColumnsIntoRows(metricColumns: Seq[String]): DataFrame = { | |
import org.apache.spark.sql.functions._ | |
val dimensions = df.columns.diff(metricColumns).toSeq | |
val dimensionColumns = dimensions.map(df(_)) | |
metricColumns.map { metricName => | |
val metricNameColumn = lit(metricName) as "metric_name" | |
val metricValueColumn = df(metricName) as "metric_value" | |
df.select(dimensionColumns :+ metricNameColumn :+ metricValueColumn: _*) | |
}.reduce(_ safeUnionAll _) | |
} | |
} | |
// ---------------------------------------- | |
// unit tests (scalaTest), not needed in notebook | |
describe("columnsIntoRows") { | |
it("transposes metrics in multiple columns into multiple rows, with one metric per row") { | |
val df = Seq( | |
("d", 100, 20) | |
).toDF("dim1", "new_users", "buyers") | |
df.pivotColumnsIntoRows(metricColumns = Seq("new_users", "buyers")) | |
.collect() | |
.map(_.getAllValuesAsMap) should contain theSameElementsAs Seq( | |
Map("dim1" -> "d", "metric_name" -> "new_users", "metric_value" -> 100), | |
Map("dim1" -> "d", "metric_name" -> "buyers", "metric_value" -> 20) | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment