Created
February 4, 2019 20:46
-
-
Save andrewgross/d82ae0bdff86b12541591866414071cb to your computer and use it in GitHub Desktop.
PySpark code to take a dataframe and repartition it in to an optimal number of partitions for generating 300Mb-1GB parquet files.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import pyspark.sql.types as T | |
from math import ceil | |
def repartition_for_writing(df): | |
count = df.count() | |
sampled_df = get_sampled_df(df, count=count) | |
string_column_sizes = get_string_column_sizes(sampled_df) | |
num_files = get_num_files(count, df.schema, string_column_sizes) | |
print(num_files) | |
return df.repartition(num_files) | |
def get_sampled_df(df, count=None): | |
if not count: | |
count = df.count() | |
sample_size = 100000.0 | |
raw_fraction = sample_size / count | |
clamped_fraction = min(raw_fraction, 1.0) | |
return df.sample(withReplacement=False, fraction=clamped_fraction) | |
def get_string_column_sizes(df): | |
ddf = df | |
string_cols = [] | |
for column in ddf.schema: | |
if isinstance(column.dataType, T.StringType): | |
ddf = ddf.withColumn('{}__length'.format(column.name), F.length(column.name)) | |
string_cols.append('{}__length'.format(column.name)) | |
sizes = ddf.groupBy().avg(*string_cols).first().asDict() | |
cleaned_sizes = {} | |
for k, v in sizes.items(): | |
col_name = re.search(r'avg\((.+)__length\)', k).groups()[0] | |
cleaned_sizes[col_name] = v | |
return cleaned_sizes | |
def get_num_files(rows, schema, string_column_sizes): | |
record_size = _get_record_size(schema, string_column_sizes) | |
return _get_files_based_on_file_size(rows, record_size) | |
def _get_record_size(schema, string_field_sizes): | |
size_mapping = get_size_mapping() | |
record_size = 0 | |
for field in schema: | |
_type = field.dataType.typeName() | |
if _type == "string": | |
# Fetch our avg size for a string field, convert from bytes to bits | |
field_size = string_field_sizes[field.name] * 8 | |
else: | |
field_size = size_mapping[_type] | |
record_size = record_size + field_size | |
return record_size | |
def _get_files_based_on_file_size(rows, record_size): | |
compression_ratio = get_compression_ratio("snappy") | |
data_size_in_bits = record_size * rows | |
data_size_in_bytes = data_size_in_bits / 8 | |
data_size_mb = data_size_in_bytes / (1024 * 1024) | |
data_size_compressed = data_size_mb * compression_ratio | |
# Aim for 1 GB files | |
num_files = data_size_compressed / 1024 | |
return int(ceil(num_files)) | |
def get_size_mapping(): | |
""" | |
Size mapping for non-string fields in Parquet files | |
""" | |
PARQUET_SIZE_MAPPING = { | |
"short": 32, | |
"integer": 32, | |
"long": 64, | |
"boolean": 1, | |
"float": 32, | |
"double": 64, | |
"decimal": 64, | |
"date": 32, # Assume no date64 | |
"timestamp": 96, # Assume legacy timestamp | |
} | |
return PARQUET_SIZE_MAPPING | |
def get_compression_ratio(compression): | |
""" | |
Return a floating point scalar for the size after compression. | |
""" | |
if compression == "snappy": | |
return 0.6 | |
return 1.0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment