Created
September 21, 2022 12:51
-
-
Save honno/6531d1e8d1acef9b3ef713200c76d91c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_strategies_namespace( | |
xp: Any, *, api_version: Optional[NominalVersion] = None | |
) -> SimpleNamespace: | |
... | |
if api_version is None: | |
# When api_version=None, we infer the most recent API version for which | |
# the passed xp is valid. We go through the released versions in | |
# descending order, passing them to x.__array_namespace__() until no | |
# errors are raised, thus inferring that specific api_version is | |
# supported. If errors are raised for all released versions, we raise | |
# our own useful error. | |
check_argument( | |
hasattr(xp, "zeros"), | |
f"Array module {xp.__name__} has no function zeros(), which is " | |
"required when inferring api_version.", | |
) | |
errmsg = ( | |
f"Could not infer any api_version which module {xp.__name__} " | |
f"supports. If you believe {xp.__name__} is indeed an Array API " | |
"module, try explicitly passing an api_version." | |
) | |
try: | |
array = xp.zeros(1) | |
except Exception: | |
raise InvalidArgument(errmsg) | |
for api_version in reversed(RELEASED_VERSIONS): | |
with contextlib.suppress(Exception): | |
xp = array.__array_namespace__(api_version=api_version) | |
break # i.e. a valid xp and api_version has been inferred | |
else: | |
raise InvalidArgument(errmsg) | |
... | |
# Tests ------------------------------------------------------------------------ | |
# test_partial_adopters.py | |
def test_raises_on_inferring_with_no_zeros_func(): | |
"""When xp has no zeros(), inferring api_version raises helpful error.""" | |
xp = make_mock_xp(exclude=("zeros",)) | |
with pytest.raises(InvalidArgument, match="has no function"): | |
make_strategies_namespace(xp) | |
def test_raises_on_erroneous_zeros_func(): | |
"""When xp has erroneous zeros(), inferring api_version raises helpful error.""" | |
xp = make_mock_xp() | |
xp.zeros = None | |
with pytest.raises(InvalidArgument): | |
make_strategies_namespace(xp) | |
# test_strategies_namespace.py | |
class MockArray: | |
def __init__(self, supported_versions: Tuple[NominalVersion, ...]): | |
assert len(set(supported_versions)) == len(supported_versions) # sanity check | |
self.supported_versions = supported_versions | |
def __array_namespace__(self, *, api_version: Optional[NominalVersion] = None): | |
if api_version is not None and api_version not in self.supported_versions: | |
raise | |
return SimpleNamespace( | |
__name__="foopy", zeros=lambda _: MockArray(self.supported_versions) | |
) | |
version_permutations: List[Tuple[NominalVersion, ...]] = [ | |
RELEASED_VERSIONS[:i] for i in range(1, len(RELEASED_VERSIONS) + 1) | |
] | |
@pytest.mark.parametrize( | |
"supported_versions", | |
version_permutations, | |
ids=lambda supported_versions: "-".join(supported_versions), | |
) | |
def test_version_inferrence(supported_versions): | |
"""Latest supported api_version is inferred.""" | |
xp = MockArray(supported_versions).__array_namespace__() | |
xps = make_strategies_namespace(xp) | |
assert xps.api_version == supported_versions[-1] | |
def test_raises_on_inferring_with_no_supported_versions(): | |
"""When xp supports no versions, inferring api_version raises helpful error.""" | |
xp = MockArray(()).__array_namespace__() | |
with pytest.raises(InvalidArgument): | |
xps = make_strategies_namespace(xp) | |
@pytest.mark.parametrize( | |
("api_version", "supported_versions"), | |
[pytest.param(p[-1], p[:-1], id=p[-1]) for p in version_permutations], | |
) | |
def test_warns_on_specifying_unsupported_version(api_version, supported_versions): | |
"""Specifying an api_version which xp does not support executes with a warning.""" | |
xp = MockArray(supported_versions).__array_namespace__() | |
xp.zeros = None | |
with pytest.warns(HypothesisWarning): | |
xps = make_strategies_namespace(xp, api_version=api_version) | |
assert xps.api_version == api_version |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment