Created
November 20, 2024 20:23
-
-
Save maneeshdisodia/b24105339a774f6d5e2d60a41d7b8a8c to your computer and use it in GitHub Desktop.
groupby pyspark udf pandas with arguments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql.functions import pandas_udf, PandasUDFType | |
df = spark.createDataFrame( | |
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], | |
("id", "v")) | |
def my_function(df, by="id", column="v", value=1.0): | |
schema = "{} long, {} double".format(by, column) | |
#@pandas_udf(schema, PandasUDFType.GROUPED_MAP) | |
def subtract_value(pdf): | |
# pdf is a pandas.DataFrame | |
v = pdf[column] | |
g = pdf[by] | |
return pdf.assign(v = v - g * value) | |
return df.groupby(by).applyInPandas(subtract_value,schema) | |
my_function(df, by="id", column="v", value=2.0).show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment