Skip to content

Instantly share code, notes, and snippets.

@franciscodara
Forked from nmukerje/Pyspark Flatten json
Created January 20, 2023 13:27
Show Gist options
  • Save franciscodara/720d340eec23b1a26fac1e5d4a5f5fd8 to your computer and use it in GitHub Desktop.
Save franciscodara/720d340eec23b1a26fac1e5d4a5f5fd8 to your computer and use it in GitHub Desktop.
from pyspark.sql.types import *
from pyspark.sql.functions import *
#Flatten array of structs and structs
def flatten(df):
# compute Complex Fields (Lists and Structs) in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
while len(complex_fields)!=0:
col_name=list(complex_fields.keys())[0]
print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if (type(complex_fields[col_name]) == StructType):
expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in complex_fields[col_name]]]
df=df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif (type(complex_fields[col_name]) == ArrayType):
df=df.withColumn(col_name,explode_outer(col_name))
# recompute remaining Complex Fields in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType or type(field.dataType) == StructType])
return df
df=flatten(df)
df.printSchema()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment