Skip to content

Instantly share code, notes, and snippets.

@acroz
Created March 28, 2019 16:46
Show Gist options
  • Save acroz/b46ff5020d65137e1ff08e8139d3095a to your computer and use it in GitHub Desktop.
Save acroz/b46ff5020d65137e1ff08e8139d3095a to your computer and use it in GitHub Desktop.
Search filter building in MLflow
import operator
import uuid
from enum import Enum
from mlflow.entities import Run, RunInfo, RunData, Metric, Param
class KeyType(Enum):
PARAM = "param"
METRIC = "metric"
class LogicalOperator(Enum):
AND = "and"
OR = "or"
LOGICAL_OPERATOR_FUNCTIONS = {
LogicalOperator.AND: all,
LogicalOperator.OR: any,
}
class ComparisonOperator(Enum):
EQUAL = "=="
NOT_EQUAL = "!="
GREATER_THAN = ">"
GREATER_THAN_EQUAL = ">="
LESS_THAN = "<"
LESS_THAN_EQUAL = "<="
COMPARISON_OPERATOR_FUNCTIONS = {
ComparisonOperator.EQUAL: operator.eq,
ComparisonOperator.NOT_EQUAL: operator.ne,
ComparisonOperator.GREATER_THAN: operator.gt,
ComparisonOperator.GREATER_THAN_EQUAL: operator.ge,
ComparisonOperator.LESS_THAN: operator.lt,
ComparisonOperator.LESS_THAN_EQUAL: operator.le,
}
class FilterComponentBase:
def __and__(self, other):
return CombinedFilterComponent([self, other], LogicalOperator.AND)
def __or__(self, other):
return CombinedFilterComponent([self, other], LogicalOperator.OR)
class FilterComponent(FilterComponentBase):
def __init__(self, key_type, key, operator, value):
self.key_type = key_type
self.key = key
self.operator = operator
self.value = value
def __repr__(self):
return "{}({}, {}, {}, {})".format(
self.__class__.__name__,
self.key_type,
self.key,
self.operator,
self.value,
)
def filter(self, run):
if self.key_type == KeyType.METRIC:
metric = next(
(m for m in run.data.metrics if m.key == self.key), None
)
lhs = metric.value if metric else None
elif self.key_type == KeyType.PARAM:
param = next(
(p for p in run.data.params if p.key == self.key), None
)
lhs = param.value if param else None
else:
raise Exception("Invalid key type")
operator_func = COMPARISON_OPERATOR_FUNCTIONS[self.operator]
return operator_func(lhs, self.value)
class CombinedFilterComponent(FilterComponentBase):
def __init__(self, children, operator):
self.children = children
self.operator = operator
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.children, self.operator
)
def filter(self, run):
parts = [child.filter(run) for child in self.children]
combiner = LOGICAL_OPERATOR_FUNCTIONS[self.operator]
return combiner(parts)
class SearchFilter:
def __init__(self, root_filter_component):
self.root = root_filter_component
def __and__(self, other):
return SearchFilter(self.root & other.root)
def __or__(self, other):
return SearchFilter(self.root | other.root)
def filter(self, run):
return self.root.filter(run)
class FilterComponentBuilder:
def __init__(self, key_type, name):
self.key_type = key_type
self.name = name
def __eq__(self, value):
return SearchFilter(
FilterComponent(
self.key_type, self.name, ComparisonOperator.EQUAL, value
)
)
def __gt__(self, value):
return SearchFilter(
FilterComponent(
self.key_type,
self.name,
ComparisonOperator.GREATER_THAN,
value,
)
)
class Filter:
@staticmethod
def metric(name):
return FilterComponentBuilder(KeyType.METRIC, name)
@staticmethod
def param(name):
return FilterComponentBuilder(KeyType.PARAM, name)
# Usage:
search_filter = (Filter.metric("accuracy") > 0.9) & (
Filter.param("model_type") == "foo"
)
info = RunInfo(uuid.uuid4().hex, 0, "", "", "", "", "", "", 0, 0, "", "")
for accuracy in [0.8, 0.95]:
for model_type in ["foo", "bar"]:
run = Run(
info,
RunData(
metrics=[Metric(key="accuracy", timestamp=0, value=accuracy)],
params=[Param(key="model_type", value=model_type)],
),
)
match = search_filter.filter(run)
print(
f"Run with accuracy {accuracy} and model_type {model_type}: "
f"{match}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment