Created
August 29, 2023 21:39
-
-
Save ianmcook/f70fc185d29ae97bdf85ffe0378c68e0 to your computer and use it in GitHub Desktop.
Use Substrait expressions to filter and project PyArrow datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tempfile | |
| import pathlib | |
| import numpy as np | |
| import pyarrow as pa | |
| import pyarrow.compute as pc | |
| import pyarrow.parquet as pq | |
| import pyarrow.dataset as ds | |
| # create a small dataset for example purposes | |
| base = pathlib.Path(tempfile.mkdtemp(prefix="pyarrow-")) | |
| (base / "parquet_dataset").mkdir(exist_ok=True) | |
| table = pa.table({'a': range(10), 'b': np.random.randn(10), 'c': [1, 2] * 5}) | |
| pq.write_table(table.slice(0, 5), base / "parquet_dataset/data1.parquet") | |
| pq.write_table(table.slice(5, 10), base / "parquet_dataset/data2.parquet") | |
| dataset = ds.dataset(base / "parquet_dataset", format="parquet") | |
| print(dataset.to_table().to_pandas()) | |
| ## a b c | |
| ## 0 0 -0.207184 1 | |
| ## 1 1 -0.317578 2 | |
| ## 2 2 0.650184 1 | |
| ## 3 3 0.902984 2 | |
| ## 4 4 1.153264 1 | |
| ## 5 5 -1.981604 2 | |
| ## 6 6 -0.907781 1 | |
| ## 7 7 1.018262 2 | |
| ## 8 8 -0.813167 1 | |
| ## 9 9 0.266288 2 | |
| # create a Boolean-valued PyArrow expression and serialize it to a Substrait expression | |
| # (which also requires a schema because Substrait expressions are bound to schemas) | |
| expr_in = (pc.field("a") < pc.scalar(3)) | (pc.field("b") > pc.scalar(1)) | |
| schema = pa.schema([("a", pa.int32()), ("b", pa.float64())]) | |
| bytes = expr_in.to_substrait(schema) | |
| # deserialize the Substrait expression to a Boolean PyArrow expression | |
| expr_out = pc.Expression.from_substrait(bytes) | |
| print(expr_out) | |
| ## ((FieldPath(0) < 3) or (FieldPath(1) > 1)) | |
| # use the expression to filter the dataset | |
| result = dataset.to_table(filter=expr_out) | |
| print(result.to_pandas()) | |
| ## a b c | |
| ## 0 0 -0.207184 1 | |
| ## 1 1 -0.317578 2 | |
| ## 2 2 0.650184 1 | |
| ## 3 4 1.153264 1 | |
| ## 4 7 1.018262 2 | |
| # create a dictionary of PyArrow expressions and serialize them to Substrait expressions | |
| # (reusing the same schema as above) | |
| exprs_in = { | |
| "a_renamed": ds.field("a"), | |
| "b_doubled": pc.multiply(ds.field("b"), 2) | |
| } | |
| dict_of_bytes = {} | |
| for key, value in exprs_in.items(): | |
| dict_of_bytes[key] = exprs_in[key].to_substrait(schema) | |
| # deserialize the Substrait expressions to PyArrow expressions | |
| exprs_out = {} | |
| for key, value in dict_of_bytes.items(): | |
| exprs_out[key] = pc.Expression.from_substrait(dict_of_bytes[key]) | |
| # use the expressions to project columns on the dataset | |
| result = dataset.to_table(columns=exprs_out) | |
| print(result.to_pandas()) | |
| ## a_renamed b_doubled | |
| ## 0 0 -0.414369 | |
| ## 1 1 -0.635156 | |
| ## 2 2 1.300367 | |
| ## 3 3 1.805968 | |
| ## 4 4 2.306528 | |
| ## 5 5 -3.963209 | |
| ## 6 6 -1.815562 | |
| ## 7 7 2.036524 | |
| ## 8 8 -1.626335 | |
| ## 9 9 0.532576 |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Before Arrow 14.0.0 is released, running this requires a development build of PyArrow with
PYARROW_WITH_SUBSTRAITenabled and Arrow C++ withARROW_SUBSTRAITenabled.