Skip to content

Instantly share code, notes, and snippets.

@rjzamora
Created September 30, 2024 15:11
Show Gist options
  • Save rjzamora/4aacc3b31eaa963d42a3489fe74b7176 to your computer and use it in GitHub Desktop.
Save rjzamora/4aacc3b31eaa963d42a3489fe74b7176 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "10575011-5099-4216-b294-33c99daf5e26",
"metadata": {},
"source": [
"# Cudf-Polars + Dask: POC Demo\n",
"\n",
"**cuDF Branch**: [rjzamora:cudf-polars-dask](https://github.com/rapidsai/cudf/compare/branch-24.12...rjzamora:cudf:cudf-polars-dask)"
]
},
{
"cell_type": "markdown",
"id": "119e52f0-3343-4994-ad30-a44ede3e7556",
"metadata": {},
"source": [
"## TODO\n",
"\n",
"- **Serialization**: The current implementation does not yet support multi-gpu execution yet, because the dask serialization primitives are not defined at the `pylibcudf` level.\n",
"- **Task fusion**: The current implementation does not yet fuse IO tasks with follow-up compute tasks when possible.\n",
"- Most `IR`/`Expr` logic has yet to be implemented as `DaskNode`/`DaskExprNode` classes."
]
},
{
"cell_type": "markdown",
"id": "07df307b-94b7-48c4-895e-a6abde9a2f5c",
"metadata": {},
"source": [
"## Key Changes\n",
"\n",
"### Build & execute a task graph in `cudf_polars/callback.py`\n",
"\n",
"In order to leverage Dask for parallel execution, we need to intercept the control flow in `callback.py`. Rather than traversing the `IR` graph and evaluating it eagerly, we use start by building a task graph:\n",
"\n",
"```python\n",
"def _callback(\n",
" ir: IR,\n",
" ...\n",
") -> pl.DataFrame:\n",
" ...\n",
"\n",
" if CUDF_POLARS_DASK == \"TRUE\":\n",
" \n",
" # Extract `DaskNode` object from `IR` node\n",
" dask_node = ir._dask_node()\n",
" # Build task graph\n",
" dsk = dask_node._task_graph()\n",
" # Execute task graph\n",
" result = get(dsk, dask_node._key)\n",
" # Return Polars object\n",
" return result.to_polars()\n",
"```\n",
"\n",
"### Define `DaskNode` classes to encapsulate Dask-specific logic\n",
"\n",
"The only way we can build a task graph in `callback.py` is if the `IR`/`Expr` nodes can be \"lowered\" into objects that implement the necessary logic for task-graph generation. Here we use the `DaskNode` and `DaskExprNode` classes for `IR` and `Expr` objects, respectively:\n",
"\n",
"```python\n",
"class DaskNode:\n",
" \"\"\" Dask-specific version of an IR node. \"\"\"\n",
"\n",
" __slots__ = (\"_ir\",)\n",
" _ir: IR\n",
" \"\"\"IR object linked to this node.\"\"\"\n",
"\n",
" @cached_property\n",
" def _key(self) -> str:\n",
" ...\n",
"\n",
" def _ir_dependencies(self):\n",
" ...\n",
"\n",
" @property\n",
" def _npartitions(self) -> int:\n",
" # A DaskNode must implement _npartitions\n",
" raise NotImplementedError(f\"Partition count for {type(self).__name__}\")\n",
"\n",
" def _tasks(self) -> MutableMapping[Any, Any]:\n",
" # A DaskNode must implement _tasks\n",
" raise NotImplementedError(f\"Generate tasks for {type(self).__name__}\")\n",
"\n",
" def _task_graph(self) -> MutableMapping[Any, Any]:\n",
" ...\n",
"\n",
"\n",
"class DaskExprNode(DaskNode):\n",
" \"\"\"Dask-specific version of an Expr node. \"\"\"\n",
"\n",
" __slots__ = (\"_ir\", \"_expr\", \"_name\")\n",
" _ir: IR\n",
" \"\"\"IR object linked to this node.\"\"\"\n",
" _expr: Expr\n",
" \"\"\"Expr object linked to this node.\"\"\"\n",
" _name: str\n",
" \"\"\"Name of the column produced by this node.\"\"\"\n",
"\n",
" ...\n",
"```\n",
"\n",
"\n",
"### Implement the necessary `DaskNode` subclasses\n",
"\n",
"In order to lower a specific `IR` node, we need to implement the corresponding `DaskNode` subclass. For example, in order to lower a lazy Parquet `Scan` operation, we need to implement a `ReadParquet(DaskNode)` class:\n",
"\n",
"```python\n",
"class ReadParquet(DaskNode):\n",
" _ir: Scan\n",
"\n",
" @property\n",
" def _npartitions(self):\n",
" return len(self._ir.paths)\n",
"\n",
" @staticmethod\n",
" def read_parquet(path, columns, nrows, skip_rows, predicate, schema):\n",
" \"\"\"Read parquet data.\"\"\"\n",
" tbl_w_meta = plc.io.parquet.read_parquet(\n",
" plc.io.SourceInfo([path]),\n",
" columns=columns,\n",
" nrows=nrows,\n",
" skip_rows=skip_rows,\n",
" )\n",
" df = DataFrame.from_table(\n",
" tbl_w_meta.tbl,\n",
" # TODO: consider nested column names?\n",
" tbl_w_meta.column_names(include_children=False),\n",
" )\n",
" assert all(c.obj.type() == schema[c.name] for c in df.columns)\n",
" if predicate is None:\n",
" return df\n",
" else:\n",
" (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)\n",
" return df.filter(mask)\n",
"\n",
" def _tasks(self) -> MutableMapping[Any, Any]:\n",
" key = self._key\n",
" with_columns = self._ir.with_columns\n",
" n_rows = self._ir.n_rows\n",
" skip_rows = self._ir.skip_rows\n",
" predicate = self._ir.predicate\n",
" schema = self._ir.schema\n",
" return {\n",
" (key, i): (\n",
" self.read_parquet,\n",
" path,\n",
" with_columns,\n",
" n_rows,\n",
" skip_rows,\n",
" predicate,\n",
" schema,\n",
" )\n",
" for i, path in enumerate(self._ir.paths)\n",
" }\n",
"```\n",
"\n",
"#### Dealing with `Expr` evaluation\n",
"\n",
"Since `IR` evaluation often requires the evaluation of a separate `Expr` graph, we also need a `DaskExprNode` class to implement this nested logic. For now, it seems reasonable to link an `DaskExprNode` to both an `IR` **and** an `Expr` object. In order to preserve `NamedExpr` properties, the `DaskExprNode` also tracks the name of the `NamedColumn` it operates on.\n"
]
},
{
"cell_type": "markdown",
"id": "802f242a-a61f-4b7f-97f2-fed5d85ca5d8",
"metadata": {},
"source": [
"## Usage Example\n",
"\n",
"Create a toy Parquet Dataset:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2c584872-de55-4e8f-abf5-0c98ba12f9d2",
"metadata": {},
"outputs": [],
"source": [
"import dask.dataframe as dd\n",
"\n",
"dd.from_dict(\n",
" {\"x\": range(30), \"y\": [\"cat\", \"dog\", \"fish\"] * 10},\n",
" npartitions=3,\n",
").to_parquet(\"demo_parquet\", write_index=False)"
]
},
{
"cell_type": "markdown",
"id": "c5bb4ab5-2431-4b8f-b635-84cc915a360d",
"metadata": {},
"source": [
"Opt into Dask execution in cuDF-Polars and perform a lazy scan of `\"demo_parquet\"`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "39b15e1e-e291-4112-88e4-d6a369103711",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import polars as pl\n",
"\n",
"os.environ[\"CUDF_POLARS_DASK\"] = \"True\"\n",
"path = \"demo_parquet\"\n",
"df = pl.scan_parquet(path)"
]
},
{
"cell_type": "markdown",
"id": "71bc3144-9ca6-4bdf-a573-65f66f270af0",
"metadata": {},
"source": [
"Call collect with the GPU engine:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dc2edd90-7be2-4894-9df1-a0ed0efa7f81",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (30, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>x</th><th>y</th></tr><tr><td>i64</td><td>str</td></tr></thead><tbody><tr><td>0</td><td>&quot;cat&quot;</td></tr><tr><td>1</td><td>&quot;dog&quot;</td></tr><tr><td>2</td><td>&quot;fish&quot;</td></tr><tr><td>3</td><td>&quot;cat&quot;</td></tr><tr><td>4</td><td>&quot;dog&quot;</td></tr><tr><td>&hellip;</td><td>&hellip;</td></tr><tr><td>25</td><td>&quot;dog&quot;</td></tr><tr><td>26</td><td>&quot;fish&quot;</td></tr><tr><td>27</td><td>&quot;cat&quot;</td></tr><tr><td>28</td><td>&quot;dog&quot;</td></tr><tr><td>29</td><td>&quot;fish&quot;</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (30, 2)\n",
"┌─────┬──────┐\n",
"│ x ┆ y │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ str │\n",
"╞═════╪══════╡\n",
"│ 0 ┆ cat │\n",
"│ 1 ┆ dog │\n",
"│ 2 ┆ fish │\n",
"│ 3 ┆ cat │\n",
"│ 4 ┆ dog │\n",
"│ … ┆ … │\n",
"│ 25 ┆ dog │\n",
"│ 26 ┆ fish │\n",
"│ 27 ┆ cat │\n",
"│ 28 ┆ dog │\n",
"│ 29 ┆ fish │\n",
"└─────┴──────┘"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.collect(engine=\"gpu\")"
]
},
{
"cell_type": "markdown",
"id": "1635c33f-9435-4c48-bc29-c72e5e4f855d",
"metadata": {},
"source": [
"When you call `collect` in this example, cudf-polars will build and execute a task graph that looks something like this:\n",
"\n",
"```\n",
"{'readparquet-275df159f57b275dd0b1805f49b9c375': (<bound method DataFrame.concatenate of <class 'cudf_polars.containers.dataframe.DataFrame'>>,\n",
" [('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 0),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 1),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 2)]),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 0): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.0.parquet',\n",
" None,\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n",
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>}),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 1): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.1.parquet',\n",
" None,\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n",
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>}),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 2): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.2.parquet',\n",
" None,\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n",
" 'y': <pylibcudf.types.DataType object at 0x7fe7e80ae610>})}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "0b59b321-044e-48e3-8bc0-b9252c851713",
"metadata": {},
"source": [
"We can also performa simple `sum` aggregation:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "07dbe993-8dee-4a4b-8ecd-c26f4c6acf93",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (1, 1)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>x</th></tr><tr><td>i64</td></tr></thead><tbody><tr><td>435</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (1, 1)\n",
"┌─────┐\n",
"│ x │\n",
"│ --- │\n",
"│ i64 │\n",
"╞═════╡\n",
"│ 435 │\n",
"└─────┘"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.select(pl.sum(\"x\")).collect(engine=\"gpu\")"
]
},
{
"cell_type": "markdown",
"id": "d167c417-ee44-4481-af01-17419e70e4ac",
"metadata": {},
"source": [
"In this case, we will build and execute a slightly-more complex task graph:\n",
"\n",
"```\n",
"{'daskselect-b005fc5162d938fe186aa1908a3c81f2': ('daskselect-b005fc5162d938fe186aa1908a3c81f2',\n",
" 0),\n",
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._chunk at 0x7fe510e527a0>,\n",
" (<function DaskCol._op at 0x7fe510e52520>,\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 0),\n",
" 'x'),\n",
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n",
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n",
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 1): (<function SumAgg._chunk at 0x7fe510e527a0>,\n",
" (<function DaskCol._op at 0x7fe510e52520>,\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 1),\n",
" 'x'),\n",
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n",
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n",
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226', 2): (<function SumAgg._chunk at 0x7fe510e527a0>,\n",
" (<function DaskCol._op at 0x7fe510e52520>,\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375',\n",
" 2),\n",
" 'x'),\n",
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n",
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>),\n",
" ('concat-sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._concat at 0x7fe510e52840>,\n",
" [('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n",
" 0),\n",
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n",
" 1),\n",
" ('chunk-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n",
" 2)]),\n",
" ('daskselect-b005fc5162d938fe186aa1908a3c81f2', 0): (<function DaskSelect._op at 0x7fe510e52340>,\n",
" [('sumagg-14654d51f6722c8b42c1b1e7cd881226',\n",
" 0)],\n",
" True),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 0): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.0.parquet',\n",
" ['x'],\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 1): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.1.parquet',\n",
" ['x'],\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n",
" ('readparquet-275df159f57b275dd0b1805f49b9c375', 2): (<function ReadParquet.read_parquet at 0x7fe510e52200>,\n",
" 'demo_parquet/part.2.parquet',\n",
" ['x'],\n",
" -1,\n",
" 0,\n",
" None,\n",
" {'x': <pylibcudf.types.DataType object at 0x7fe7e80aedb0>}),\n",
" ('sumagg-14654d51f6722c8b42c1b1e7cd881226', 0): (<function SumAgg._finalize at 0x7fe510e528e0>,\n",
" ('concat-sumagg-14654d51f6722c8b42c1b1e7cd881226',\n",
" 0),\n",
" <pylibcudf.aggregation.Aggregation object at 0x7fe51028a110>,\n",
" <pylibcudf.types.DataType object at 0x7fe7e80aedb0>,\n",
" 'x')}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f33d2e9-1873-4a06-afeb-83c69a46fa21",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment