Last active
May 19, 2024 19:41
-
-
Save RaczeQ/cb4c65c3626ae410b63a5e6caa71b6dd to your computer and use it in GitHub Desktop.
Pyarrow Multiprocessing with streaming the result
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
from pathlib import Path | |
from queue import Queue | |
from time import sleep | |
from typing import Callable | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
from tqdm import tqdm | |
def _intersection_worker( | |
queue: Queue[tuple[str, int]], | |
save_path: Path, | |
function: Callable[[pa.Table], pa.Table], | |
columns: Optional[list[str]] = None | |
) -> None: # pragma: no cover | |
current_pid = multiprocessing.current_process().pid | |
filepath = save_path / f"{current_pid}.parquet" | |
writer = None | |
while not queue.empty(): | |
try: | |
file_name = None | |
file_name, row_group_index = queue.get(block=True, timeout=1) | |
pq_file = pq.ParquetFile(file_name) | |
row_group_table = pq_file.read_row_group(row_group_index, columns=columns) | |
if len(row_group_table) == 0: | |
continue | |
result_table = function(row_group_table) | |
if not writer: | |
writer = pq.ParquetWriter(filepath, result_table.schema) | |
writer.write_table(result_table) | |
except Exception as ex: | |
log_message(ex) | |
if file_name is not None: | |
queue.put((file_name, row_group_index)) | |
if writer: | |
writer.close() | |
def map_parquet_dataset( | |
dataset_path: Path, | |
destination_path: Path, | |
function: Callable[[pa.Table], pa.Table], | |
columns: Optional[list[str]] = None, | |
) -> None: | |
""" | |
Apply a function over parquet dataset in a multiprocessing environment. | |
Will save results in multiple files in a destination path. | |
Args: | |
dataset_path (Path): Path of the parquet dataset. | |
destination_path (Path): Path of the destination. | |
function (Callable[[pa.Table], pa.Table]): Function to apply over a row group table. | |
Will save resulting table in a new parquet file. | |
columns (Optional[list[str]]): List of columns to read. Defaults to `None`. | |
""" | |
queue: Queue[tuple[str, int]] = multiprocessing.Manager().Queue() | |
dataset = pq.ParquetDataset(dataset_path) | |
for pq_file in dataset.files: | |
for row_group in range(pq.ParquetFile(pq_file).num_row_groups): | |
queue.put((pq_file, row_group)) | |
total = queue.qsize() | |
destination_path.mkdir(parents=True, exist_ok=True) | |
try: | |
processes = [ | |
multiprocessing.Process( | |
target=_intersection_worker, | |
args=(queue, destination_path, function, columns), | |
) | |
for _ in range(multiprocessing.cpu_count()) | |
] | |
# Run processes | |
for p in processes: | |
p.start() | |
# Report progress with TQDM | |
with tqdm(total=total) as bar: | |
while any(process.is_alive() for process in processes): | |
bar.n = total - queue.qsize() | |
bar.refresh() | |
sleep(1) | |
bar.n = total | |
bar.refresh() | |
finally: | |
# In case of exception - stop all processes | |
for p in processes: | |
if p.is_alive(): | |
p.terminate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment