Created
August 11, 2021 06:18
-
-
Save eavae/57aef3e68e5593e6c8ba2544104c6a3a to your computer and use it in GitHub Desktop.
An idea about collect hyper parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import inspect | |
import typing | |
class SearchSpacesDict(OrderedDict): | |
def __setitem__(self, key, value): | |
super().__setitem__(key, value) | |
self.move_to_end(key) | |
def mangling(cls): | |
return f'__{cls.__name__}_search_spaces__' | |
def get_search_space(cls): | |
attr_name = mangling(cls) | |
if not hasattr(cls, attr_name): | |
setattr(cls, attr_name, SearchSpacesDict()) | |
return getattr(cls, attr_name) | |
class Tunable(): | |
def __init__(self, *args, **kargs) -> None: | |
pass | |
@classmethod | |
def get_search_spaces(cls): | |
cs = SearchSpacesDict() | |
for c in inspect.getmro(cls): | |
if issubclass(c, Tunable): | |
for name, space in get_search_space(c).items(): | |
cs[name] = space | |
return cs | |
def get_tunable_class(cls): | |
classes = [] | |
if inspect.isclass(cls): | |
for cls in inspect.getmro(cls): | |
if cls == Tunable: | |
classes.append(cls) | |
return classes | |
def choice(name, options): | |
def inner(cls): | |
cs = get_search_space(cls) | |
# copy search space from choice of tunable classes | |
for i in options: | |
if inspect.isclass(i) and issubclass(i, Tunable): | |
for key, space in i.get_search_spaces().items(): | |
cs[key] = space | |
# todo: create search space | |
cs[name] = options | |
return cls | |
return inner | |
def real(name, bound): | |
def inner(cls): | |
cs = get_search_space(cls) | |
# todo: create search space | |
cs[name] = bound | |
return cls | |
return inner | |
""" | |
Simple example when collect multi search space with class decorator | |
""" | |
@choice('tunable_choice', [1, 2, 3]) | |
@real('tunable_real', [1, 10]) | |
class Hello(Tunable): | |
def __init__(self, list_var, dict_var=1, **kargs) -> None: | |
pass | |
# print(Hello.get_search_spaces()) | |
""" | |
Deep inheritance | |
""" | |
@choice('have_seed', [True, False]) | |
class Food(Tunable): | |
pass | |
@choice('color', ['red', 'green']) | |
class Apple(Food, Tunable): | |
pass | |
# print(Food.get_search_spaces()) | |
# print(Apple.get_search_spaces()) | |
""" | |
class composition | |
""" | |
@choice('size', ['s', 'm', 'x']) | |
class Noddle(Food, Tunable): | |
pass | |
@choice('eat', [Apple, Noddle]) | |
class Lunch(Tunable): | |
pass | |
# print(Lunch.get_search_spaces()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment