Created
June 13, 2018 16:41
-
-
Save alexkon/90d602b17404db2c2857497754bb8d6d to your computer and use it in GitHub Desktop.
DataFrameSuite allows you to check if two DataFrames are equal. You can assert the DataFrames equality using method assertDataFrameEquals. When DataFrames contains doubles or Spark Mllib Vector, you can assert that the DataFrames approximately equal using method assertDataFrameApproximateEquals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import breeze.numerics.abs | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.functions.col | |
import org.apache.spark.sql.{Column, DataFrame, Row} | |
/** | |
* Originally created by Umberto on 06/02/2017 (https://gist.github.com/umbertogriffo/112a02848d8be269f23757c9656df908). Added minor fix by alexkon. | |
*/ | |
object DataFrameSuite { | |
/** | |
* Compares if two [[DataFrame]]s are equal. | |
* This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders. | |
* 1. Check two schemas are equal | |
* 2. Check the number of rows are equal | |
* 3. Check there is no unequal rows | |
* | |
* @param a DataFrame | |
* @param b DataFrame | |
* @param isRelaxed Boolean | |
* @return | |
*/ | |
def assertDataFrameEquals(a: DataFrame, b: DataFrame, isRelaxed: Boolean): Boolean = { | |
try { | |
a.rdd.cache | |
b.rdd.cache | |
// 1. Check the equality of two schemas | |
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) { | |
return false | |
} | |
// 2. Check the number of rows in two dfs | |
if (a.count() != b.count()) { | |
return false | |
} | |
// 3. Check there is no unequal rows | |
val aColumns: Array[String] = a.columns | |
val bColumns: Array[String] = b.columns | |
// To correctly handles cases where the DataFrames may have columns in different orders | |
scala.util.Sorting.quickSort(aColumns) | |
scala.util.Sorting.quickSort(bColumns) | |
val aSeq: Seq[Column] = aColumns.map(col(_)) | |
val bSeq: Seq[Column] = bColumns.map(col(_)) | |
var a_prime: DataFrame = null | |
var b_prime: DataFrame = null | |
if (isRelaxed) { | |
a_prime = a | |
// a_prime.show() | |
b_prime = b | |
// a_prime.show() | |
} | |
else { | |
// To correctly handles cases where the DataFrames may have duplicate rows and/or rows in different orders | |
a_prime = a.sort(aSeq: _*).groupBy(aSeq: _*).count() | |
// a_prime.show() | |
b_prime = b.sort(aSeq: _*).groupBy(bSeq: _*).count() | |
// a_prime.show() | |
} | |
val c1: Long = a_prime.except(b_prime).count() | |
val c2: Long = b_prime.except(a_prime).count() | |
if (c1 != c2 || c1 != 0 || c2 != 0) { | |
return false | |
} | |
} finally { | |
a.rdd.unpersist() | |
b.rdd.unpersist() | |
} | |
true | |
} | |
/** | |
* Compares if two [[DataFrame]]s containing double are equal. | |
* 1. Check two schemas are equal | |
* 2. Check the number of rows are equal | |
* 3. Check there is no unequal rows | |
* | |
* @param tol max acceptable tolerance, should be less than 1. | |
*/ | |
def assertDataFrameApproximateEquals(a: DataFrame, b: DataFrame, tol: Double): Boolean = { | |
try { | |
a.rdd.cache | |
b.rdd.cache | |
// 1. Check the equality of two schemas | |
if (!a.schema.toString().equalsIgnoreCase(b.schema.toString)) { | |
return false | |
} | |
// 2. Check the number of rows in two dfs | |
if (a.count() != b.count()) { | |
return false | |
} | |
// 3. Check there is no unequal rows | |
val aIndexValue = zipWithIndex(a.rdd) | |
val bIndexValue = zipWithIndex(b.rdd) | |
val unequalRDD = aIndexValue.join(bIndexValue).filter { case (idx, (r1, r2)) => | |
!DataFrameSuite.approxEquals(r1, r2, tol) | |
} | |
if (unequalRDD.take(1).length != 0) { | |
return false; | |
} | |
} finally { | |
a.rdd.unpersist() | |
b.rdd.unpersist() | |
} | |
true | |
} | |
def zipWithIndex[U](rdd: RDD[U]) = rdd.zipWithIndex().map { case (row, idx) => (idx, row) } | |
/** | |
* Approximate equality, based on equals from [[Row]] | |
* | |
* @param r1 | |
* @param r2 | |
* @param tol | |
* @return | |
*/ | |
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { | |
if (r1.length != r2.length) { | |
return false | |
} else { | |
var idx = 0 | |
val length = r1.length | |
while (idx < length) { | |
if (r1.isNullAt(idx) != r2.isNullAt(idx)) | |
return false | |
if (!r1.isNullAt(idx)) { | |
val o1 = r1.get(idx) | |
val o2 = r2.get(idx) | |
o1 match { | |
case b1: Array[Byte] => | |
if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) return false | |
case f1: Float => | |
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(o2.asInstanceOf[Float])) return false | |
if (abs(f1 - o2.asInstanceOf[Float]) > tol) return false | |
case d1: Double => | |
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(o2.asInstanceOf[Double])) return false | |
if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false | |
case d1: java.math.BigDecimal => | |
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false | |
case d1: org.apache.spark.ml.linalg.Vector => | |
val arr1: Array[Double] = d1.toArray | |
val arr2: Array[Double] = o2.asInstanceOf[org.apache.spark.ml.linalg.Vector].toArray | |
if (arr1.length != arr2.length) return false | |
for (i <- 0 to (arr1.length - 1)) { | |
if (abs(arr1(i) - arr2(i)) > tol) return false | |
} | |
case _ => | |
if (o1 != o2) return false | |
} | |
} | |
idx += 1 | |
} | |
} | |
true | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment