Created
May 21, 2021 09:51
-
-
Save pmeier/ea35bdffb597b35f4f6592c5ac201cd4 to your computer and use it in GitHub Desktop.
Check inter-category type promotion behavior of 0d-tensors for array API compatibility
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from typing import Collection | |
import networkx as nx | |
# overwrite this with the array API that you want to test | |
import numpy as array_api | |
def maybe_add_dtype( | |
graph: nx.Graph, name: str, promotes_to_names: Collection[str] = () | |
) -> None: | |
try: | |
dtype = getattr(array_api, name) | |
except AttributeError: | |
return | |
graph.add_node(dtype) | |
for name in promotes_to_names: | |
try: | |
promoted_dtype = getattr(array_api, name) | |
except AttributeError: | |
continue | |
graph.add_edge(dtype, promoted_dtype) | |
integral_graph = nx.DiGraph() | |
maybe_add_dtype(integral_graph, "int8", ("int16",)) | |
maybe_add_dtype(integral_graph, "int16", ("int32",)) | |
maybe_add_dtype(integral_graph, "int32", ("int64",)) | |
maybe_add_dtype(integral_graph, "int64") | |
maybe_add_dtype(integral_graph, "uint8", ("uint16", "int16")) | |
maybe_add_dtype(integral_graph, "uint16", ("uint32", "int32")) | |
maybe_add_dtype(integral_graph, "uint32", ("uint64", "int64")) | |
maybe_add_dtype(integral_graph, "uint64") | |
floating_graph = nx.DiGraph() | |
maybe_add_dtype(floating_graph, "float32", ("float64",)) | |
maybe_add_dtype(floating_graph, "float64") | |
for graph in (integral_graph, floating_graph): | |
reverse_graph = graph.reverse() | |
for dtype_0d, dtype_nd in itertools.product(graph.nodes, repeat=2): | |
dtype_expected = nx.lowest_common_ancestor(reverse_graph, dtype_0d, dtype_nd) | |
if not dtype_expected: | |
continue | |
a = array_api.empty((), dtype=dtype_0d) | |
b = array_api.empty((1,), dtype=dtype_nd) | |
dtype_actual = array_api.result_type(a, b) | |
if dtype_actual != dtype_expected: | |
print(f"0d {dtype_0d} + nd {dtype_nd} = {dtype_actual} != {dtype_expected}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is cool! Thanks.