Last active
April 14, 2022 00:39
-
-
Save jgbos/c2602dfad87af7897d0f3c5f3ad03e95 to your computer and use it in GitHub Desktop.
Base class definition of a workflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import abstractmethod | |
from collections import defaultdict | |
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union | |
from hydra.core.utils import JobReturn | |
from hydra_zen import make_config | |
from .hydra import launch, zen | |
class Workflows: | |
""" | |
Workflows are the essential part of any machine learning experimentation. Classes | |
that extend are intended to help: | |
- Allow the user to write interpretable task APIs instead using th uber config | |
object as an input | |
- Collate metrics and provide them in a useful format (e.g., ``xarray.DataArray``) | |
- Run large scale distributed jobs via CLI or ``hydra_zen.launch`` | |
Workflows have a number of useful properties and methods: | |
1. The user has full control of the inputs and outputs of their evalulation task | |
2. A workflow will validate the configuration input contains the necessary properties | |
to execute the evaluation task, ``zen(self.evaluation_task).validate(config)``` | |
3. A workflow can be executed using multirun (multiple workflow experiments). Since | |
a workflow will also run a multirun experiment, this because ``uber'' distributed :) | |
4. The configuration and return metrics for every job executed in a work flow | |
are stored in the class object | |
5. A workflow provides methods to output simple data products for a user to | |
examine metrics for each job (e.g., ``self.to_xarray`` returns an `xarray.DataArray` of | |
all the metrics for each job.) | |
""" | |
metrics: List[Mapping[str, Any]] | |
cfgs: List[Mapping[str, Any]] | |
jobs: List[JobReturn] | |
def __init__(self, eval_task_cfg=None) -> None: | |
# we can do validation checks here | |
self.eval_task_cfg = ( | |
eval_task_cfg if eval_task_cfg is not None else make_config() | |
) | |
self.validate() | |
def validate(self): | |
raise NotImplementedError() | |
@staticmethod | |
def evaluation_task(eval_task_cfg) -> Mapping[str, Any]: | |
raise NotImplementedError() | |
@abstractmethod | |
def run( | |
self, | |
*, | |
sweepdir: str = ".", | |
launcher: Optional[str] = None, | |
additional_overrides: Optional[Sequence[str]] = None, | |
**kwargs | |
) -> None: | |
""" | |
TODO: How do we handle that the user will have additional | |
inputs? | |
""" | |
raise NotImplementedError() | |
def to_xarray(self): | |
raise NotImplementedError() | |
def jobs_post_process(self) -> None: | |
# save and unpack jobs | |
self.cfgs = [] | |
self.metrics = [] | |
for j in self.jobs: | |
self.cfgs.append(j.cfg) | |
self.metrics.append(j.return_value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment