Skip to content

Instantly share code, notes, and snippets.

@nguyenvulebinh
Last active August 8, 2023 15:08
Show Gist options
  • Save nguyenvulebinh/794c296b1133feb80e46e812ef50f7fc to your computer and use it in GitHub Desktop.
Save nguyenvulebinh/794c296b1133feb80e46e812ef50f7fc to your computer and use it in GitHub Desktop.
Flatten a Spark DataFrame schema (include struct and array type)
import typing as T
import cytoolz.curried as tz
import pyspark
from pyspark.sql.functions import explode
def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
columns = list()
def helper(schm: pyspark.sql.types.StructType, prefix: list = None):
if prefix is None:
prefix = list()
for item in schm.fields:
if isinstance(item.dataType, pyspark.sql.types.StructType):
helper(item.dataType, prefix + [item.name])
else:
columns.append(prefix + [item.name])
helper(schema)
return columns
def flatten_array(frame: pyspark.sql.DataFrame) -> (pyspark.sql.DataFrame, BooleanType):
have_array = False
aliased_columns = list()
for column, t_column in frame.dtypes:
if t_column.startswith('array<'):
have_array = True
c = explode(frame[column]).alias(column)
else:
c = tz.get_in([column], frame)
aliased_columns.append(c)
return (frame.select(aliased_columns), have_array)
def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
aliased_columns = list()
for col_spec in schema_to_columns(frame.schema):
c = tz.get_in(col_spec, frame)
if len(col_spec) == 1:
aliased_columns.append(c)
else:
aliased_columns.append(c.alias(':'.join(col_spec)))
return frame.select(aliased_columns)
def flatten_all(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
frame = flatten_frame(frame)
(frame, have_array) = flatten_array(frame)
if have_array:
return flatten_all(frame)
else:
return frame
@NataliaLaurova
Copy link

@ghost
I faced with a similar issue as you - I got the right structure and flatten schema but there is no data. The root cause of it is - explode function, it will drop records if there are nulls in the columns. I fixed that using outer_explode (that takes care of the null and behave as a left outer join in SQL). Hope this will help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment