Last active
September 20, 2018 13:51
-
-
Save gbraccialli/d9301befd0c62bfeb58da3937045d0f8 to your computer and use it in GitHub Desktop.
spark_scala_python_udf_battle
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//scala create datasets | |
def randomStr(size: Int): String = { | |
import scala.util.Random | |
return Random.alphanumeric.take(size).mkString("") | |
} | |
val udfRandomStr = udf(randomStr _) | |
val dfRnd = (1 to 30000).toDF.repartition(3000) | |
val dfRnd2 = (1 to 10).toDF.withColumnRenamed("value", "value2") | |
//creates 2.8GB dataset with 300,000 rows | |
dfRnd.crossJoin(broadcast(dfRnd2)).withColumn("text", udfRandomStr(lit(10000))).withColumn("del1", udfRandomStr(lit(2))).withColumn("del2", udfRandomStr(lit(2))).write.mode("overwrite").save("randomDF") | |
import scala.util.Random | |
val dfRnd = (1 to 100000).toDF.repartition(3000) | |
val dfRnd2 = (1 to 2000).toDF.withColumnRenamed("value", "value2") | |
//creates 3.1GB dataset with 200,000,000 rows | |
dfRnd.crossJoin(broadcast(dfRnd2)).withColumn("lat", udf{(a: Any) => -90 + 180*Random.nextDouble}.apply($"value")).withColumn("lon", udf{(a: Any) => -180 + 360*Random.nextDouble}.apply($"value")).write.mode("overwrite").save("randomDF_geo") | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#PYTHON TIMES 1000 | |
df = spark.read.load("randomDF_geo").cache() | |
df.count() | |
import pandas as pd | |
from pyspark.sql.functions import pandas_udf, PandasUDFType | |
from pyspark.sql.functions import avg, udf, substring, col | |
from pyspark.sql.types import StringType, DoubleType | |
def times1000(field): | |
return field * 1000.00 | |
udfTimes1000 = udf(times1000, DoubleType()) | |
@pandas_udf('double', PandasUDFType.SCALAR) | |
def pandasUdf_times1000(field): | |
return field * 1000 | |
#1.2 minutes | |
df.select(udfTimes1000(df.lat).alias("output")).agg(avg("output")).show() | |
#32 seconds | |
df.select(pandasUdf_times1000(df.lat).alias("output")).agg(avg("output")).show() | |
#PYTHON GEOHASH | |
df = spark.read.load("randomDF_geo").cache() | |
df.count() | |
import pandas as pd | |
from pyspark.sql.functions import pandas_udf, PandasUDFType | |
from pyspark.sql.functions import avg, udf, substring, col | |
from pyspark.sql.types import StringType, DoubleType | |
import geohash | |
def geohash_pyspark(lat, lon): | |
return geohash.encode(lat, lon) | |
udfGeohash = udf(geohash_pyspark, StringType()) | |
@pandas_udf('string', PandasUDFType.SCALAR) | |
def geohash_pandas_udf(series_lat, series_lon): | |
df = pd.DataFrame({'lat': series_lat,'lon': series_lon}) | |
return pd.Series(df.apply(lambda row: geohash.encode(row['lat'], row['lon']), axis=1)) | |
df = spark.read.load("randomDF_geo").cache() | |
#2.7 minutes | |
df.select(udfGeohash(df.lat, df.lon).alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show() | |
#23 minutes | |
df.select(geohash_pandas_udf(df.lat, df.lon).alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show() | |
def strExtract(text, del1, del2): | |
start = text.find(del1) | |
end = text.find(del2, start) | |
if start > -1 and end > -1 and end > start+len(del1): | |
return text[start+len(del1):end] | |
else: | |
return "invalid" | |
import pandas as pd | |
from pyspark.sql.functions import pandas_udf, PandasUDFType | |
@pandas_udf('string', PandasUDFType.SCALAR) | |
def pandasUdf(series_text, series_delim1, series_delim2): | |
outputs = [] | |
row = 0 | |
for text in series_text: | |
del1 = series_delim1[row] | |
del2 = series_delim2[row] | |
start = text.find(del1) | |
end = text.find(del2, start) | |
outputs.append(strExtract(text, del1, del2)) | |
row += 1 | |
return pd.Series(outputs) | |
@pandas_udf('string', PandasUDFType.SCALAR) | |
def pandasUdf2(series_text, series_delim1, series_delim2): | |
df = pd.DataFrame({'text': series_text,'delim1': series_delim1, 'delim2': series_delim2}) | |
return pd.Series(df.apply(lambda row: strExtract(row['text'], row['delim1'], row['delim2']), axis=1)) | |
def strExtractSplit(concat): | |
parts = str(concat).split("|") | |
return strExtract(parts[0], parts[1], parts[2]) | |
@pandas_udf('string', PandasUDFType.SCALAR) | |
def pandasUdf3(fields_concat): | |
return fields_concat.apply(lambda row: strExtractSplit(row)) | |
from pyspark.sql.types import StringType, DoubleType | |
from pyspark.sql.functions import udf, concat, lit | |
udfExtract = udf(strExtract, StringType()) | |
df = spark.read.load("randomDF").cache() | |
df.count() | |
#16 seocnds | |
df.select(udfExtract(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show() | |
#16 seconds | |
df.select(pandasUdf(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show() | |
#16 seconds | |
df.select(pandasUdf2(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show() | |
#15 seconds | |
df.select(pandasUdf3(concat(df.text, lit('|'), df.del1, lit('|'), df.del2)).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//SCALA times 1000 | |
val df = spark.read.load("randomDF_geo").cache() | |
df.count() | |
//1 second | |
df.select(udf{(a: Double) => a*1000.00}.apply($"lat").alias("output")).agg(avg("output")).show() | |
//SCALA GEOHASH | |
//spark-shell --packages com.github.davidmoten:geo:0.7.1 | |
import com.github.davidmoten.geo._ | |
def geohash(lat: Double, lon:Double): String = GeoHash.encodeHash(lat,lon) | |
def udfGeohash = udf(geohash _) | |
val df = spark.read.load("randomDF_geo").cache() | |
df.count() | |
//23 seconds | |
df.select(udfGeohash($"lat", $"lon").alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show() | |
def strExtract(input: String, del1: String, del2: String): String = { | |
val start = input.indexOf(del1) | |
val end = input.indexOf(del2, start) | |
if (start > -1 && end > -1 && end > start+del1.length()) | |
return input.substring(start+del1.length(),end) | |
else | |
return "invalid" | |
} | |
val udfExtract = udf(strExtract _) | |
val df = spark.read.load("randomDF").cache() | |
df.count() | |
//10 seconds | |
df.select(udfExtract(df.col("text"), df.col("del1"), df.col("del2")).alias("output")).groupBy("output").count().orderBy(desc("count")).show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment