Skip to content

Instantly share code, notes, and snippets.

@kagesenshi
Last active October 24, 2019 09:16
Show Gist options
  • Save kagesenshi/253dc3584fb01c828562be198a64d472 to your computer and use it in GitHub Desktop.
Save kagesenshi/253dc3584fb01c828562be198a64d472 to your computer and use it in GitHub Desktop.
Hubcap data transformation helpers
# Data cleansing helper functions
from pyspark.sql import functions as F
from pyspark.sql import Window, DataFrame
from IPython.display import HTML, Markdown
from pyspark.ml.feature import VectorAssembler
import pandas as pd
import matplotlib.pyplot as plt
def transform(self, function, *args, **kwargs):
return function(self, *args, **kwargs)
if getattr(DataFrame, 'transform', None) is None:
DataFrame.transform = transform
def removeColumnPrefixes(df, prefixes):
for col in df.columns:
for prefix in prefixes:
if col.startswith(prefix):
df = df.withColumnRenamed(col, col.replace(prefix,''))
return df
def renameColumns(df, mapping):
for ocol, newcol in mapping.items():
if ocol in df.columns:
df = df.withColumnRenamed(ocol, newcol)
else:
print(f"WARNING: column {ocol} unavailable")
return df
def dropColumns(df, columns):
return df.select(*[col for col in df.columns if col not in columns])
def trimColumns(df, columns):
cols = []
for col in df.columns:
if col in columns:
cols.append(F.when(F.col(col).isNotNull() & (F.trim(F.col(col)) != F.lit('')),
F.trim(F.col(col))).alias(col))
else:
cols.append(F.col(col))
return df.select(*cols)
def trimAllStringColumns(df):
string_fields = []
for field in df.columns:
dataType = df.schema[field].dataType.typeName()
if dataType == 'string':
string_fields.append(field)
return trimColumns(df, string_fields)
def plotStringLengthDistribution(df, field, length_field_name='length', title=None):
if isinstance(field, str):
field = F.col(field)
df = (df.select(field)
.withColumn(length_field_name, F.length(field))
.groupBy(F.col(length_field_name)).count()
.orderBy(F.col('count').desc()))
display(df.toPandas().plot(x=length_field_name, kind='bar', title=title))
plt.show()
def findDuplicatesBy(df, fields):
if not isinstance(fields, list) and not isinstance(fields, tuple):
fields = [fields]
columns = df.columns
select_columns = columns + [F.count('*').over(Window.partitionBy(*fields)).alias('record_count')]
grouping = df.select(select_columns)
dupes = (df.groupBy(*fields)
.agg(F.collect_list(F.struct(*columns)).alias('records'),
F.count('*').alias('record_count'))
.where('record_count > 1'))
unique = grouping.where('record_count == 1').drop('record_count')
return [dupes, unique]
def stringToVector(df, field):
df = df.select(field).where(df['iddoc_type'] == 'O')
# id structure analysis
df = (df.select(field, F.posexplode(F.split(field,'')).alias('position', 'letter'))
.withColumn('letter_pos', F.concat(F.lit('letter'), F.format_string('%02d', F.col('position')))))
df = df.withColumn('letter_type', F.when(F.regexp_extract(F.col('letter'), r'^[a-zA-Z]+$', 0) != '', 2)
.when(F.regexp_extract(F.col('letter'), r'^[0-9]+$', 0) != '', 1))
df = df.drop('position','letter').groupBy(field).pivot('letter_pos').agg(F.first('letter_type'))
df = df.withColumn('str_length', F.length(F.col(field)))
for c in df.columns:
df = df.withColumn(c, F.when(F.col(c).isNotNull(), F.col(c)).otherwise(0))
vect_in = df.drop('iddoc')
vect_assembler = VectorAssembler(inputCols=vect_in.columns, outputCol='features')
vect = vect_assembler.transform(vect_in)
return [vect, df]
def profile(df, field, find_duplicates=False):
dataType = df.schema[field].dataType.typeName()
display(Markdown(f'# Field {field} ({dataType})'))
nullreport = df.select(F.when(F.col(field).isNotNull(), F.lit('Not Null'))
.otherwise(F.lit('Null')).alias('recorded')).groupBy(F.col('recorded')).count()
display(nullreport.toPandas().set_index('recorded').plot(y='count', autopct='%.2f', kind='pie', title='Null Report'))
plt.axis("off")
plt.show()
if dataType in ['integer', 'long','float']:
zerosreport = df.select(F.when(F.col(field).isNotNull() & (F.col(field) != 0), F.lit('Not Zero'))
.otherwise(F.lit('Zero')).alias('recorded')).groupBy(F.col('recorded')).count()
display(zerosreport.toPandas().set_index('recorded').plot(y='count', autopct='%.2f', kind='pie', title='Zeros Report'))
plt.axis("off")
plt.show()
if dataType == 'string':
dist = df.groupBy(field).count()
dist_top = dist.orderBy(F.col('count').desc()).limit(20)
dist_bottom = dist.orderBy(F.col('count')).limit(20)
display(dist_top.toPandas().plot(kind='bar', x=field, title='Top 20 values by count'))
display(dist_bottom.toPandas().plot(kind='bar', x=field, title='Bottom 20 values by count'))
plt.show()
plotStringLengthDistribution(df, field, title='Length distribution')
if find_duplicates:
if dataType == 'string':
dupes, unique = findDuplicatesBy(df.withColumn(field,F.trim(F.upper(F.col(field)))), field)
else:
dupes, unique = findDuplicatesBy(df, field)
dupes = dupes.cache()
unique_count = unique.select(F.lit('Unique').alias('label'), F.count('*').alias('count'))
dupes_count = dupes.select(F.lit('Duplicate').alias('label'), F.count('*').alias('count'))
dupe_report = unique_count.union(dupes_count).toPandas()
has_duplicate = False
for r in dupe_report.to_dict('records'):
if r['label'] == 'Duplicate' and r['count'] > 0:
has_duplicate = True
if has_duplicate:
display(dupe_report.set_index('label').plot(y='count', kind='pie', autopct='%.2f', title='Duplicated Records'))
plt.axis("off")
plt.show()
display(dupes.groupBy('record_count').count().orderBy(F.col('count').desc()).limit(20).toPandas().plot(x='record_count', kind='bar', title='Distribution of total duplicated records'))
plt.show()
display(dupes.select(field, 'record_count').limit(10).toPandas())
plt.show()
else:
print("No duplicates found")
dupes.unpersist()
def cwsDateToTimestamp(df, fields):
ts_format = 'yyyyMMddHHmm'
for field in fields:
df = df.withColumn(field,
((F.unix_timestamp((F.col(field)/100000).cast('bigint').cast('string'),
ts_format)) * 100000) + (F.col(field) % 100000))
return df
def listEmptyFields(df):
result = []
for f in df.columns:
if df.select(f).where(F.col(f).isNotNull()).distinct().count() == 0:
result.append(f)
print(f)
return result
def computeFuzzySimilarity(df, id_field, fields,
match_column_prefix='matched',
distance_column_prefix='levdist'):
df = df.select(id_field, *fields)
columns = df.columns
match_columns = []
matched_id_field = f'{match_column_prefix}_{id_field}'
for col in columns:
match_column_name = f'{match_column_prefix}_{col}'
match_columns.append(F.col(col).alias(match_column_name))
df2 = df.select(*match_columns)
df = df.join(df2, df[id_field] != df2[matched_id_field])
unique_window = F.concat(
F.when(F.col(id_field) < F.col(matched_id_field), F.col(id_field)).cast('string'),
F.lit('-----'),
F.when(F.col(id_field) < F.col(matched_id_field), F.col(matched_id_field)).cast('string'))
unique_window_name = f'unique_{id_field}'
df = df.select(
F.row_number().over(Window.partitionBy(unique_window).orderBy(F.col(id_field))).alias(unique_window_name),
*df.columns
).where(F.col(unique_window_name) == 1).drop(unique_window_name)
lev_cols = []
for field in fields:
match_column_name = f'{match_column_prefix}_{field}'
distance_column_name = f'{distance_column_prefix}_{field}'
lev_col = F.levenshtein(F.col(field), F.col(match_column_name)).alias(distance_column_name)
lev_cols.append(lev_col)
df = df.select(*(df.columns + lev_cols))
return df
def select(df, columns_mapping):
columns = []
for k, v in columns_mapping.items():
if isinstance(v, str):
columns.append(F.col(v).alias(k))
else:
columns.append(v.alias(k))
return df.select(*columns)
def reference_lookup(df, lookup_df, column, lookup_key_column, lookup_value_column, lookup_output_column, lookup_key_alias='df_lookup_key'):
lookup_df = lookup_df.select(F.col(lookup_key_column).alias(lookup_key_alias),
F.col(lookup_value_column).alias(lookup_output_column))
df = df.join(lookup_df, df[column] == lookup_df[lookup_key_alias], how='left').drop(lookup_key_alias)
return df
def enrich(df, enrich_df, join_condition, columns_mapping):
columns = dict([(c,c) for c in df.columns])
columns.update(columns_mapping)
df = df.join(enrich_df, join_condition, how='left')
df = df.transform(select, columns)
return df
def union(df, df2):
all_cols = sorted(set(df.columns + df2.columns))
df1_cols = []
df2_cols = []
for c in df.columns:
if c not in all_cols:
df1_cols.append(F.lit(None).alias(c))
else:
df1_cols.append(c)
for c in df.columns:
if c not in all_cols:
df2_cols.append(F.lit(None).alias(c))
else:
df2_cols.append(c)
df = df.select(*sorted(df1_cols))
df2 = df2.select(*sorted(df2_cols))
return df.union(df2)
def fix_timecol(df, cols):
"""
When a column is a Time column, and JDBC connection loaded UTC data as 1900 epoch, timezone conversion may
result in weird behavior due to because there are some historical behavior in 1900s which leads to differences
in timezone conversion.
Eg: In MYT timezone, 1900-01-01 00:00:00 UTC would be loaded as 1899-12-31 22:38:21 UTC and
`from_utc_timestamp` converts the value to 1900-01-01 06:38:21 MYT, which is wrong because it
should be 1901-01-01 07:30:00 or 1901-01-01 06:45:00. `from_unixtime` handles this more correctly
This code attempt to fix the time column by reloading the time raw internal integer timestamp as 1970 epoch
"""
columns = []
for c in df.columns:
if c in cols:
cf = F.concat(F.lit('1970-01-01 '),
F.from_unixtime(F.col(c).cast('long'))
.substr(12,8))
cf = F.unix_timestamp(cf)
columns.append(
F.when(F.col(c).isNotNull(), cf)
.otherwise(F.lit(None)).alias(c))
else:
columns.append(c)
return df.select(*columns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment