Skip to content

Instantly share code, notes, and snippets.

@fgassert
Created January 15, 2025 15:43
Show Gist options
  • Save fgassert/c6c9a87c47d2eaffd30d3f72b0ff675a to your computer and use it in GitHub Desktop.
Save fgassert/c6c9a87c47d2eaffd30d3f72b0ff675a to your computer and use it in GitHub Desktop.
from copy import deepcopy
from typing import Any, Callable, Literal
import dask.distributed
from dask.delayed import delayed
from kedro.io.core import (
AbstractDataset,
DatasetError,
)
from .robust_partitioned_dataset import RobustPartitionedDataset
class DaskPartitionedDataset(RobustPartitionedDataset):
DEFAULT_CHECKPOINT_FILENAME = "_MANIFEST"
def __init__( # noqa: PLR0913
self,
*,
path: str,
dataset: str | type[AbstractDataset] | dict[str, Any],
behavior: Literal['default', 'complete_missing', 'overwrite'] | None = "default",
checkpoint: dict[str, Any] | None = None,
filepath_arg: str = "filepath",
filename_suffix: str = "",
credentials: dict[str, Any] | None = None,
load_args: dict[str, Any] | None= None,
fs_args: dict[str, Any] | None = None,
dask_client_options: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
):
""" The DaskPartitionedDataset is a subclass of PartitionedDataset that
uses Dask to process and save partitions in parallel.
DaskPartitionedDataset adds some robustness for processing large datasets.
Errors saving individual partitions will not prevent the rest of the
partitions from being processed. A manifest file of completed partitions
will be saved, and you can skip processing partitions that have already
been saved by using the behavior:"complete_missing" keyword argument.
Args:
dask_client_options: Options to pass to the Dask Client constructor
to configure the connection to the Dask cluster.
e.g. {'address': 'http://localhost:8786'}
If no address is specificed, Dask will initiate a LocalCluster
to execute tasks, and keyword arguments will be passed to the
LocalCluster constructor.
behavior: 'default' | 'complete_missing' | 'overwrite'
The behavior to use when saving partitions.
'default': Save all partitions, overwriting any that already exist.
'complete_missing': Only save partitions that do not already exist.
'overwrite': Delete all partitions before saving.
"""
super().__init__(
path=path,
dataset=dataset,
behavior=behavior,
checkpoint=checkpoint,
filepath_arg=filepath_arg,
filename_suffix=filename_suffix,
credentials=credentials,
load_args=load_args,
fs_args=fs_args,
metadata=metadata,
)
self._client_options = dask_client_options or {}
def _save_new_partitions(self, data: dict[str, Callable[[], Any]]) -> None:
def _save_partition(partition_data, dataset):
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
dataset.save(partition_data)
errors = []
tasks = []
for partition_id, partition_data in data.items():
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs)
task = delayed(_save_partition)(partition_data, dataset)
tasks.append(task)
client = dask.distributed.Client(**self._client_options)
futures = client.compute(tasks)
futures_to_ids = {future: partition_id for future, partition_id in zip(futures, data.keys())} # type: ignore
ac = dask.distributed.as_completed(futures, raise_errors=False)
for future in ac:
if future.status == 'error':
self._logger.warning(f"Error saving partition {futures_to_ids[future]}, with exception: {future.exception().__repr__()}")
errors.append(future.exception())
future.release()
else:
self._completed_partitions.append(futures_to_ids[future])
future.release()
if errors:
raise DatasetError(f"{len(errors)} errors occurred while saving partitions.")
client.shutdown()
import concurrent.futures
from copy import deepcopy
from typing import Any, Callable, Literal
from kedro.io.core import (
AbstractDataset,
DatasetError,
)
from .robust_partitioned_dataset import RobustPartitionedDataset
class MultiprocessingPartitionedDataset(RobustPartitionedDataset):
DEFAULT_CHECKPOINT_FILENAME = "_MANIFEST"
def __init__( # noqa: PLR0913
self,
*,
path: str,
dataset: str | type[AbstractDataset] | dict[str, Any],
behavior: Literal['default', 'complete_missing', 'overwrite'] | None = "default",
checkpoint: dict[str, Any] | None = None,
filepath_arg: str = "filepath",
filename_suffix: str = "",
credentials: dict[str, Any] | None = None,
load_args: dict[str, Any] | None= None,
fs_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
):
""" The MultiprocessingPartitionedDataset is a subclass of PartitionedDataset that
uses a ProcessPoolExecutor to process and save partitions in parallel.
MultiprocessingPartitionedDataset adds some robustness for processing large datasets.
Errors saving individual partitions will not prevent the rest of the
partitions from being processed. A manifest file of completed partitions
will be saved, and you can skip processing partitions that have already
been saved by using the behavior:"complete_missing" keyword argument.
Args:
dask_client_options: Options to pass to the Dask Client constructor
to configure the connection to the Dask cluster.
e.g. {'address': 'http://localhost:8786'}
If no address is specificed, Dask will initiate a LocalCluster
to execute tasks, and keyword arguments will be passed to the
LocalCluster constructor.
behavior: 'default' | 'complete_missing' | 'overwrite'
The behavior to use when saving partitions.
'default': Save all partitions, overwriting any that already exist.
'complete_missing': Only save partitions that do not already exist.
'overwrite': Delete all partitions before saving.
"""
super().__init__(
path=path,
dataset=dataset,
behavior=behavior,
checkpoint=checkpoint,
filepath_arg=filepath_arg,
filename_suffix=filename_suffix,
credentials=credentials,
load_args=load_args,
fs_args=fs_args,
metadata=metadata,
)
def _save_new_partitions(self, data: dict[str, Callable[[], Any]]) -> None:
def _save_partition(partition_data, dataset):
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
dataset.save(partition_data)
errors = []
with concurrent.futures.ProcessPoolExecutor() as pool:
futures_to_ids = {}
futures = [ ]
for partition_id, partition_data in data.items():
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs)
future = pool.submit(_save_partition, partition_data, dataset)
futures.append(future)
futures_to_ids[future] = partition_id
for future in concurrent.futures.as_completed(futures):
try:
future.result()
self._completed_partitions.append(futures_to_ids[future])
except Exception as e:
errors.append(e)
self._logger.warning(f"Error saving partition {futures_to_ids[future]}, with exception: {e.__repr__()}")
if errors:
raise DatasetError(f"{len(errors)} errors occurred while saving partitions.")
import operator
from copy import deepcopy
from typing import Any, Callable, Literal
from cachetools import cachedmethod
from kedro.io.catalog_config_resolver import CREDENTIALS_KEY
from kedro.io.core import (
VERSION_KEY,
VERSIONED_FLAG_KEY,
AbstractDataset,
DatasetError,
parse_dataset_definition,
)
from kedro_datasets.partitions import PartitionedDataset
from kedro_datasets.partitions.partitioned_dataset import (
KEY_PROPAGATION_WARNING,
_grandparent,
)
COMPLETE_MISSING = "complete_missing"
OVERWRITE = "overwrite"
class RobustPartitionedDataset(PartitionedDataset):
DEFAULT_CHECKPOINT_FILENAME = "_MANIFEST"
DEFAULT_CHECKPOINT_TYPE = "kedro_datasets.text.TextDataset"
def __init__( # noqa: PLR0913
self,
*,
path: str,
dataset: str | type[AbstractDataset] | dict[str, Any],
behavior: Literal['default', 'complete_missing', 'overwrite'] | None = "default",
checkpoint: dict[str, Any] | None = None,
filepath_arg: str = "filepath",
filename_suffix: str = "",
credentials: dict[str, Any] | None = None,
load_args: dict[str, Any] | None= None,
fs_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
):
"""
RobustPartitionedDataset is a subclass of PartitionedDataset that adds some robustness
for processing large datasets. Errors saving individual partitions will not prevent the
rest of the partitions from being processed. A manifest file of completed partitions will
be saved, and you can skip processing partitions that have already been saved by using the
behavior: "complete_missing" keyword argument.
Args:
behavior: 'default' | 'complete_missing' | 'overwrite'
The behavior to use when saving partitions.
'default': Save all partitions, overwriting any that already exist.
'complete_missing': Only save partitions that do not already exist.
'overwrite': Delete all partitions before saving.
"""
super().__init__(
path=path,
dataset=dataset,
filepath_arg=filepath_arg,
filename_suffix=filename_suffix,
credentials=credentials,
load_args=load_args,
fs_args=fs_args,
metadata=metadata,
)
self._checkpoint_config = self._parse_checkpoint_config(checkpoint)
self._behavior = behavior
self._completed_partitions = []
@property
def _checkpoint(self) -> AbstractDataset:
type_, kwargs = parse_dataset_definition(self._checkpoint_config)
return type_(**kwargs) # type: ignore
def _read_checkpoint_ids(self) -> list[str]:
try:
return self._checkpoint.load().splitlines()
except DatasetError:
return []
def _parse_checkpoint_config(
self, checkpoint_config: dict[str, Any] | None
) -> dict[str, Any]:
checkpoint_config = deepcopy(checkpoint_config)
checkpoint_config = checkpoint_config or {}
for key in {VERSION_KEY, VERSIONED_FLAG_KEY} & checkpoint_config.keys():
raise DatasetError(
f"'{self.__class__.__name__}' does not support versioning of the "
f"checkpoint. Please remove '{key}' key from the checkpoint definition."
)
default_checkpoint_path = self._sep.join(
[self._normalized_path.rstrip(self._sep), self.DEFAULT_CHECKPOINT_FILENAME]
)
default_config = {
"type": self.DEFAULT_CHECKPOINT_TYPE,
self._filepath_arg: default_checkpoint_path,
}
if self._credentials:
default_config[CREDENTIALS_KEY] = deepcopy(self._credentials)
if CREDENTIALS_KEY in default_config.keys() & checkpoint_config.keys():
self._logger.warning(
KEY_PROPAGATION_WARNING,
{"keys": CREDENTIALS_KEY, "target": "checkpoint"},
)
return {**default_config, **checkpoint_config}
@cachedmethod(cache=operator.attrgetter("_partition_cache"))
def _list_partitions(self) -> list[str]:
checkpoint_path = self._filesystem._strip_protocol(
self._checkpoint_config[self._filepath_arg]
)
dataset_is_versioned = VERSION_KEY in self._dataset_config
checkpoint_ids = self._read_checkpoint_ids()
def _is_valid_partition(partition) -> bool:
if not partition.endswith(self._filename_suffix):
return False
if partition == checkpoint_path:
return False
if checkpoint_ids and self._path_to_partition(partition) not in checkpoint_ids:
return False
return True
return [
_grandparent(path) if dataset_is_versioned else path
for path in self._filesystem.find(self._normalized_path, **self._load_args)
if _is_valid_partition(path)
]
def _load(self) -> dict[str, Callable[[], Any]]:
partitions = {}
for partition in self._list_partitions():
kwargs = deepcopy(self._dataset_config)
# join the protocol back since PySpark may rely on it
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
partition_id = self._path_to_partition(partition)
partitions[partition_id] = dataset.load
if not partitions:
raise DatasetError(f"No partitions found in '{self._path}'")
return partitions
def _save_checkpoint(self, checkpoint_ids: list[str]) -> None:
self._checkpoint.save("\n".join(sorted(checkpoint_ids)))
def _save_new_partitions(self, data: dict[str, Callable[[], Any]]) -> None:
def _save_partition(partition_data, dataset):
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
dataset.save(partition_data)
errors = []
for partition_id, partition_data in data.items():
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs)
try:
_save_partition(partition_data, dataset)
self._completed_partitions.append(partition_id)
except Exception as e:
self._logger.warning(f"Error saving partition {partition_id}:\n{e}")
errors.append(e)
if errors:
raise DatasetError(f"{len(errors)} errors occurred while saving partitions.")
def _save(self, data: dict[str, Callable[[], Any]]) -> None:
new_partitions = data.keys()
checkpoint_ids = []
if self._behavior == COMPLETE_MISSING:
checkpoint_ids = self._read_checkpoint_ids()
new_partitions = set(data.keys()) - set(checkpoint_ids)
self._logger.info(f"Saving {self._path}: {len(data.keys()) - len(new_partitions)} of {len(data.keys())} partitions already exist.")
elif self._behavior == OVERWRITE and self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True)
if not len(new_partitions):
return
if not self._filesystem.exists(self._normalized_path):
self._filesystem.mkdir(self._normalized_path)
self._completed_partitions = checkpoint_ids
try:
self._save_new_partitions({k: data[k] for k in new_partitions})
except (KeyboardInterrupt, Exception) as e:
self._logger.info("Saving checkpoint on Exception")
self._save_checkpoint(self._completed_partitions)
raise e
self._save_checkpoint(self._completed_partitions)
self._invalidate_caches()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment