Last active
February 12, 2020 06:13
-
-
Save umbertogriffo/112a02848d8be269f23757c9656df908 to your computer and use it in GitHub Desktop.
DataFrameSuite allows you to check if two DataFrames are equal. You can assert the DataFrames equality using method assertDataFrameEquals. When DataFrames contains doubles or Spark Mllib Vector, you can assert that the DataFrames approximately equal using method assertDataFrameApproximateEquals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package test.com.idlike.junit.df | |
import breeze.numerics.abs | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.functions.col | |
import org.apache.spark.sql.{Column, DataFrame, Row} | |
/** | |
* Created by Umberto on 06/02/2017. | |
*/ | |
object DataFrameSuite { | |
/** | |
* Compares if two [[DataFrame]]s are equal. | |
* This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders. | |
* 1. Check two schemas are equal | |
* 2. Check the number of rows are equal | |
* 3. Check there is no unequal rows | |
* | |
* @param a DataFrame | |
* @param b DataFrame | |
* @param isRelaxed Boolean | |
* @return | |
*/ | |
def assertDataFrameEquals(a: DataFrame, b: DataFrame, isRelaxed: Boolean): Boolean = { | |
try { | |
a.rdd.cache | |
b.rdd.cache | |
// 1. Check the equality of two schemas | |
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) { | |
return false | |
} | |
// 2. Check the number of rows in two dfs | |
if (a.count() != b.count()) { | |
return false | |
} | |
// 3. Check there is no unequal rows | |
val aColumns: Array[String] = a.columns | |
val bColumns: Array[String] = b.columns | |
// To correctly handles cases where the DataFrames may have columns in different orders | |
scala.util.Sorting.quickSort(aColumns) | |
scala.util.Sorting.quickSort(bColumns) | |
val aSeq: Seq[Column] = aColumns.map(col(_)) | |
val bSeq: Seq[Column] = bColumns.map(col(_)) | |
var a_prime: DataFrame = null | |
var b_prime: DataFrame = null | |
if (isRelaxed) { | |
a_prime = a | |
// a_prime.show() | |
b_prime = b | |
// a_prime.show() | |
} | |
else { | |
// To correctly handles cases where the DataFrames may have duplicate rows and/or rows in different orders | |
a_prime = a.sort(aSeq: _*).groupBy(aSeq: _*).count() | |
// a_prime.show() | |
b_prime = b.sort(aSeq: _*).groupBy(bSeq: _*).count() | |
// a_prime.show() | |
} | |
val c1: Long = a_prime.except(b_prime).count() | |
val c2: Long = b_prime.except(a_prime).count() | |
if (c1 != c2 || c1 != 0 || c2 != 0) { | |
return false | |
} | |
} finally { | |
a.rdd.unpersist() | |
b.rdd.unpersist() | |
} | |
true | |
} | |
/** | |
* Compares if two [[DataFrame]]s containing double are equal. | |
* 1. Check two schemas are equal | |
* 2. Check the number of rows are equal | |
* 3. Check there is no unequal rows | |
* | |
* @param tol max acceptable tolerance, should be less than 1. | |
*/ | |
def assertDataFrameApproximateEquals(a: DataFrame, b: DataFrame, tol: Double): Boolean = { | |
try { | |
a.rdd.cache | |
b.rdd.cache | |
// 1. Check the equality of two schemas | |
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) { | |
return false | |
} | |
// 2. Check the number of rows in two dfs | |
if (a.count() != b.count()) { | |
return false | |
} | |
// 3. Check there is no unequal rows | |
val aIndexValue = zipWithIndex(a.rdd) | |
val bIndexValue = zipWithIndex(b.rdd) | |
val unequalRDD = aIndexValue.join(bIndexValue).filter { case (idx, (r1, r2)) => | |
!DataFrameSuite.approxEquals(r1, r2, tol) | |
} | |
if (unequalRDD.take(1).length != 0) { | |
return false; | |
} | |
} finally { | |
a.rdd.unpersist() | |
b.rdd.unpersist() | |
} | |
true | |
} | |
def zipWithIndex[U](rdd: RDD[U]) = rdd.zipWithIndex().map { case (row, idx) => (idx, row) } | |
/** | |
* Approximate equality, based on equals from [[Row]] | |
* | |
* @param r1 | |
* @param r2 | |
* @param tol | |
* @return | |
*/ | |
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { | |
if (r1.length != r2.length) { | |
return false | |
} else { | |
var idx = 0 | |
val length = r1.length | |
while (idx < length) { | |
if (r1.isNullAt(idx) != r2.isNullAt(idx)) | |
return false | |
if (!r1.isNullAt(idx)) { | |
val o1 = r1.get(idx) | |
val o2 = r2.get(idx) | |
o1 match { | |
case b1: Array[Byte] => | |
if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) return false | |
case f1: Float => | |
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(o2.asInstanceOf[Float])) return false | |
if (abs(f1 - o2.asInstanceOf[Float]) > tol) return false | |
case d1: Double => | |
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(o2.asInstanceOf[Double])) return false | |
if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false | |
case d1: java.math.BigDecimal => | |
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false | |
case d1: org.apache.spark.ml.linalg.Vector => | |
val arr1: Array[Double] = d1.toArray | |
val arr2: Array[Double] = o2.asInstanceOf[org.apache.spark.ml.linalg.Vector].toArray | |
if (arr1.length != arr2.length) return false | |
for (i <- 0 to (arr1.length - 1)) { | |
if (abs(arr1(i) - arr2(i)) > tol) return false | |
} | |
case _ => | |
if (o1 != o2) return false | |
} | |
} | |
idx += 1 | |
} | |
} | |
true | |
} | |
} |
Line 49: "scala.util.Sorting.quickSort(aColumns)" -> "scala.util.Sorting.quickSort(bColumns)"
Here's a really good book to help you write effective scala: https://www.amazon.co.uk/Scala-Impatient-Cay-S-Horstmann/dp/0134540565/
Line 49: "scala.util.Sorting.quickSort(aColumns)" -> "scala.util.Sorting.quickSort(bColumns)"
@alexkon Thank you very much! I've update It
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
what is "isRelaxed"?