Skip to content

Instantly share code, notes, and snippets.

@jgbos
Last active April 14, 2022 00:39
Show Gist options
  • Save jgbos/c2602dfad87af7897d0f3c5f3ad03e95 to your computer and use it in GitHub Desktop.
Save jgbos/c2602dfad87af7897d0f3c5f3ad03e95 to your computer and use it in GitHub Desktop.
Base class definition of a workflow
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from hydra.core.utils import JobReturn
from hydra_zen import make_config
from .hydra import launch, zen
class Workflows:
"""
Workflows are the essential part of any machine learning experimentation. Classes
that extend are intended to help:
- Allow the user to write interpretable task APIs instead using th uber config
object as an input
- Collate metrics and provide them in a useful format (e.g., ``xarray.DataArray``)
- Run large scale distributed jobs via CLI or ``hydra_zen.launch``
Workflows have a number of useful properties and methods:
1. The user has full control of the inputs and outputs of their evalulation task
2. A workflow will validate the configuration input contains the necessary properties
to execute the evaluation task, ``zen(self.evaluation_task).validate(config)```
3. A workflow can be executed using multirun (multiple workflow experiments). Since
a workflow will also run a multirun experiment, this because ``uber'' distributed :)
4. The configuration and return metrics for every job executed in a work flow
are stored in the class object
5. A workflow provides methods to output simple data products for a user to
examine metrics for each job (e.g., ``self.to_xarray`` returns an `xarray.DataArray` of
all the metrics for each job.)
"""
metrics: List[Mapping[str, Any]]
cfgs: List[Mapping[str, Any]]
jobs: List[JobReturn]
def __init__(self, eval_task_cfg=None) -> None:
# we can do validation checks here
self.eval_task_cfg = (
eval_task_cfg if eval_task_cfg is not None else make_config()
)
self.validate()
def validate(self):
raise NotImplementedError()
@staticmethod
def evaluation_task(eval_task_cfg) -> Mapping[str, Any]:
raise NotImplementedError()
@abstractmethod
def run(
self,
*,
sweepdir: str = ".",
launcher: Optional[str] = None,
additional_overrides: Optional[Sequence[str]] = None,
**kwargs
) -> None:
"""
TODO: How do we handle that the user will have additional
inputs?
"""
raise NotImplementedError()
def to_xarray(self):
raise NotImplementedError()
def jobs_post_process(self) -> None:
# save and unpack jobs
self.cfgs = []
self.metrics = []
for j in self.jobs:
self.cfgs.append(j.cfg)
self.metrics.append(j.return_value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment