Last active
March 14, 2018 05:47
-
-
Save BryanCutler/2d2ae04e81fa96ba4b61dc095726419f to your computer and use it in GitHub Desktop.
Vectorized UDFs in Python SPARK-21190
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataFrame(object): | |
... | |
def asPandas(self): | |
return ArrowDataFrame(self) | |
class ArrowDataFrame(object): | |
""" | |
Wraps a Python DataFrame to group/winow then apply using``pandas.DataFrame`` | |
""" | |
def __init__(self, data_frame): | |
self.df = data_frame | |
self._lazy_rdd = None | |
@property | |
def rdd(self): | |
if self._lazy_rdd is None: | |
jrdd = self._jdf.javaToPython() | |
self._lazy_rdd = ArrowRDD(jrdd, self.df._sc) | |
return self._lazy_rdd | |
def groupBy(self, *cols): | |
jgd = self._jdf.groupBy(self._jcols(*cols)) | |
return ArrowGroupedData(jgd, self.df.sql_ctx) | |
def windowOver(self, window_spec): | |
raise NotImplementedError() | |
class ArrowGroupedData(GroupedData): | |
""" | |
Wraps a Python GroupedData object to process groups as ``pandas.DataFrame`` | |
""" | |
def __init__(self, jgd, sql_ctx): | |
super(ArrowGroupedData, self).__init__(jgd, sql_ctx) | |
def agg(self, f): | |
# Apply function f to each group | |
return DataFrame(...) | |
class ArrowRDD(object): | |
""" | |
Wraps a Python RDD to deserialize using Arrow into ``pandas.DataFrame`` for processing. | |
""" | |
def __init__(self, jrdd, ctx, pipelined_rdd=None): | |
if pipelined_rdd is None: | |
self._rdd = RDD(jrdd, ctx, jrdd_deserializer=ArrowPandasSerializer()) | |
else: | |
self._rdd = pipelined_rdd | |
def _wrap_rdd(self, rdd): | |
rdd._jrdd_deserializer = self._rdd._jrdd_deserializer | |
return ArrowRDD(jrdd=None, ctx=None, pipelined_rdd=rdd) | |
def map(self, f, preservesPartitioning=False): | |
rdd = self._rdd.map(f, preservesPartitioning=preservesPartitioning) | |
return self._wrap_rdd(rdd) | |
def reduce(self, f): | |
return self._rdd.reduce(f) | |
def count(self): | |
return self._rdd.count() | |
def collect(self): | |
return self._rdd.collect() | |
def toDF(self): | |
schema = convert_arrow_schema() | |
return ArrowDataFrame(self.ctx.createDataFrame(self, schema)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment