Created
September 30, 2024 15:11
-
-
Save rjzamora/4aacc3b31eaa963d42a3489fe74b7176 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "10575011-5099-4216-b294-33c99daf5e26", | |
"metadata": {}, | |
"source": [ | |
"# Cudf-Polars + Dask: POC Demo\n", | |
"\n", | |
"**cuDF Branch**: [rjzamora:cudf-polars-dask](https://github.com/rapidsai/cudf/compare/branch-24.12...rjzamora:cudf:cudf-polars-dask)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "119e52f0-3343-4994-ad30-a44ede3e7556", | |
"metadata": {}, | |
"source": [ | |
"## TODO\n", | |
"\n", | |
"- **Serialization**: The current implementation does not yet support multi-gpu execution yet, because the dask serialization primitives are not defined at the `pylibcudf` level.\n", | |
"- **Task fusion**: The current implementation does not yet fuse IO tasks with follow-up compute tasks when possible.\n", | |
"- Most `IR`/`Expr` logic has yet to be implemented as `DaskNode`/`DaskExprNode` classes." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07df307b-94b7-48c4-895e-a6abde9a2f5c", | |
"metadata": {}, | |
"source": [ | |
"## Key Changes\n", | |
"\n", | |
"### Build & execute a task graph in `cudf_polars/callback.py`\n", | |
"\n", | |
"In order to leverage Dask for parallel execution, we need to intercept the control flow in `callback.py`. Rather than traversing the `IR` graph and evaluating it eagerly, we use start by building a task graph:\n", | |
"\n", | |
"```python\n", | |
"def _callback(\n", | |
" ir: IR,\n", | |
" ...\n", | |
") -> pl.DataFrame:\n", | |
" ...\n", | |
"\n", | |
" if CUDF_POLARS_DASK == \"TRUE\":\n", | |
" \n", | |
" # Extract `DaskNode` object from `IR` node\n", | |
" dask_node = ir._dask_node()\n", | |
" # Build task graph\n", | |
" dsk = dask_node._task_graph()\n", | |
" # Execute task graph\n", | |
" result = get(dsk, dask_node._key)\n", | |
" # Return Polars object\n", | |
" return result.to_polars()\n", | |
"```\n", | |
"\n", | |
"### Define `DaskNode` classes to encapsulate Dask-specific logic\n", | |
"\n", | |
"The only way we can build a task graph in `callback.py` is if the `IR`/`Expr` nodes can be \"lowered\" into objects that implement the necessary logic for task-graph generation. Here we use the `DaskNode` and `DaskExprNode` classes for `IR` and `Expr` objects, respectively:\n", | |
"\n", | |
"```python\n", | |
"class DaskNode:\n", | |
" \"\"\" Dask-specific version of an IR node. \"\"\"\n", | |
"\n", | |
" __slots__ = (\"_ir\",)\n", | |
" _ir: IR\n", | |
" \"\"\"IR object linked to this node.\"\"\"\n", | |
"\n", | |
" @cached_property\n", | |
" def _key(self) -> str:\n", | |
" ...\n", | |
"\n", | |
" def _ir_dependencies(self):\n", | |
" ...\n", | |
"\n", | |
" @property\n", | |
" def _npartitions(self) -> int:\n", | |
" # A DaskNode must implement _npartitions\n", | |
" raise NotImplementedError(f\"Partition count for {type(self).__name__}\")\n", | |
"\n", | |
" def _tasks(self) -> MutableMapping[Any, Any]:\n", | |
" # A DaskNode must implement _tasks\n", | |
" raise NotImplementedError(f\"Generate tasks for {type(self).__name__}\")\n", | |
"\n", | |
" def _task_graph(self) -> MutableMapping[Any, Any]:\n", | |
" ...\n", | |
"\n", | |
"\n", | |
"class DaskExprNode(DaskNode):\n", | |
" \"\"\"Dask-specific version of an Expr node. \"\"\"\n", | |
"\n", | |
" __slots__ = (\"_ir\", \"_expr\", \"_name\")\n", | |
" _ir: IR\n", | |
" \"\"\"IR object linked to this node.\"\"\"\n", | |
" _expr: Expr\n", | |
" \"\"\"Expr object linked to this node.\"\"\"\n", | |
" _name: str\n", | |
" \"\"\"Name of the column produced by this node.\"\"\"\n", | |
"\n", | |
" ...\n", | |
"```\n", | |
"\n", | |
"\n", | |
"### Implement the necessary `DaskNode` subclasses\n", | |
"\n", | |
"In order to lower a specific `IR` node, we need to implement the corresponding `DaskNode` subclass. For example, in order to lower a lazy Parquet `Scan` operation, we need to implement a `ReadParquet(DaskNode)` class:\n", | |
"\n", | |
"```python\n", | |
"class ReadParquet(DaskNode):\n", | |
" _ir: Scan\n", | |
"\n", | |
" @property\n", | |
" def _npartitions(self):\n", | |
" return len(self._ir.paths)\n", | |
"\n", | |
" @staticmethod\n", | |
" def read_parquet(path, columns, nrows, skip_rows, predicate, schema):\n", | |
" \"\"\"Read parquet data.\"\"\"\n", | |
" tbl_w_meta = plc.io.parquet.read_parquet(\n", | |
" plc.io.SourceInfo([path]),\n", | |
" columns=columns,\n", | |
" nrows=nrows,\n", | |
" skip_rows=skip_rows,\n", | |
" )\n", | |
" df = DataFrame.from_table(\n", | |
" tbl_w_meta.tbl,\n", | |
" # TODO: consider nested column names?\n", | |
" tbl_w_meta.column_names(include_children=False),\n", | |
" )\n", | |
" assert all(c.obj.type() == schema[c.name] for c in df.columns)\n", | |
" if predicate is None:\n", | |
" return df\n", | |
" else:\n", | |
" (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)\n", | |
" return df.filter(mask)\n", | |
"\n", | |
" def _tasks(self) -> MutableMapping[Any, Any]:\n", | |
" key = self._key\n", | |
" with_columns = self._ir.with_columns\n", | |
" n_rows = self._ir.n_rows\n", | |
" skip_rows = self._ir.skip_rows\n", | |
" predicate = self._ir.predicate\n", | |
" schema = self._ir.schema\n", | |
" return {\n", | |
" (key, i): (\n", | |
" self.read_parquet,\n", | |
" path,\n", | |
" with_columns,\n", | |
" n_rows,\n", | |
" skip_rows,\n", | |
" predicate,\n", | |
" schema,\n", | |
" )\n", | |
" for i, path in enumerate(self._ir.paths)\n", | |
" }\n", | |
"```\n", | |
"\n", | |
"#### Dealing with `Expr` evaluation\n", | |
"\n", | |
"Since `IR` evaluation often requires the evaluation of a separate `Expr` graph, we also need a `DaskExprNode` class to implement this nested logic. For now, it seems reasonable to link an `DaskExprNode` to both an `IR` **and** an `Expr` object. In order to preserve `NamedExpr` properties, the `DaskExprNode` also tracks the name of the `NamedColumn` it operates on.\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "802f242a-a61f-4b7f-97f2-fed5d85ca5d8", | |
"metadata": {}, | |
"source": [ | |
"## Usage Example\n", | |
"\n", | |
"Create a toy Parquet Dataset:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "2c584872-de55-4e8f-abf5-0c98ba12f9d2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import dask.dataframe as dd\n", | |
"\n", | |
"dd.from_dict(\n", | |
" {\"x\": range(30), \"y\": [\"cat\", \"dog\", \"fish\"] * 10},\n", | |
" npartitions=3,\n", | |
").to_parquet(\"demo_parquet\", write_index=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c5bb4ab5-2431-4b8f-b635-84cc915a360d", | |
"metadata": {}, | |
"source": [ | |
"Opt into Dask execution in cuDF-Polars and perform a lazy scan of `\"demo_parquet\"`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "39b15e1e-e291-4112-88e4-d6a369103711", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import polars as pl\n", | |
"\n", | |
"os.environ[\"CUDF_POLARS_DASK\"] = \"True\"\n", | |
"path = \"demo_parquet\"\n", | |
"df = pl.scan_parquet(path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "71bc3144-9ca6-4bdf-a573-65f66f270af0", | |
"metadata": {}, | |
"source": [ | |
"Call collect with the GPU engine:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "dc2edd90-7be2-4894-9df1-a0ed0efa7f81", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div><style>\n", | |
".dataframe > thead > tr,\n", | |
".dataframe > tbody > tr {\n", | |
" text-align: right;\n", | |
" white-space: pre-wrap;\n", | |
"}\n", | |
"</style>\n", | |
"<small>shape: (30, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>x</th><th>y</th></tr><tr><td>i64</td><td>str</td></tr></thead><tbody><tr><td>0</td><td>"cat"</td></tr><tr><td>1</td><td>"dog"</td></tr><tr><td>2</td><td>"fish"</td></tr><tr><td>3</td><td>"cat"</td></tr><tr><td>4</td><td>"dog"</td></tr><tr><td>…</td><td>…</td></tr><tr><td>25</td><td>"dog"</td></tr><tr><td>26</td><td>"fish"</td></tr><tr><td>27</td><td>"cat"</td></tr><tr><td>28</td><td>"dog"</td></tr><tr><td>29</td><td>"fish"</td></tr></tbody></table></div>" | |
], | |
"text/plain": [ | |
"shape: (30, 2)\n", | |
"┌─────┬──────┐\n", | |
"│ x ┆ y │\n", | |
"│ --- ┆ --- │\n", | |
"│ i64 ┆ str │\n", | |
"╞═════╪══════╡\n", | |
"│ 0 ┆ cat │\n", | |
"│ 1 ┆ dog │\n", | |
"│ 2 ┆ fish │\n", | |
"│ 3 ┆ cat │\n", | |
"│ 4 ┆ dog │\n", | |
"│ … ┆ … │\n", | |
"│ 25 ┆ dog │\n", | |
"│ 26 ┆ fish │\n", | |
"│ 27 ┆ cat │\n", | |
"│ 28 ┆ dog │\n", | |
"│ 29 ┆ fish │\n", | |
"└─────┴──────┘" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.collect(engine=\"gpu\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1635c33f-9435-4c48-bc29-c72e5e4f855d", | |
"metadata": {}, | |
"source": [ | |
"When you call `collect` in this example, cudf-polars will build and execute a task graph that looks something like this:\n", | |
"\n", | |
"```\n", | |
"{'readparquet-275df159f57b275dd0b1805f49b9c375': (<bound method DataFrame.concatenate of <class 'cudf_polars.containers.dataframe.DataFrame'>>,\n", | |
" [('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 0),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 1),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 2)]),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 0): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.0.parquet',\n", | |
" None,\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n", | |
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>}),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 1): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.1.parquet',\n", | |
" None,\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n", | |
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>}),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 2): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.2.parquet',\n", | |
" None,\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n", | |
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>})}\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0b59b321-044e-48e3-8bc0-b9252c851713", | |
"metadata": {}, | |
"source": [ | |
"We can also performa simple `sum` aggregation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "07dbe993-8dee-4a4b-8ecd-c26f4c6acf93", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div><style>\n", | |
".dataframe > thead > tr,\n", | |
".dataframe > tbody > tr {\n", | |
" text-align: right;\n", | |
" white-space: pre-wrap;\n", | |
"}\n", | |
"</style>\n", | |
"<small>shape: (1, 1)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>x</th></tr><tr><td>i64</td></tr></thead><tbody><tr><td>435</td></tr></tbody></table></div>" | |
], | |
"text/plain": [ | |
"shape: (1, 1)\n", | |
"┌─────┐\n", | |
"│ x │\n", | |
"│ --- │\n", | |
"│ i64 │\n", | |
"╞═════╡\n", | |
"│ 435 │\n", | |
"└─────┘" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.select(pl.sum(\"x\")).collect(engine=\"gpu\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d167c417-ee44-4481-af01-17419e70e4ac", | |
"metadata": {}, | |
"source": [ | |
"In this case, we will build and execute a slightly-more complex task graph:\n", | |
"\n", | |
"```\n", | |
"{'daskselect-b005fc5162d938fe186aa1908a3c81f2': ('daskselect-b005fc5162d938fe186aa1908a3c81f2',\n", | |
" 0),\n", | |
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._chunk at 0x7fe510e527a0>,\n", | |
" (<function DaskCol._op at 0x7fe510e52520>,\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 0),\n", | |
" 'x'),\n", | |
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n", | |
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n", | |
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 1): (<function SumAgg._chunk at 0x7fe510e527a0>,\n", | |
" (<function DaskCol._op at 0x7fe510e52520>,\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 1),\n", | |
" 'x'),\n", | |
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n", | |
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n", | |
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 2): (<function SumAgg._chunk at 0x7fe510e527a0>,\n", | |
" (<function DaskCol._op at 0x7fe510e52520>,\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n", | |
" 2),\n", | |
" 'x'),\n", | |
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n", | |
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n", | |
" ('concat-sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._concat at 0x7fe510e52840>,\n", | |
" [('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n", | |
" 0),\n", | |
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n", | |
" 1),\n", | |
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n", | |
" 2)]),\n", | |
" ('daskselect-b005fc5162d938fe186aa1908a3c81f2', 0): (<function DaskSelect._op at 0x7fe510e52340>,\n", | |
" [('sumagg-14654d51f6722c8b42c1b1e7cd881226',\n", | |
" 0)],\n", | |
" True),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 0): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.0.parquet',\n", | |
" ['x'],\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 1): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.1.parquet',\n", | |
" ['x'],\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n", | |
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 2): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n", | |
" 'demo_parquet/part.2.parquet',\n", | |
" ['x'],\n", | |
" -1,\n", | |
" 0,\n", | |
" None,\n", | |
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n", | |
" ('sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._finalize at 0x7fe510e528e0>,\n", | |
" ('concat-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n", | |
" 0),\n", | |
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n", | |
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n", | |
" 'x')}\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2f33d2e9-1873-4a06-afeb-83c69a46fa21", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment