Created
March 28, 2019 16:46
-
-
Save acroz/b46ff5020d65137e1ff08e8139d3095a to your computer and use it in GitHub Desktop.
Search filter building in MLflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import uuid | |
from enum import Enum | |
from mlflow.entities import Run, RunInfo, RunData, Metric, Param | |
class KeyType(Enum): | |
PARAM = "param" | |
METRIC = "metric" | |
class LogicalOperator(Enum): | |
AND = "and" | |
OR = "or" | |
LOGICAL_OPERATOR_FUNCTIONS = { | |
LogicalOperator.AND: all, | |
LogicalOperator.OR: any, | |
} | |
class ComparisonOperator(Enum): | |
EQUAL = "==" | |
NOT_EQUAL = "!=" | |
GREATER_THAN = ">" | |
GREATER_THAN_EQUAL = ">=" | |
LESS_THAN = "<" | |
LESS_THAN_EQUAL = "<=" | |
COMPARISON_OPERATOR_FUNCTIONS = { | |
ComparisonOperator.EQUAL: operator.eq, | |
ComparisonOperator.NOT_EQUAL: operator.ne, | |
ComparisonOperator.GREATER_THAN: operator.gt, | |
ComparisonOperator.GREATER_THAN_EQUAL: operator.ge, | |
ComparisonOperator.LESS_THAN: operator.lt, | |
ComparisonOperator.LESS_THAN_EQUAL: operator.le, | |
} | |
class FilterComponentBase: | |
def __and__(self, other): | |
return CombinedFilterComponent([self, other], LogicalOperator.AND) | |
def __or__(self, other): | |
return CombinedFilterComponent([self, other], LogicalOperator.OR) | |
class FilterComponent(FilterComponentBase): | |
def __init__(self, key_type, key, operator, value): | |
self.key_type = key_type | |
self.key = key | |
self.operator = operator | |
self.value = value | |
def __repr__(self): | |
return "{}({}, {}, {}, {})".format( | |
self.__class__.__name__, | |
self.key_type, | |
self.key, | |
self.operator, | |
self.value, | |
) | |
def filter(self, run): | |
if self.key_type == KeyType.METRIC: | |
metric = next( | |
(m for m in run.data.metrics if m.key == self.key), None | |
) | |
lhs = metric.value if metric else None | |
elif self.key_type == KeyType.PARAM: | |
param = next( | |
(p for p in run.data.params if p.key == self.key), None | |
) | |
lhs = param.value if param else None | |
else: | |
raise Exception("Invalid key type") | |
operator_func = COMPARISON_OPERATOR_FUNCTIONS[self.operator] | |
return operator_func(lhs, self.value) | |
class CombinedFilterComponent(FilterComponentBase): | |
def __init__(self, children, operator): | |
self.children = children | |
self.operator = operator | |
def __repr__(self): | |
return "{}({}, {})".format( | |
self.__class__.__name__, self.children, self.operator | |
) | |
def filter(self, run): | |
parts = [child.filter(run) for child in self.children] | |
combiner = LOGICAL_OPERATOR_FUNCTIONS[self.operator] | |
return combiner(parts) | |
class SearchFilter: | |
def __init__(self, root_filter_component): | |
self.root = root_filter_component | |
def __and__(self, other): | |
return SearchFilter(self.root & other.root) | |
def __or__(self, other): | |
return SearchFilter(self.root | other.root) | |
def filter(self, run): | |
return self.root.filter(run) | |
class FilterComponentBuilder: | |
def __init__(self, key_type, name): | |
self.key_type = key_type | |
self.name = name | |
def __eq__(self, value): | |
return SearchFilter( | |
FilterComponent( | |
self.key_type, self.name, ComparisonOperator.EQUAL, value | |
) | |
) | |
def __gt__(self, value): | |
return SearchFilter( | |
FilterComponent( | |
self.key_type, | |
self.name, | |
ComparisonOperator.GREATER_THAN, | |
value, | |
) | |
) | |
class Filter: | |
@staticmethod | |
def metric(name): | |
return FilterComponentBuilder(KeyType.METRIC, name) | |
@staticmethod | |
def param(name): | |
return FilterComponentBuilder(KeyType.PARAM, name) | |
# Usage: | |
search_filter = (Filter.metric("accuracy") > 0.9) & ( | |
Filter.param("model_type") == "foo" | |
) | |
info = RunInfo(uuid.uuid4().hex, 0, "", "", "", "", "", "", 0, 0, "", "") | |
for accuracy in [0.8, 0.95]: | |
for model_type in ["foo", "bar"]: | |
run = Run( | |
info, | |
RunData( | |
metrics=[Metric(key="accuracy", timestamp=0, value=accuracy)], | |
params=[Param(key="model_type", value=model_type)], | |
), | |
) | |
match = search_filter.filter(run) | |
print( | |
f"Run with accuracy {accuracy} and model_type {model_type}: " | |
f"{match}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment