Last active
October 24, 2019 09:16
-
-
Save kagesenshi/253dc3584fb01c828562be198a64d472 to your computer and use it in GitHub Desktop.
Hubcap data transformation helpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Data cleansing helper functions | |
from pyspark.sql import functions as F | |
from pyspark.sql import Window, DataFrame | |
from IPython.display import HTML, Markdown | |
from pyspark.ml.feature import VectorAssembler | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
def transform(self, function, *args, **kwargs): | |
return function(self, *args, **kwargs) | |
if getattr(DataFrame, 'transform', None) is None: | |
DataFrame.transform = transform | |
def removeColumnPrefixes(df, prefixes): | |
for col in df.columns: | |
for prefix in prefixes: | |
if col.startswith(prefix): | |
df = df.withColumnRenamed(col, col.replace(prefix,'')) | |
return df | |
def renameColumns(df, mapping): | |
for ocol, newcol in mapping.items(): | |
if ocol in df.columns: | |
df = df.withColumnRenamed(ocol, newcol) | |
else: | |
print(f"WARNING: column {ocol} unavailable") | |
return df | |
def dropColumns(df, columns): | |
return df.select(*[col for col in df.columns if col not in columns]) | |
def trimColumns(df, columns): | |
cols = [] | |
for col in df.columns: | |
if col in columns: | |
cols.append(F.when(F.col(col).isNotNull() & (F.trim(F.col(col)) != F.lit('')), | |
F.trim(F.col(col))).alias(col)) | |
else: | |
cols.append(F.col(col)) | |
return df.select(*cols) | |
def trimAllStringColumns(df): | |
string_fields = [] | |
for field in df.columns: | |
dataType = df.schema[field].dataType.typeName() | |
if dataType == 'string': | |
string_fields.append(field) | |
return trimColumns(df, string_fields) | |
def plotStringLengthDistribution(df, field, length_field_name='length', title=None): | |
if isinstance(field, str): | |
field = F.col(field) | |
df = (df.select(field) | |
.withColumn(length_field_name, F.length(field)) | |
.groupBy(F.col(length_field_name)).count() | |
.orderBy(F.col('count').desc())) | |
display(df.toPandas().plot(x=length_field_name, kind='bar', title=title)) | |
plt.show() | |
def findDuplicatesBy(df, fields): | |
if not isinstance(fields, list) and not isinstance(fields, tuple): | |
fields = [fields] | |
columns = df.columns | |
select_columns = columns + [F.count('*').over(Window.partitionBy(*fields)).alias('record_count')] | |
grouping = df.select(select_columns) | |
dupes = (df.groupBy(*fields) | |
.agg(F.collect_list(F.struct(*columns)).alias('records'), | |
F.count('*').alias('record_count')) | |
.where('record_count > 1')) | |
unique = grouping.where('record_count == 1').drop('record_count') | |
return [dupes, unique] | |
def stringToVector(df, field): | |
df = df.select(field).where(df['iddoc_type'] == 'O') | |
# id structure analysis | |
df = (df.select(field, F.posexplode(F.split(field,'')).alias('position', 'letter')) | |
.withColumn('letter_pos', F.concat(F.lit('letter'), F.format_string('%02d', F.col('position'))))) | |
df = df.withColumn('letter_type', F.when(F.regexp_extract(F.col('letter'), r'^[a-zA-Z]+$', 0) != '', 2) | |
.when(F.regexp_extract(F.col('letter'), r'^[0-9]+$', 0) != '', 1)) | |
df = df.drop('position','letter').groupBy(field).pivot('letter_pos').agg(F.first('letter_type')) | |
df = df.withColumn('str_length', F.length(F.col(field))) | |
for c in df.columns: | |
df = df.withColumn(c, F.when(F.col(c).isNotNull(), F.col(c)).otherwise(0)) | |
vect_in = df.drop('iddoc') | |
vect_assembler = VectorAssembler(inputCols=vect_in.columns, outputCol='features') | |
vect = vect_assembler.transform(vect_in) | |
return [vect, df] | |
def profile(df, field, find_duplicates=False): | |
dataType = df.schema[field].dataType.typeName() | |
display(Markdown(f'# Field {field} ({dataType})')) | |
nullreport = df.select(F.when(F.col(field).isNotNull(), F.lit('Not Null')) | |
.otherwise(F.lit('Null')).alias('recorded')).groupBy(F.col('recorded')).count() | |
display(nullreport.toPandas().set_index('recorded').plot(y='count', autopct='%.2f', kind='pie', title='Null Report')) | |
plt.axis("off") | |
plt.show() | |
if dataType in ['integer', 'long','float']: | |
zerosreport = df.select(F.when(F.col(field).isNotNull() & (F.col(field) != 0), F.lit('Not Zero')) | |
.otherwise(F.lit('Zero')).alias('recorded')).groupBy(F.col('recorded')).count() | |
display(zerosreport.toPandas().set_index('recorded').plot(y='count', autopct='%.2f', kind='pie', title='Zeros Report')) | |
plt.axis("off") | |
plt.show() | |
if dataType == 'string': | |
dist = df.groupBy(field).count() | |
dist_top = dist.orderBy(F.col('count').desc()).limit(20) | |
dist_bottom = dist.orderBy(F.col('count')).limit(20) | |
display(dist_top.toPandas().plot(kind='bar', x=field, title='Top 20 values by count')) | |
display(dist_bottom.toPandas().plot(kind='bar', x=field, title='Bottom 20 values by count')) | |
plt.show() | |
plotStringLengthDistribution(df, field, title='Length distribution') | |
if find_duplicates: | |
if dataType == 'string': | |
dupes, unique = findDuplicatesBy(df.withColumn(field,F.trim(F.upper(F.col(field)))), field) | |
else: | |
dupes, unique = findDuplicatesBy(df, field) | |
dupes = dupes.cache() | |
unique_count = unique.select(F.lit('Unique').alias('label'), F.count('*').alias('count')) | |
dupes_count = dupes.select(F.lit('Duplicate').alias('label'), F.count('*').alias('count')) | |
dupe_report = unique_count.union(dupes_count).toPandas() | |
has_duplicate = False | |
for r in dupe_report.to_dict('records'): | |
if r['label'] == 'Duplicate' and r['count'] > 0: | |
has_duplicate = True | |
if has_duplicate: | |
display(dupe_report.set_index('label').plot(y='count', kind='pie', autopct='%.2f', title='Duplicated Records')) | |
plt.axis("off") | |
plt.show() | |
display(dupes.groupBy('record_count').count().orderBy(F.col('count').desc()).limit(20).toPandas().plot(x='record_count', kind='bar', title='Distribution of total duplicated records')) | |
plt.show() | |
display(dupes.select(field, 'record_count').limit(10).toPandas()) | |
plt.show() | |
else: | |
print("No duplicates found") | |
dupes.unpersist() | |
def cwsDateToTimestamp(df, fields): | |
ts_format = 'yyyyMMddHHmm' | |
for field in fields: | |
df = df.withColumn(field, | |
((F.unix_timestamp((F.col(field)/100000).cast('bigint').cast('string'), | |
ts_format)) * 100000) + (F.col(field) % 100000)) | |
return df | |
def listEmptyFields(df): | |
result = [] | |
for f in df.columns: | |
if df.select(f).where(F.col(f).isNotNull()).distinct().count() == 0: | |
result.append(f) | |
print(f) | |
return result | |
def computeFuzzySimilarity(df, id_field, fields, | |
match_column_prefix='matched', | |
distance_column_prefix='levdist'): | |
df = df.select(id_field, *fields) | |
columns = df.columns | |
match_columns = [] | |
matched_id_field = f'{match_column_prefix}_{id_field}' | |
for col in columns: | |
match_column_name = f'{match_column_prefix}_{col}' | |
match_columns.append(F.col(col).alias(match_column_name)) | |
df2 = df.select(*match_columns) | |
df = df.join(df2, df[id_field] != df2[matched_id_field]) | |
unique_window = F.concat( | |
F.when(F.col(id_field) < F.col(matched_id_field), F.col(id_field)).cast('string'), | |
F.lit('-----'), | |
F.when(F.col(id_field) < F.col(matched_id_field), F.col(matched_id_field)).cast('string')) | |
unique_window_name = f'unique_{id_field}' | |
df = df.select( | |
F.row_number().over(Window.partitionBy(unique_window).orderBy(F.col(id_field))).alias(unique_window_name), | |
*df.columns | |
).where(F.col(unique_window_name) == 1).drop(unique_window_name) | |
lev_cols = [] | |
for field in fields: | |
match_column_name = f'{match_column_prefix}_{field}' | |
distance_column_name = f'{distance_column_prefix}_{field}' | |
lev_col = F.levenshtein(F.col(field), F.col(match_column_name)).alias(distance_column_name) | |
lev_cols.append(lev_col) | |
df = df.select(*(df.columns + lev_cols)) | |
return df | |
def select(df, columns_mapping): | |
columns = [] | |
for k, v in columns_mapping.items(): | |
if isinstance(v, str): | |
columns.append(F.col(v).alias(k)) | |
else: | |
columns.append(v.alias(k)) | |
return df.select(*columns) | |
def reference_lookup(df, lookup_df, column, lookup_key_column, lookup_value_column, lookup_output_column, lookup_key_alias='df_lookup_key'): | |
lookup_df = lookup_df.select(F.col(lookup_key_column).alias(lookup_key_alias), | |
F.col(lookup_value_column).alias(lookup_output_column)) | |
df = df.join(lookup_df, df[column] == lookup_df[lookup_key_alias], how='left').drop(lookup_key_alias) | |
return df | |
def enrich(df, enrich_df, join_condition, columns_mapping): | |
columns = dict([(c,c) for c in df.columns]) | |
columns.update(columns_mapping) | |
df = df.join(enrich_df, join_condition, how='left') | |
df = df.transform(select, columns) | |
return df | |
def union(df, df2): | |
all_cols = sorted(set(df.columns + df2.columns)) | |
df1_cols = [] | |
df2_cols = [] | |
for c in df.columns: | |
if c not in all_cols: | |
df1_cols.append(F.lit(None).alias(c)) | |
else: | |
df1_cols.append(c) | |
for c in df.columns: | |
if c not in all_cols: | |
df2_cols.append(F.lit(None).alias(c)) | |
else: | |
df2_cols.append(c) | |
df = df.select(*sorted(df1_cols)) | |
df2 = df2.select(*sorted(df2_cols)) | |
return df.union(df2) | |
def fix_timecol(df, cols): | |
""" | |
When a column is a Time column, and JDBC connection loaded UTC data as 1900 epoch, timezone conversion may | |
result in weird behavior due to because there are some historical behavior in 1900s which leads to differences | |
in timezone conversion. | |
Eg: In MYT timezone, 1900-01-01 00:00:00 UTC would be loaded as 1899-12-31 22:38:21 UTC and | |
`from_utc_timestamp` converts the value to 1900-01-01 06:38:21 MYT, which is wrong because it | |
should be 1901-01-01 07:30:00 or 1901-01-01 06:45:00. `from_unixtime` handles this more correctly | |
This code attempt to fix the time column by reloading the time raw internal integer timestamp as 1970 epoch | |
""" | |
columns = [] | |
for c in df.columns: | |
if c in cols: | |
cf = F.concat(F.lit('1970-01-01 '), | |
F.from_unixtime(F.col(c).cast('long')) | |
.substr(12,8)) | |
cf = F.unix_timestamp(cf) | |
columns.append( | |
F.when(F.col(c).isNotNull(), cf) | |
.otherwise(F.lit(None)).alias(c)) | |
else: | |
columns.append(c) | |
return df.select(*columns) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment