Skip to content

Instantly share code, notes, and snippets.

@bjornjorgensen
Forked from nmukerje/Pyspark Flatten json
Last active April 20, 2022 12:37
Show Gist options
  • Save bjornjorgensen/8c562e9adaed9cc070035005dd72b95e to your computer and use it in GitHub Desktop.
Save bjornjorgensen/8c562e9adaed9cc070035005dd72b95e to your computer and use it in GitHub Desktop.
from pyspark.sql.types import *
from pyspark.sql.functions import *
def flatten_test(df, sep="_"):
"""Returns a flattened dataframe.
.. versionadded:: x.X.X
Parameters
----------
sep : str
Delimiter for flatted columns. Default `_`
Notes
-----
Don`t use `.` as `sep`
It won't work on nested data frames with more than one level.
And you will have to use `columns.name`.
Flattening Map Types will have to find every key in the column.
This can be slow.
Examples
--------
data_mixed = [
{
"state": "Florida",
"shortname": "FL",
"info": {"governor": "Rick Scott"},
"counties": [
{"name": "Dade", "population": 12345},
{"name": "Broward", "population": 40000},
{"name": "Palm Beach", "population": 60000},
],
},
{
"state": "Ohio",
"shortname": "OH",
"info": {"governor": "John Kasich"},
"counties": [
{"name": "Summit", "population": 1234},
{"name": "Cuyahoga", "population": 1337},
],
},
]
data_mixed = spark.createDataFrame(data=data_mixed)
data_mixed.printSchema()
root
|-- counties: array (nullable = true)
| |-- element: map (containsNull = true)
| | |-- key: string
| | |-- value: string (valueContainsNull = true)
|-- info: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
data_mixed_flat = flatten_test(df, sep=":")
data_mixed_flat.printSchema()
root
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
|-- counties:name: string (nullable = true)
|-- counties:population: string (nullable = true)
|-- info:governor: string (nullable = true)
data = [
{
"id": 1,
"name": "Cole Volk",
"fitness": {"height": 130, "weight": 60},
},
{"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
{
"id": 2,
"name": "Faye Raker",
"fitness": {"height": 130, "weight": 60},
},
]
df = spark.createDataFrame(data=data)
df.printSchema()
root
|-- fitness: map (nullable = true)
| |-- key: string
| |-- value: long (valueContainsNull = true)
|-- id: long (nullable = true)
|-- name: string (nullable = true)
df_flat = flatten_test(df, sep=":")
df_flat.printSchema()
root
|-- id: long (nullable = true)
|-- name: string (nullable = true)
|-- fitness:height: long (nullable = true)
|-- fitness:weight: long (nullable = true)
data_struct = [
(("James",None,"Smith"),"OH","M"),
(("Anna","Rose",""),"NY","F"),
(("Julia","","Williams"),"OH","F"),
(("Maria","Anne","Jones"),"NY","M"),
(("Jen","Mary","Brown"),"NY","M"),
(("Mike","Mary","Williams"),"OH","M")
]
schema = StructType([
StructField('name', StructType([
StructField('firstname', StringType(), True),
StructField('middlename', StringType(), True),
StructField('lastname', StringType(), True)
])),
StructField('state', StringType(), True),
StructField('gender', StringType(), True)
])
df_struct = spark.createDataFrame(data = data_struct, schema = schema)
df_struct.printSchema()
root
|-- name: struct (nullable = true)
| |-- firstname: string (nullable = true)
| |-- middlename: string (nullable = true)
| |-- lastname: string (nullable = true)
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
df_struct_flat = flatten_test(df_struct, sep=":")
df_struct_flat.printSchema()
root
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
|-- name:firstname: string (nullable = true)
|-- name:middlename: string (nullable = true)
|-- name:lastname: string (nullable = true)
"""
# compute Complex Fields (Arrays, Structs and Maptypes) in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType])
while len(complex_fields) !=0:
col_name = list(complex_fields.keys())[0]
#print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if (type(complex_fields[col_name]) == StructType):
expanded = [col(col_name + '.' + k).alias(col_name + sep + k)
for k in [n.name for n in complex_fields[col_name]]]
df = df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif (type(complex_fields[col_name]) == ArrayType):
df = df.withColumn(col_name, explode_outer(col_name))
# if MapType then convert all sub element to columns.
# i.e. flatten
elif (type(complex_fields[col_name]) == MapType):
keys_df = df.select(explode_outer(map_keys(col(col_name)))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(map(lambda f: col(col_name).getItem(f)
.alias(str(col_name + sep + f)), keys))
drop_column_list = [col_name]
df = df.select([col_name for col_name in df.columns
if col_name not in drop_column_list] + key_cols)
# recompute remaining Complex Fields in Schema
complex_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType])
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment