-
-
Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
import pandas as pd | |
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1): | |
""" | |
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting | |
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the | |
data types will be used to coerce the data in Pandas to Arrow conversion. | |
""" | |
from pyspark.serializers import ArrowSerializer, _create_batch | |
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType | |
from pyspark.sql.utils import require_minimum_pandas_version, \ | |
require_minimum_pyarrow_version | |
require_minimum_pandas_version() | |
require_minimum_pyarrow_version() | |
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype | |
# Determine arrow types to coerce data when creating batches | |
if isinstance(schema, StructType): | |
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] | |
elif isinstance(schema, DataType): | |
raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) | |
else: | |
# Any timestamps must be coerced to be compatible with Spark | |
arrow_types = [to_arrow_type(TimestampType()) | |
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None | |
for t in pdf.dtypes] | |
# Slice the DataFrame to be batched | |
step = -(-len(pdf) // parallelism) # round int up | |
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) | |
# Create Arrow record batches | |
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], | |
timezone) | |
for pdf_slice in pdf_slices] | |
return map(bytearray, map(ArrowSerializer().dumps, batches)) | |
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None): | |
from pyspark.sql.types import from_arrow_schema | |
from pyspark.sql.dataframe import DataFrame | |
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer | |
# Map rdd of pandas dataframes to arrow record batches | |
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache() | |
# If schema is not defined, get from the first dataframe | |
if schema is None: | |
schema = [str(x) if not isinstance(x, basestring) else | |
(x.encode('utf-8') if not isinstance(x, str) else x) | |
for x in prdd.map(lambda x: x.columns).first()] | |
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone)) | |
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) | |
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema) | |
for i, name in enumerate(schema): | |
struct.fields[i].name = name | |
struct.names[i] = name | |
schema = struct | |
# Create the Spark DataFrame directly from the Arrow data and schema | |
jrdd = prdd._to_java_object_rdd() | |
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( | |
jrdd, schema.json(), self._wrapped._jsqlContext) | |
df = DataFrame(jdf, self._wrapped) | |
df._schema = schema | |
return df | |
from pyspark.sql import SparkSession | |
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD |
Hi @linar-jether, Thank you so much for your time. :-)
I have just checked, it works perfectly.
Hi @linar-jether, I got stuck again when migrating the above code to Spark-3.0.0.
Can you please help me? I created a JIRA few days back explaining the whole issue. Thanks.
Hi @tahashmi, Can be done using:
Also this PR (SPARK-32846) might be useful as it uses user-facing APIs (but requires conversion to a pandas rdd)
from pyspark.sql import SparkSession
import pyarrow as pa
def _arrow_record_batch_dumps(rb):
return bytearray(rb.serialize())
def rb_return(ardd):
data = [
pa.array(range(5), type='int16'),
pa.array([-10, -5, 0, None, 10], type='int32')
]
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
return pa.RecordBatch.from_arrays(data, schema=schema)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("Python Arrow-in-Spark example") \
.getOrCreate()
# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
sc = spark.sparkContext
ardd = spark.sparkContext.parallelize([0, 1, 2], 3)
ardd = ardd.map(rb_return)
from pyspark.sql.pandas.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
# Filter out and cache arrow record batches
ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()
ardd = ardd.map(_arrow_record_batch_dumps)
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
schema = from_arrow_schema(schema)
jrdd = ardd._to_java_object_rdd()
jdf = spark._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
spark._wrapped._jsqlContext)
df = DataFrame(jdf, spark._wrapped)
df._schema = schema
df.show()
Thank you so much @linar-jether ! It's very helpful.
I want to avoid Pandas conversion because my data is in Arrow RecordBatches on all worker nodes.
@linar-jether @tahashmi while running your code I am facing below error
Traceback (most recent call last):
File "recordbatch.py", line 48, in
spark._wrapped._jsqlContext)
AttributeError: 'SparkSession' object has no attribute '_wrapped'
Am I missing something?
Hi @tahashmi, this seems to work for me:
Issue was using flatMap on the record batch, causing it to iterate on arrays