Created
May 24, 2016 11:49
-
-
Save andfanilo/83918c32ee4a072dca1138657bcd6666 to your computer and use it in GitHub Desktop.
An implementation of DataFrame comparison functions from spark-testing-base's DataFrameSuiteBase trait in specs2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package utils | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.types.StructType | |
import org.apache.spark.sql.{DataFrame, Row} | |
import org.specs2.matcher.{Expectable, Matcher} | |
import org.specs2.mutable.Specification | |
/** | |
* Utility class to compare DataFrames and Rows inside unit tests | |
*/ | |
trait DataFrameTesting extends Specification { | |
val maxUnequalRowsToShow = 10 | |
/** | |
* Utility method to create dataframes from a sequence | |
* | |
* @param sqlContext | |
* @param seq content of dataframe | |
* @param schema schema of dataframe | |
* @return datafram with schema and content | |
*/ | |
def createTestDataframe(sqlContext: SQLContext, seq: Seq[Row], schema: StructType): DataFrame = { | |
sqlContext.createDataFrame( | |
sqlContext.sparkContext.parallelize(seq), schema | |
) | |
} | |
/** | |
* Compares if two [[DataFrame]]s are equal, checks the schema and then if that matches | |
* checks if the rows are equal. | |
*/ | |
def dataFrameEquals(res: DataFrame, tol: Double = 0.01): Matcher[DataFrame] = new Matcher[DataFrame] { | |
def apply[S <: DataFrame](s: Expectable[S]) = { | |
val expected: DataFrame = s.value | |
if (expected.schema != res.schema) { | |
result(test = false, | |
"same schema", | |
s"Schema of left dataframe \n ${expected.schema} \n " + | |
s"is different from schema of right dataframe \n ${res.schema} \n", | |
s) | |
} else if (expected.rdd.count != res.rdd.count) { | |
result(test = false, | |
"same length", | |
s"Length of left dataframe ${expected.rdd.count} is different from length of right dataframe ${res.rdd.count}", | |
s) | |
} | |
else { | |
try { | |
val toleranceValue = tol // this is to prevent serialization failures later on | |
expected.rdd.cache | |
res.rdd.cache | |
val expectedIndexValue = zipWithIndex(expected.rdd) | |
val resultIndexValue = zipWithIndex(res.rdd) | |
val unequalContent = expectedIndexValue.join(resultIndexValue) | |
.map { case (idx, (r1, r2)) => (idx, (r1, r2), DataFrameTesting.rowApproxEquals(r1, r2, toleranceValue)) | |
} | |
.filter { case (idx, (r1, r2), (isEqual, errorMessage)) => !isEqual } | |
.map { case (idx, (r1, r2), (isEqual, errorMessage)) => s"On row index $idx, \n left row ${r1.toString()} \n right row ${r2.toString()} \n $errorMessage" } | |
.collect() | |
result(unequalContent.isEmpty, | |
"same dataframe", | |
s"${unequalContent.mkString("\n")}", | |
s) | |
} finally { | |
expected.rdd.unpersist() | |
res.rdd.unpersist() | |
} | |
} | |
} | |
} | |
private def zipWithIndex[U](rdd: RDD[U]) = rdd.zipWithIndex().map { case (row, idx) => (idx, row) } | |
} | |
object DataFrameTesting { | |
/** | |
* Approximate equality between 2 rows, based on equals from [[org.apache.spark.sql.Row]] | |
* | |
* @param r1 left row to compare | |
* @param r2 right row to compare | |
* @param tol max acceptable tolerance for comparing Double values, should be less than 1 | |
* @return (true if equality respected given the tolerance, false if not; error message) | |
*/ | |
def rowApproxEquals(r1: Row, r2: Row, tol: Double): (Boolean, String) = { | |
if (r1.length != r2.length) { | |
return (false, "rows don't have the same length") | |
} else { | |
var idx = 0 | |
val length = r1.length | |
while (idx < length) { | |
if (r1.isNullAt(idx) != r2.isNullAt(idx)) | |
return (false, s"there is a null value on column ${r1.schema.fieldNames(idx)}") | |
if (!r1.isNullAt(idx)) { | |
val o1 = r1.get(idx) | |
val o2 = r2.get(idx) | |
o1 match { | |
case b1: Array[Byte] => | |
if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) return (false, s"$b1 is not equal to ${o2.asInstanceOf[Array[Byte]]} on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
case f1: Float => | |
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(o2.asInstanceOf[Float])) return (false, s"null value on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
if (Math.abs(f1 - o2.asInstanceOf[Float]) > tol) return (false, s"$f1 is not equal to ${o2.asInstanceOf[Float]} on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
case d1: Double => | |
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(o2.asInstanceOf[Double])) return (false, s"null value on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
if (Math.abs(d1 - o2.asInstanceOf[Double]) > tol) return (false, s"$d1 is not equal to ${o2.asInstanceOf[Double]} at $tol tolerance on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
case d1: java.math.BigDecimal => | |
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return (false, s"$d1 is not equal to ${o2.asInstanceOf[java.math.BigDecimal]} on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
case _ => | |
if (o1 != o2) return (false, s"$o1 is not equal to $o2 on column ${r1.schema.fieldNames(idx)} (column index $idx)") | |
} | |
} | |
idx += 1 | |
} | |
} | |
(true, "ok") | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import utils.DataFrameTesting | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.types._ | |
class TestSpec extends DataFrameTesting { | |
val schema = StructType(Seq( | |
StructField("Name", StringType, nullable = true), | |
StructField("Number", DoubleType, nullable = true) | |
)) | |
"DataFrameTesting" should { | |
"compare similar dataframes" in { | |
// Given | |
val input = createTestDataframe(Seq( | |
Row("Fanilo", 3000.0), | |
Row("Me", 3000.5), | |
Row("Myself", 4000.0) | |
), schema) | |
val expectedDataFrame = createTestDataframe(Seq( | |
Row("Fanilo", 3000.2), | |
Row("Me", 3000.3), | |
Row("Myself", 4000.0) | |
), schema) | |
// Then | |
input must dataFrameEquals(expectedDataFrame) | |
/** | |
* Output should be : | |
* On row index 0, | |
* left row [Fanilo,3000.0] | |
* right row [Fanilo,3000.2] | |
* 3000.0 is not equal to 3000.2 at 0.01 tolerance on column Number (column index 1) | |
* On row index 1, | |
* left row [Me,3000.5] | |
* right row [Maa,3000.5] | |
* Me is not equal to Maa on column Name (column index 0) | |
*/ | |
} | |
} | |
} |
can you please confirm is result() method custom or it has been imported from some API. Please provide reference.
Hey @mintuchoudhary ! I honestly don't remember, it was 6 years ago and I forgot all of my Scala days :-(
I did find part of the code I was using and it was
import org.specs2.mutable.SpecificationWithJUnit
trait DataFrameTesting extends SpecificationWithJUnit {
...
}
instead, maybe the result comes from SpecificationWithJUnit
Otherwise I don't know, sorry!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Please let me know the scala version as I am getting error.