Last active
July 3, 2019 17:19
-
-
Save saswata-dutta/12fdb035b8d5af3a49d7d67be10b2981 to your computer and use it in GitHub Desktop.
tranc vitals model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.expressions.UserDefinedFunction | |
import org.apache.spark.sql.expressions.Window | |
import java.time.{Instant, LocalDate, ZoneId, ZonedDateTime} | |
val toEpoch: UserDefinedFunction = | |
udf[Long, Long]((t: Long) => t / 1000L) | |
val fixCoolant: UserDefinedFunction = | |
udf[Double, Double]((c: Double) => if(c.round == 0 || c.round == 120) 40.0 else c) | |
val haversine: UserDefinedFunction = udf[Double, Double, Double, Double, Double]((lat1:Double, lon1:Double, lat2:Double, lon2:Double) => { | |
import math._ | |
if (lat1 < 0 || lon1 < 0 || lat2 < 0 || lon2 < 0) 0.000 | |
else { | |
val dLat=(lat2 - lat1).toRadians | |
val dLon=(lon2 - lon1).toRadians | |
val a = pow(sin(dLat/2),2) + pow(sin(dLon/2),2) * cos(lat1.toRadians) * cos(lat2.toRadians) | |
val c = 2 * asin(sqrt(a)) | |
val dist = 6372.8 * c | |
BigDecimal(dist).setScale(3, BigDecimal.RoundingMode.HALF_UP).toDouble | |
} | |
}) | |
val timebucket: UserDefinedFunction = udf[String, Long]((e: Long) => { | |
val secs = e / 1000 | |
val i = Instant.ofEpochSecond(secs) | |
val IST: ZoneId = ZoneId.of("Asia/Kolkata") | |
val zdt = ZonedDateTime.ofInstant(i, IST) | |
val (y,m,d,h,mm) = (zdt.getYear, zdt.getMonthValue, zdt.getDayOfMonth, zdt.getHour, zdt.getMinute) | |
val mm_bkt = (mm / 15) * 15 | |
f"${y}%04d_${m}%02d_${d}%02d_${h}%02d_${mm_bkt}%02d" | |
}) | |
val cols = Array("vehicle", "trip_id", "latitude", "longitude", "coolant", "engine_oil_pressure", "speed") | |
val range_start = 1551983400000L // march 8 | |
val range_stop = 1556649000000L // may 1 | |
// give s3 path for month 3 and then month 4 ? cmd args s3 suffix ? | |
val data = | |
spark.read.parquet("trip_vitals/") | |
.filter(s"departure_timestamp > $range_start and arrival_timestamp < $range_stop") | |
.filter("engine_oil_pressure is not null") | |
.filter("latitude is not null") | |
.filter("longitude is not null") | |
.na.fill(40.0, Seq("coolant")) | |
.withColumn("coolant", fixCoolant($"coolant")) | |
.withColumnRenamed("gps_timestamp", "epoch") | |
.select("epoch", cols: _*) | |
val trip_vehicle_window = Window.partitionBy("trip_id", "vehicle").orderBy("epoch") | |
val lag_lat = lag(col("latitude"), 1, -1).over(trip_vehicle_window) | |
val lag_lon = lag(col("longitude"), 1, -1).over(trip_vehicle_window) | |
val dist = haversine(col("latitude"), col("longitude"), lag_lat, lag_lon) | |
val withDist = data.withColumn("distance", dist) | |
// withDist.orderBy("trip_id", "vehicle", "epoch").filter("distance > 10").show(100) | |
// agg fns | |
val num_instances = count("*").alias("num_instances") | |
val mean_coolant = mean(col("coolant")).alias("mean_coolant") | |
val stddev_coolant = stddev(col("coolant")).alias("stddev_coolant") | |
val max_coolant = max(col("coolant")).alias("max_coolant") | |
val mean_speed = mean(col("speed")).alias("mean_speed") | |
val max_speed = mean(col("speed")).alias("max_speed") | |
val total_dist = sum(col("distance")).alias("total_dist") | |
val count_speed_0 = sum((col("speed") < 1).cast("integer")).alias("count_speed_0") | |
val num_hi_oil_pressure = sum((col("engine_oil_pressure") > 0).cast("integer")).alias("num_hi_oil_pressure") | |
val withTimeBucket = withDist.withColumn("timebucket", timebucket(col("epoch"))) | |
val summary = | |
withTimeBucket | |
.groupBy("vehicle", "trip_id", "timebucket") | |
.agg(num_instances, mean_coolant, stddev_coolant, | |
max_coolant, mean_speed, max_speed, | |
num_hi_oil_pressure, total_dist, count_speed_0) | |
//dummy.groupBy("k").agg(min("v"), max("v"), count("*")).show |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment