Skip to content

Instantly share code, notes, and snippets.

@zoltanctoth
Last active July 15, 2023 13:23
Show Gist options
  • Save zoltanctoth/2deccd69e3d1cde1dd78 to your computer and use it in GitHub Desktop.
Save zoltanctoth/2deccd69e3d1cde1dd78 to your computer and use it in GitHub Desktop.
Writing an UDF for withColumn in PySpark
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
maturity_udf = udf(lambda age: "adult" if age >=18 else "child", StringType())
df = spark.createDataFrame([{'name': 'Alice', 'age': 1}])
df.withColumn("maturity", maturity_udf(df.age))
df.show()
@Sisyphuss
Copy link

Thanks for the 2nd line.

@aloktiagi
Copy link

How would you pass multiple columns of df to maturity_udf?

@rajenur
Copy link

rajenur commented May 6, 2017

thanks z

@smanurung
Copy link

How do you do it for nested fields?

@mrandrewandrade
Copy link

mrandrewandrade commented Mar 18, 2018

This is awesome but I wanted to give a couple more examples and info.

Let's say your UDF is longer, then it might be more readable as a stand alone def instead of a lambda:

def return_age_bracket(age):
  if (age <= 12):
    return 'Under 12'
  elif (age >= 13 and age <= 19):
    return 'Between 13 and 19'
  elif (age > 19 and age < 65):
    return 'Between 19 and 65'
  elif (age >= 65):
    return 'Over 65'
  else: return 'N/A'

from pyspark.sql.functions import udf

maturity_udf = udf(return_age_bracket)
df = sqlContext.createDataFrame([{'name': 'Alice', 'age': 1}])
df.withColumn("maturity", maturity_udf(df.age))

With a small to medium dataset this may take many minutes to run. To debug, you can run df.explain, and will get a query plan like:

== Physical Plan ==
*(2) Project [Name#3, pythonUDF0#41 AS age_bracket#25]
+- BatchEvalPython [return_age_bracket(Age#5)], [Name#3, Age#5, pythonUDF0#41]

The badness here might be the pythonUDF as it might not be optimized. Instead, you should look to use any of the pyspark.functions as they are optimized to run faster. In this example, when((condition), result).otherwise(result) is a much better way of doing things:

from pyspark.sql import functions as F
df = sqlContext.createDataFrame([{'name': 'Alice', 'age': 1}])
df.withColumn("age_bracket", F.when(input_df.Age <= 12, 'Infant').when(( (input_df.Age >= 13) & (input_df.Age <= 19)), 'Adolescent').when(( (input_df.Age >= 19) & (input_df.Age < 65)), 'Adult').when(input_df.Age >= 65, 'Retired').otherwise('N/A'))

The query will look something like:

== Physical Plan ==
*(1) Project [Name#3, CASE WHEN (Age#5 <= 12.0) THEN Infant WHEN ((Age#5 >= 13.0) && (Age#5 <= 19.0)) THEN Adolescent WHEN ((Age#5 >= 19.0) && (Age#5 < 65.0)) THEN Adult WHEN (Age#5 >= 65.0) THEN Retired ELSE N/A END AS age_bracket#42]

Copy link

ghost commented Jan 7, 2019

Thanks ! :)

@swetaravi
Copy link

I have a question. When I have a data frame with date columns in the format of 'Mmm dd,yyyy' then can I use this udf?

1 Change date fields

review_date_udf = fn.udf(
lambda x: datetime.strptime(x, ' %b %d, %Y'), DateType()
)

reviews_df = reviews_df.withColumn("dates", review_date_udf(reviews_df['dates']))

But when I try to view the data frame it starts throwing an error of Caused by: java.net.SocketTimeoutException: Accept timed out. Any ideas to solve this issue?

@datbui
Copy link

datbui commented Feb 3, 2020

Thanks !

@vinothkumar-dev
Copy link

TypeError: a bytes-like object is required, not 'NoneType'

I am getting this error while trying 'mrandrewandrade' input.
How can I resolve this error?

Thanks.

@abdu95
Copy link

abdu95 commented Sep 2, 2021

It was nice to come across my teacher's code even after graduation. Thank you!

@Weiyu-Luo
Copy link

I encountered this problem too. Have you solved it? Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment