Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save karpanGit/4884d500ef905e42b4f36d3de9d5e3cb to your computer and use it in GitHub Desktop.

Select an option

Save karpanGit/4884d500ef905e42b4f36d3de9d5e3cb to your computer and use it in GitHub Desktop.
pyspark, create dataframe with columns of struct (Row) type
# experiment with complex elements
import os
os.environ['PYSPARK_PYTHON'] = '/bin/python3'
os.environ['PYSPARK_DRIVER_PYTHON'] = '/bin/python3'
# setup the python path
import sys
sys.path = ['/usr/hdp/current/spark2-client/python',
'/usr/hdp/current/spark2-client/python/lib/py4j-0.10.4-src.zip'] + sys.path
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.types import *
import random
spark = SparkSession.builder.appName('learn').master('yarn').enableHiveSupport().getOrCreate()
small_schema = StructType([
StructField('id_inside', LongType(), nullable=False),
StructField('name_inside', StringType(), nullable=False),
])
schema = StructType([
StructField('id_outside', LongType(), nullable=False),
StructField('complex', small_schema, nullable=False)
])
# create a dataframe, method 1
data = [
[10, Row(id_inside= 100, name_inside= 'str1')],
[20, Row(id_inside= 200, name_inside= 'str2')],
]
df = spark.createDataFrame(data, schema=schema)
df.show()
df.printSchema()
# create a dataframe, method 2
data = [
[10, [100, 'str1']],
[20, [200, 'str2']],
]
df = spark.createDataFrame(data, schema=schema)
df.show()
df.printSchema()
# in both case the second column is of type struct (i.e. Row)
# root
# |-- id_outside: long (nullable = false)
# |-- complex: struct (nullable = false)
# | |-- id_inside: long (nullable = false)
# | |-- name_inside: string (nullable = false)
print(df.take(1)[0]['complex'])
# Row(id_inside=100, name_inside='str1')
print(type(df.take(1)[0]['complex']))
# <class 'pyspark.sql.types.Row'>
# more complex example
df = (spark
.range(0, 10 * 1000 * 1000)
.withColumn('id', (f.col('id') / 1000).cast('integer'))
.withColumn('v', f.rand()))
df.cache()
df.count()
res = ((df.groupby('id').agg(f.collect_list(f.struct(df['id'], df['v'])).alias('rows'))))
res.printSchema()
# root
# |-- id: integer (nullable = true)
# |-- rows: array (nullable = true)
# | |-- element: struct (containsNull = true)
# | | |-- id: integer (nullable = true)
# | | |-- v: double (nullable = false)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment