Created
January 26, 2018 00:39
-
-
Save gbraccialli/a0d0650caa474402150f138e88c5db9f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class Client(id: Integer, name: String, parent: Integer, value: Integer) | |
def append(arr: Seq[Int], element: Any): Seq[Int] = { | |
element match { | |
case null => arr | |
case i: Int => i +: arr | |
} | |
} | |
val udfAppend = udf(append _) | |
val df = Seq( | |
Client(1 , "aa" , null, 423), | |
Client(11 , "aa1.1" , 1 , 456), | |
Client(12 , "aa1.2" , 1 , 657), | |
Client(13 , "aa1.3" , 1 , 234), | |
Client(111 , "aa1.1.1" , 11 , 964), | |
Client(112 , "aa1.1.2" , 11 , 238), | |
Client(1111, "aa1.1.1.1", 111 , 853), | |
Client(1112, "aa1.1.1.2", 111 , 924), | |
Client(2 , "bb" , null, 423) | |
).toDF | |
val dfAux = df.withColumnRenamed("id", "id2").select("id2", "parent").cache | |
dfAux.count | |
val dfStart = df.select($"id", $"name", $"value", when($"parent".isNull, lit(Array.empty[Int])).otherwise(array($"parent").as("parents")).as("parents")).cache | |
dfStart.count | |
val buf = scala.collection.mutable.ArrayBuffer.empty[org.apache.spark.sql.DataFrame] | |
buf += dfStart | |
val maxIter = 10 | |
for (i <- 1 to maxIter){ | |
//TODO check if top level as achieve for all rows and stop before maxIter | |
//TODO option to stop cyclics | |
println("iteration: " + i) | |
val dfPrevious = buf.last | |
val dfCurrent = dfPrevious.join(dfAux, dfPrevious("parents").getItem(0) === dfAux("id2"), "left").withColumn("parents", udfAppend($"parents", $"parent")).withColumn("top", when($"parent".isNull, "Y").otherwise("N")).drop("parent", "id2").cache | |
dfPrevious.unpersist() | |
buf += dfCurrent | |
} | |
val dfFinalIter = buf.last | |
val dfAgg = dfFinalIter.select(explode($"parents").as("id"),$"value").groupBy("id").agg(count(lit(1)).as("children"), sum($"value").as("total_children")) | |
val dfResult = dfFinalIter.join(dfAgg, Seq("id"), "left") | |
df.sort("id").show | |
dfResult.sort("id").show |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment