-
-
Save AxREki/70988ae40d1d82db15832575c175c41c to your computer and use it in GitHub Desktop.
import typing as T | |
import cytoolz.curried as tz | |
import pyspark | |
from pyspark.sql.functions import explode | |
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]: | |
columns = list() | |
def helper(schm: pyspark.sql.types.StructType, prefix: list = None): | |
if prefix is None: | |
prefix = list() | |
for item in schm.fields: | |
if isinstance(item.dataType, pyspark.sql.types.StructType): | |
helper(item.dataType, prefix + [item.name]) | |
else: | |
columns.append(prefix + [item.name]) | |
helper(schema) | |
return columns | |
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType): | |
have_array = False | |
aliased_columns = list() | |
i=0 | |
for column, t_column in frame.dtypes: | |
if t_column.startswith('array<') and i == 0: | |
have_array = True | |
c = explode(frame[column]).alias(column) | |
i = i+ 1 | |
else: | |
c = tz.get_in([column], frame) | |
aliased_columns.append(c) | |
return (frame.select(aliased_columns), have_array) | |
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
aliased_columns = list() | |
for col_spec in schema_to_columns(frame.schema): | |
c = tz.get_in(col_spec, frame) | |
if len(col_spec) == 1: | |
aliased_columns.append(c) | |
else: | |
aliased_columns.append(c.alias(':'.join(col_spec))) | |
return frame.select(aliased_columns) | |
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: | |
frame = flatten_frame(frame) | |
(frame, have_array) = flatten_array(frame) | |
if have_array: | |
return flatten_all(frame) | |
else: | |
return frame |
Hi ! I tested your script and some rows are lost in the process. I tested on a small dataframe (16 rows as input ) with a unique identifier column, and from the 16 different ids, only 3 are found in the output dataframe.
I observed that when a struct or array column in the input dataframe has null values the rows having these nulls are deleted
edit : it's the use of explode that deletes null values in array columns, replace it with explode_outer to handle such cases.
thanks for your feedback @vchalmel
The entire row is being deleted ? Even if there was data in other columns ?
Do you have a sample I could try it on to improve the gist ?
Yes, the entire row, c.f. the spark.sql module's documentation
pyspark.sql.functions.explode_outer(col)
Returns a new row for each element in the given array or map. Unlike explode, if the array/map is null or empty then null is produced.
df = spark.createDataFrame([(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],("id", "an_array", "a_map"))
df.select("id", "an_array", explode("a_map")).show()
df.select("id", "an_array", explode_outer("a_map")).show()
df.select("id", "a_map", explode("an_array")).show()
df.select("id", "a_map", explode_outer("an_array")).show()
I am able to run this successfully but I am getting duplicate values, the number of rows after flattening are getting doubled. Does anyone know why?
rows with NULL values are disappearing also for me.
This one https://gist.github.com/nmukerje/e65cde41be85470e4b8dfd9a2d6aed50 have the explode uter fix
Yes. MapType does not work with this script.