Created
November 8, 2024 19:55
-
-
Save makslevental/bcf09040e77948098450a4a6b43ada99 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: compiler/src/iree/compiler/API/CMakeLists.txt | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt | |
--- a/compiler/src/iree/compiler/API/CMakeLists.txt (revision be41632fdca0ba6102d48a8cf7108fca67297056) | |
+++ b/compiler/src/iree/compiler/API/CMakeLists.txt (date 1731095153333) | |
@@ -25,6 +25,7 @@ | |
MLIRCAPITransformDialect | |
MLIRCAPITransformDialectTransforms | |
MLIRCAPITransforms | |
+ StablehloCAPI | |
iree::compiler::API::Internal::CompilerDriver | |
iree::compiler::API::Internal::IREECompileToolEntryPoint | |
iree::compiler::API::Internal::IREEMLIRLSPServerToolEntryPoint | |
@@ -76,6 +77,8 @@ | |
obj.MLIRCAPITransforms | |
obj.MLIRCAPITransformDialect | |
obj.MLIRCAPITransformDialectTransforms | |
+ obj.StablehloCAPI | |
+ StablehloCAPI | |
iree_compiler_API_Internal_CompilerDriver.objects | |
iree_compiler_API_Internal_IREECompileToolEntryPoint.objects | |
iree_compiler_API_Internal_IREEGPUDialectCAPI.objects | |
@@ -89,6 +92,42 @@ | |
message(STATUS "Bundling additional exports into libIREECompiler.so: ${IREE_COMPILER_API_ADDL_EXPORT_OBJECTS}") | |
list(APPEND _EXPORT_OBJECT_LIBS ${IREE_COMPILER_API_ADDL_EXPORT_OBJECTS}) | |
endif() | |
+ | |
+# Get all propreties that cmake supports | |
+if(NOT CMAKE_PROPERTY_LIST) | |
+ execute_process(COMMAND cmake --help-property-list OUTPUT_VARIABLE CMAKE_PROPERTY_LIST) | |
+ | |
+ # Convert command output into a CMake list | |
+ string(REGEX REPLACE ";" "\\\\;" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") | |
+ string(REGEX REPLACE "\n" ";" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") | |
+ list(REMOVE_DUPLICATES CMAKE_PROPERTY_LIST) | |
+endif() | |
+ | |
+function(print_properties) | |
+ message("CMAKE_PROPERTY_LIST = ${CMAKE_PROPERTY_LIST}") | |
+endfunction() | |
+ | |
+function(print_target_properties target) | |
+ if(NOT TARGET ${target}) | |
+ message(STATUS "There is no target named '${target}'") | |
+ return() | |
+ endif() | |
+ | |
+ foreach(property ${CMAKE_PROPERTY_LIST}) | |
+ string(REPLACE "<CONFIG>" "${CMAKE_BUILD_TYPE}" property ${property}) | |
+ | |
+ # Fix https://stackoverflow.com/questions/32197663/how-can-i-remove-the-the-location-property-may-not-be-read-from-target-error-i | |
+ if(property STREQUAL "LOCATION" OR property MATCHES "^LOCATION_" OR property MATCHES "_LOCATION$") | |
+ continue() | |
+ endif() | |
+ | |
+ get_property(was_set TARGET ${target} PROPERTY ${property} SET) | |
+ if(was_set) | |
+ get_target_property(value ${target} ${property}) | |
+ message("${target} ${property} = ${value}") | |
+ endif() | |
+ endforeach() | |
+endfunction() | |
set(_EXPORT_OBJECT_SRCS) | |
set(_EXPORT_OBJECT_DEPS) | |
@@ -136,11 +175,19 @@ | |
# that show like generator expressions are showing up in link lines, this is | |
# the culprit. Look at the export_objects_debug.txt to confirm. Then, add | |
# another level of fix upstream if you like pain. | |
- list(APPEND _EXPORT_OBJECT_DEPS "$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<TARGET_PROPERTY:${_object_lib},LINK_LIBRARIES>>>>>") | |
+ message(STATUS "wtfbbq ${_object_lib}") | |
+ print_target_properties("${_object_lib}") | |
+ get_target_property(type "${_object_lib}" TYPE) | |
+ if (${type} STREQUAL "OBJECT_LIBRARY") | |
+ list(APPEND _EXPORT_OBJECT_DEPS "$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<GENEX_EVAL:$<TARGET_PROPERTY:${_object_lib},LINK_LIBRARIES>>>>>>>>") | |
+ elseif (${type} STREQUAL "STATIC_LIBRARY") | |
+ get_target_property(_libs "${_object_lib}" INTERFACE_LINK_LIBRARIES) | |
+ list(APPEND _EXPORT_OBJECT_DEPS "${_libs}") | |
+ endif() | |
endforeach() | |
# UNCOMMENT TO DEBUG WHAT IS GOING ON. | |
-# file(GENERATE OUTPUT export_objects_debug.txt CONTENT "OBJECTS:${_EXPORT_OBJECT_SRCS}\n\nDEPS:${_EXPORT_OBJECT_DEPS}") | |
+ file(GENERATE OUTPUT export_objects_debug.txt CONTENT "OBJECTS:${_EXPORT_OBJECT_SRCS}\n\nDEPS:${_EXPORT_OBJECT_DEPS}") | |
# Disable .so.0 style naming/linking. In order to be consistent across platforms | |
# and bindings, we will embed a major version in the library name when it is time. | |
Index: CMakeLists.txt | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/CMakeLists.txt b/CMakeLists.txt | |
--- a/CMakeLists.txt (revision be41632fdca0ba6102d48a8cf7108fca67297056) | |
+++ b/CMakeLists.txt (date 1731090415714) | |
@@ -836,12 +836,6 @@ | |
iree_llvm_add_usage_requirements(MLIRSupport IREELLVMIncludeSetup) | |
# Add external projects. | |
- | |
- message(STATUS "Configuring llvm-external-projects/mlir-iree-dialects") | |
- list(APPEND CMAKE_MESSAGE_INDENT " ") | |
- iree_llvm_add_external_project(mlir-iree-dialects ${CMAKE_CURRENT_SOURCE_DIR}/llvm-external-projects/iree-dialects) | |
- list(POP_BACK CMAKE_MESSAGE_INDENT) | |
- | |
if(IREE_INPUT_STABLEHLO) | |
message(STATUS "Configuring third_party/stablehlo") | |
list(APPEND CMAKE_MESSAGE_INDENT " ") | |
@@ -849,6 +843,11 @@ | |
list(POP_BACK CMAKE_MESSAGE_INDENT) | |
endif() | |
+ message(STATUS "Configuring llvm-external-projects/mlir-iree-dialects") | |
+ list(APPEND CMAKE_MESSAGE_INDENT " ") | |
+ iree_llvm_add_external_project(mlir-iree-dialects ${CMAKE_CURRENT_SOURCE_DIR}/llvm-external-projects/iree-dialects) | |
+ list(POP_BACK CMAKE_MESSAGE_INDENT) | |
+ | |
# Ensure that LLVM-based dependencies needed for testing are included. | |
add_dependencies(iree-test-deps FileCheck) | |
if(IREE_LLD_TARGET) | |
Index: compiler/bindings/python/CMakeLists.txt | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt | |
--- a/compiler/bindings/python/CMakeLists.txt (revision be41632fdca0ba6102d48a8cf7108fca67297056) | |
+++ b/compiler/bindings/python/CMakeLists.txt (date 1731092599045) | |
@@ -181,9 +181,111 @@ | |
IREECompilerDialectsModule.cpp | |
EMBED_CAPI_LINK_LIBS | |
iree_compiler_API_SharedImpl | |
+ StablehloCAPI | |
PRIVATE_LINK_LIBS | |
LLVMSupport | |
) | |
+ | |
+if (IREE_INPUT_STABLEHLO) | |
+ set(STABLEHLO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/stablehlo") | |
+ set(STABLEHLO_PYTHON_SOURCE_DIR "${STABLEHLO_SOURCE_DIR}/stablehlo/integrations/python") | |
+ include_directories(${STABLEHLO_SOURCE_DIR}) | |
+ | |
+ declare_mlir_python_sources(CheckPythonSources.Dialects | |
+ ADD_TO_PARENT IREEPythonSources | |
+ ) | |
+ | |
+ declare_mlir_dialect_python_bindings( | |
+ ADD_TO_PARENT CheckPythonSources.Dialects | |
+ ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+ TD_FILE dialects/CheckOps.td | |
+ SOURCES dialects/check.py | |
+ DIALECT_NAME check) | |
+ | |
+ declare_mlir_python_sources(ChloPythonSources.Dialects | |
+ ADD_TO_PARENT IREEPythonSources | |
+ ) | |
+ | |
+ declare_mlir_dialect_python_bindings( | |
+ ADD_TO_PARENT ChloPythonSources.Dialects | |
+ ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+ TD_FILE dialects/ChloOps.td | |
+ SOURCES dialects/chlo.py | |
+ DIALECT_NAME chlo) | |
+ | |
+ declare_mlir_python_sources(StablehloPythonSources.Dialects | |
+ ADD_TO_PARENT IREEPythonSources | |
+ ) | |
+ | |
+ declare_mlir_dialect_python_bindings( | |
+ ADD_TO_PARENT StablehloPythonSources.Dialects | |
+ ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+ TD_FILE dialects/StablehloOps.td | |
+ SOURCES dialects/stablehlo.py | |
+ DIALECT_NAME stablehlo) | |
+ | |
+ declare_mlir_python_sources(VhloPythonSources.Dialects | |
+ ADD_TO_PARENT IREEPythonSources | |
+ ) | |
+ | |
+ declare_mlir_dialect_python_bindings( | |
+ ADD_TO_PARENT VhloPythonSources.Dialects | |
+ ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+ TD_FILE dialects/VhloOps.td | |
+ SOURCES dialects/vhlo.py | |
+ DIALECT_NAME vhlo) | |
+ | |
+ ################################################################################ | |
+ # Extensions | |
+ ################################################################################ | |
+ | |
+ set(STABLEHLO_PYTHON_SOURCE_DIR "/../../../third_party/stablehlo/stablehlo/integrations/python") | |
+ | |
+ declare_mlir_python_extension(CheckPythonExtensions.Main | |
+ MODULE_NAME _check | |
+ ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects | |
+ SOURCES | |
+ "${STABLEHLO_PYTHON_SOURCE_DIR}/CheckModule.cpp" | |
+ EMBED_CAPI_LINK_LIBS | |
+ CheckCAPI | |
+ PRIVATE_LINK_LIBS | |
+ LLVMSupport | |
+ ) | |
+ | |
+ declare_mlir_python_extension(ChloPythonExtensions.Main | |
+ MODULE_NAME _chlo | |
+ ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects | |
+ SOURCES | |
+ "${STABLEHLO_PYTHON_SOURCE_DIR}/ChloModule.cpp" | |
+ EMBED_CAPI_LINK_LIBS | |
+ ChloCAPI | |
+ PRIVATE_LINK_LIBS | |
+ LLVMSupport | |
+ ) | |
+ | |
+ declare_mlir_python_extension(StablehloPythonExtensions.Main | |
+ MODULE_NAME _stablehlo | |
+ ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects | |
+ SOURCES | |
+ "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloApi.cpp" | |
+ "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloModule.cpp" | |
+ EMBED_CAPI_LINK_LIBS | |
+ StablehloCAPI | |
+ PRIVATE_LINK_LIBS | |
+ LLVMSupport | |
+ ) | |
+ | |
+ declare_mlir_python_extension(VhloPythonExtensions.Main | |
+ MODULE_NAME _vhlo | |
+ ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects | |
+ SOURCES | |
+ "${STABLEHLO_PYTHON_SOURCE_DIR}/VhloModule.cpp" | |
+ EMBED_CAPI_LINK_LIBS | |
+ VhloCAPI | |
+ PRIVATE_LINK_LIBS | |
+ LLVMSupport | |
+ ) | |
+endif() | |
################################################################################ | |
# Generate packages and shared library | |
Index: compiler/bindings/python/test/ir/stablehlo.py | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/compiler/bindings/python/test/ir/stablehlo.py b/compiler/bindings/python/test/ir/stablehlo.py | |
new file mode 100644 | |
--- /dev/null (date 1731095512163) | |
+++ b/compiler/bindings/python/test/ir/stablehlo.py (date 1731095512163) | |
@@ -0,0 +1,388 @@ | |
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | |
+# Copyright 2022 The StableHLO Authors. | |
+# | |
+# Licensed under the Apache License, Version 2.0 (the "License"); | |
+# you may not use this file except in compliance with the License. | |
+# You may obtain a copy of the License at | |
+# | |
+# http://www.apache.org/licenses/LICENSE-2.0 | |
+# | |
+# Unless required by applicable law or agreed to in writing, software | |
+# distributed under the License is distributed on an "AS IS" BASIS, | |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+# See the License for the specific language governing permissions and | |
+# limitations under the License. | |
+# ============================================================================== | |
+"""Tests for StableHLO Python APIs.""" | |
+ | |
+# pylint: disable=wildcard-import,undefined-variable | |
+ | |
+import io | |
+import re | |
+from iree.compiler import ir, passmanager as pm | |
+from iree.compiler.dialects import stablehlo | |
+ | |
+import numpy as np | |
+ | |
+ | |
+def run(f): | |
+ with ir.Context() as context: | |
+ stablehlo.register_dialect(context) | |
+ f() | |
+ return f | |
+ | |
+ | |
+@run | |
+def test_channel_handle(): | |
+ attr = stablehlo.ChannelHandle.get(handle=1, type=2) | |
+ assert attr is not None | |
+ assert attr.handle == 1 | |
+ assert attr.channel_type == 2 | |
+ | |
+ | |
+@run | |
+def test_comparison_direction_attr(): | |
+ attr = stablehlo.ComparisonDirectionAttr.get("EQ") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<comparison_direction EQ>") | |
+ assert attr.value == "EQ" | |
+ | |
+ | |
+@run | |
+def test_comparison_type_attr(): | |
+ attr = stablehlo.ComparisonTypeAttr.get("FLOAT") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<comparison_type FLOAT>") | |
+ assert attr.value == "FLOAT" | |
+ | |
+ | |
+@run | |
+def test_conv_dimension_numbers(): | |
+ attr = stablehlo.ConvDimensionNumbers.get( | |
+ input_batch_dimension=0, | |
+ input_feature_dimension=1, | |
+ input_spatial_dimensions=[2, 3, 4], | |
+ kernel_input_feature_dimension=0, | |
+ kernel_output_feature_dimension=1, | |
+ kernel_spatial_dimensions=[2, 3], | |
+ output_batch_dimension=0, | |
+ output_feature_dimension=1, | |
+ output_spatial_dimensions=[2, 3]) | |
+ assert str(attr) == ("#stablehlo.conv<[b, f, 0, 1, 2]x[i, o, 0, 1]->" | |
+ "[b, f, 0, 1]>") | |
+ assert attr is not None | |
+ assert attr.input_batch_dimension == 0 | |
+ assert attr.input_feature_dimension == 1 | |
+ assert attr.input_spatial_dimensions == [2, 3, 4] | |
+ assert attr.kernel_input_feature_dimension == 0 | |
+ assert attr.kernel_output_feature_dimension == 1 | |
+ assert attr.kernel_spatial_dimensions == [2, 3] | |
+ assert attr.output_batch_dimension == 0 | |
+ assert attr.output_feature_dimension == 1 | |
+ assert attr.output_spatial_dimensions == [2, 3] | |
+ | |
+ | |
+@run | |
+def test_dot_algorithm(): | |
+ # BF16_BF16_F32_X3 | |
+ attr = stablehlo.DotAlgorithm.get( | |
+ lhs_precision_type=ir.BF16Type.get(), | |
+ rhs_precision_type=ir.BF16Type.get(), | |
+ accumulation_type=ir.F32Type.get(), | |
+ lhs_component_count=1, | |
+ rhs_component_count=1, | |
+ num_primitive_operations=3, | |
+ allow_imprecise_accumulation=False) | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo.dot_algorithm<lhs_precision_type = bf16, " | |
+ "rhs_precision_type = bf16, accumulation_type = f32, " | |
+ "lhs_component_count = 1, rhs_component_count = 1, " | |
+ "num_primitive_operations = 3, " | |
+ "allow_imprecise_accumulation = false>") | |
+ assert isinstance(attr.lhs_precision_type, ir.BF16Type) | |
+ assert isinstance(attr.rhs_precision_type, ir.BF16Type) | |
+ assert isinstance(attr.accumulation_type, ir.F32Type) | |
+ assert attr.lhs_component_count == 1 | |
+ assert attr.rhs_component_count == 1 | |
+ assert attr.num_primitive_operations == 3 | |
+ assert attr.allow_imprecise_accumulation == False | |
+ | |
+ | |
+@run | |
+def test_dot_dimension_numbers(): | |
+ attr = stablehlo.DotDimensionNumbers.get( | |
+ lhs_batching_dimensions=[0, 1], | |
+ rhs_batching_dimensions=[2, 3], | |
+ lhs_contracting_dimensions=[4, 5], | |
+ rhs_contracting_dimensions=[6, 7]) | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo.dot<lhs_batching_dimensions = [0, 1], " | |
+ "rhs_batching_dimensions = [2, 3], " | |
+ "lhs_contracting_dimensions = [4, 5], " | |
+ "rhs_contracting_dimensions = [6, 7]>") | |
+ assert attr.lhs_batching_dimensions == [0, 1] | |
+ assert attr.rhs_batching_dimensions == [2, 3] | |
+ assert attr.lhs_contracting_dimensions == [4, 5] | |
+ assert attr.rhs_contracting_dimensions == [6, 7] | |
+ | |
+ | |
+@run | |
+def test_fft_type_attr(): | |
+ attr = stablehlo.FftTypeAttr.get("FFT") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<fft_type FFT>") | |
+ assert attr.value == "FFT" | |
+ | |
+ | |
+@run | |
+def test_gather_dimension_numbers(): | |
+ attr = stablehlo.GatherDimensionNumbers.get( | |
+ offset_dims=[1, 2], | |
+ collapsed_slice_dims=[3, 4, 5], | |
+ operand_batching_dims=[6, 7], | |
+ start_indices_batching_dims=[8, 9], | |
+ start_index_map=[10], | |
+ index_vector_dim=11, | |
+ ) | |
+ assert attr is not None | |
+ assert str(attr) == ( | |
+ "#stablehlo.gather<offset_dims = [1, 2], " | |
+ "collapsed_slice_dims = [3, 4, 5], " | |
+ "operand_batching_dims = [6, 7], " | |
+ "start_indices_batching_dims = [8, 9], " | |
+ "start_index_map = [10], " | |
+ "index_vector_dim = 11>" | |
+ ) | |
+ assert attr.offset_dims == [1, 2] | |
+ assert attr.collapsed_slice_dims == [3, 4, 5] | |
+ assert attr.operand_batching_dims == [6, 7] | |
+ assert attr.start_indices_batching_dims == [8, 9] | |
+ assert attr.start_index_map == [10] | |
+ assert attr.index_vector_dim == 11 | |
+ | |
+ | |
+@run | |
+def test_output_operand_alias(): | |
+ attr = stablehlo.OutputOperandAlias.get( | |
+ output_tuple_indices=[0], | |
+ operand_index=0, | |
+ operand_tuple_indices=[1]) | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo.output_operand_alias<output_tuple_indices = [0], " | |
+ "operand_index = 0, " | |
+ "operand_tuple_indices = [1]>") | |
+ assert attr.output_tuple_indices == [0] | |
+ assert attr.operand_index == 0 | |
+ assert attr.operand_tuple_indices == [1] | |
+ | |
+ | |
+@run | |
+def test_precision_attr(): | |
+ attr = stablehlo.PrecisionAttr.get("DEFAULT") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<precision DEFAULT>") | |
+ assert attr.value == "DEFAULT" | |
+ | |
+ | |
+@run | |
+def test_rng_algorithm_attr(): | |
+ attr = stablehlo.RngAlgorithmAttr.get("DEFAULT") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<rng_algorithm DEFAULT>") | |
+ assert attr.value == "DEFAULT" | |
+ | |
+ | |
+@run | |
+def test_rng_distribution_attr(): | |
+ attr = stablehlo.RngDistributionAttr.get("UNIFORM") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<rng_distribution UNIFORM>") | |
+ assert attr.value == "UNIFORM" | |
+ | |
+ | |
+@run | |
+def test_scatter_dimension_numbers(): | |
+ attr = stablehlo.ScatterDimensionNumbers.get( | |
+ update_window_dims=[1, 2, 3], | |
+ inserted_window_dims=[4, 5], | |
+ input_batching_dims=[6, 7], | |
+ scatter_indices_batching_dims=[8, 9], | |
+ scattered_dims_to_operand_dims=[10, 11], | |
+ index_vector_dim=12, | |
+ ) | |
+ assert attr is not None | |
+ assert str(attr) == ( | |
+ "#stablehlo.scatter<update_window_dims = [1, 2, 3], " | |
+ "inserted_window_dims = [4, 5], " | |
+ "input_batching_dims = [6, 7], " | |
+ "scatter_indices_batching_dims = [8, 9], " | |
+ "scatter_dims_to_operand_dims = [10, 11], " | |
+ "index_vector_dim = 12>" | |
+ ) | |
+ assert attr.update_window_dims == [1, 2, 3] | |
+ assert attr.inserted_window_dims == [4, 5] | |
+ assert attr.input_batching_dims == [6, 7] | |
+ assert attr.scatter_indices_batching_dims == [8, 9] | |
+ assert attr.scattered_dims_to_operand_dims == [10, 11] | |
+ assert attr.index_vector_dim == 12 | |
+ | |
+ | |
+@run | |
+def test_transpose_attr(): | |
+ attr = stablehlo.TransposeAttr.get("TRANSPOSE") | |
+ assert attr is not None | |
+ assert str(attr) == ("#stablehlo<transpose TRANSPOSE>") | |
+ assert attr.value == "TRANSPOSE" | |
+ | |
+ | |
+@run | |
+def test_token_type(): | |
+ type = stablehlo.TokenType.get() | |
+ assert type is not None | |
+ assert str(type) == "!stablehlo.token" | |
+ | |
+ | |
+@run | |
+def test_type_extensions(): | |
+ dyn_size = ir.ShapedType.get_dynamic_size() | |
+ attr = stablehlo.TypeExtensions.get(bounds=[128, dyn_size]) | |
+ assert attr is not None | |
+ assert attr.bounds == [128, dyn_size] | |
+ | |
+ | |
+@run | |
+def test_api_version(): | |
+ api_version = stablehlo.get_api_version() | |
+ assert type(api_version) == int | |
+ assert api_version > 0 | |
+ | |
+ | |
+def is_semver_format(version_str): | |
+ return re.match("^\d+\.\d+\.\d+$", version_str) | |
+ | |
+ | |
+@run | |
+def test_current_version(): | |
+ curr_version = stablehlo.get_current_version() | |
+ assert is_semver_format(curr_version) | |
+ | |
+ | |
+@run | |
+def test_minimum_version(): | |
+ curr_version = stablehlo.get_minimum_version() | |
+ assert is_semver_format(curr_version) | |
+ | |
+ | |
+@run | |
+def test_version_requirements(): | |
+ for req in ( | |
+ stablehlo.StablehloCompatibilityRequirement.NONE, | |
+ stablehlo.StablehloCompatibilityRequirement.WEEK_4, | |
+ stablehlo.StablehloCompatibilityRequirement.WEEK_12, | |
+ stablehlo.StablehloCompatibilityRequirement.MAX, | |
+ ): | |
+ assert is_semver_format( | |
+ stablehlo.get_version_from_compatibility_requirement(req) | |
+ ) | |
+ | |
+ | |
+ASM_FORMAT = """ | |
+func.func @test(%arg0: tensor<{0}>) -> tensor<{0}> {{ | |
+ %0 = stablehlo.add %arg0, %arg0 : (tensor<{0}>, tensor<{0}>) -> tensor<{0}> | |
+ func.return %0 : tensor<{0}> | |
+}} | |
+""" | |
+ | |
+ | |
+# @run | |
+# def test_reference_api(): | |
+# # Formatted as (tensor_type, np_value) | |
+# # Program runs arg + arg, which is used for expected value | |
+# tests = [ | |
+# # No numpy types for f8 - skipping fp8 tests | |
+# ("f16", np.asarray(1, np.float16)), | |
+# ("f32", np.asarray(2, np.float32)), | |
+# ("f64", np.asarray(3, np.double)), | |
+# ("1xi8", np.asarray([4], np.int8)), | |
+# ("1xi16", np.asarray([5], np.int16)), | |
+# ("1xi32", np.asarray([-6], np.int32)), | |
+# # Numpy's uint treated as int by DenseElementsAttr, skipping np.uint tests | |
+# ("2x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), | |
+# ("2x1x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,1,2)), | |
+# ("?x?xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), | |
+# ("?x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), | |
+# ] | |
+# for test in tests: | |
+# tensor_type, arg = test | |
+# with ir.Context() as context: | |
+# stablehlo.register_dialect(context) | |
+# m = ir.Module.parse(ASM_FORMAT.format(tensor_type)) | |
+# args = [ir.DenseIntElementsAttr.get(arg)] | |
+# | |
+# actual = np.array(stablehlo.eval_module(m, args)[0]) | |
+# expected = arg + arg | |
+# assert (actual == expected).all() | |
+# | |
+ | |
+@run | |
+def test_get_smaller_version(): | |
+ curr_version = stablehlo.get_current_version() | |
+ min_version = stablehlo.get_minimum_version() | |
+ assert stablehlo.get_smaller_version(curr_version, min_version) == min_version | |
+ | |
+ | |
+@run | |
+def test_serialization_apis(): | |
+ curr_version = stablehlo.get_current_version() | |
+ | |
+ with ir.Context() as context: | |
+ stablehlo.register_dialect(context) | |
+ m = ir.Module.parse(ASM_FORMAT.format("2xf32")) | |
+ assert m is not None | |
+ module_str = str(m) | |
+ serialized = stablehlo.serialize_portable_artifact(m, curr_version) | |
+ deserialized = stablehlo.deserialize_portable_artifact(context, serialized) | |
+ assert module_str == str(deserialized) | |
+ | |
+ | |
+@run | |
+def test_str_serialization_apis(): | |
+ curr_version = stablehlo.get_current_version() | |
+ | |
+ def module_to_bytecode(module: ir.Module) -> bytes: | |
+ output = io.BytesIO() | |
+ module.operation.write_bytecode(file=output) | |
+ return output.getvalue() | |
+ | |
+ with ir.Context() as context: | |
+ stablehlo.register_dialect(context) | |
+ m = ir.Module.parse(ASM_FORMAT.format("2xf32")) | |
+ assert m is not None | |
+ module_str = str(m) | |
+ bytecode = module_to_bytecode(m) | |
+ serialized = stablehlo.serialize_portable_artifact_str( | |
+ bytecode, curr_version | |
+ ) | |
+ deserialized = stablehlo.deserialize_portable_artifact_str(serialized) | |
+ deserialized_module = ir.Module.parse(deserialized) | |
+ assert module_str == str(deserialized_module) | |
+ | |
+ | |
+@run | |
+def test_register_passes(): | |
+ """Tests pass registration.""" | |
+ with ir.Context() as context: | |
+ stablehlo.register_dialect(context) | |
+ module = ir.Module.parse(ASM_FORMAT.format("2xf32")) | |
+ assert module is not None | |
+ | |
+ stablehlo.register_stablehlo_passes() | |
+ pipeline = [ | |
+ "stablehlo-legalize-to-vhlo", | |
+ "vhlo-legalize-to-stablehlo", | |
+ ] | |
+ pipeline = pm.PassManager.parse(f"builtin.module({','.join(pipeline)})") | |
+ | |
+ cloned_module = module.operation.clone() | |
+ pipeline.run(cloned_module.operation) | |
+ assert str(module) == str(cloned_module) | |
Index: llvm-external-projects/iree-dialects/python/CMakeLists.txt | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt | |
--- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt (revision be41632fdca0ba6102d48a8cf7108fca67297056) | |
+++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt (date 1731090900725) | |
@@ -48,6 +48,107 @@ | |
LLVMSupport | |
) | |
+#if (IREE_INPUT_STABLEHLO) | |
+# set(STABLEHLO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/stablehlo") | |
+# set(STABLEHLO_PYTHON_SOURCE_DIR "${STABLEHLO_SOURCE_DIR}/stablehlo/integrations/python") | |
+# include_directories(${STABLEHLO_SOURCE_DIR}) | |
+# | |
+# declare_mlir_python_sources(CheckPythonSources.Dialects | |
+# ADD_TO_PARENT IREEDialectsPythonSources | |
+# ) | |
+# | |
+# declare_mlir_dialect_python_bindings( | |
+# ADD_TO_PARENT CheckPythonSources.Dialects | |
+# ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+# TD_FILE dialects/CheckOps.td | |
+# SOURCES dialects/check.py | |
+# DIALECT_NAME check) | |
+# | |
+# declare_mlir_python_sources(ChloPythonSources.Dialects | |
+# ADD_TO_PARENT IREEDialectsPythonSources | |
+# ) | |
+# | |
+# declare_mlir_dialect_python_bindings( | |
+# ADD_TO_PARENT ChloPythonSources.Dialects | |
+# ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+# TD_FILE dialects/ChloOps.td | |
+# SOURCES dialects/chlo.py | |
+# DIALECT_NAME chlo) | |
+# | |
+# declare_mlir_python_sources(StablehloPythonSources.Dialects | |
+# ADD_TO_PARENT IREEDialectsPythonSources | |
+# ) | |
+# | |
+# declare_mlir_dialect_python_bindings( | |
+# ADD_TO_PARENT StablehloPythonSources.Dialects | |
+# ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+# TD_FILE dialects/StablehloOps.td | |
+# SOURCES dialects/stablehlo.py | |
+# DIALECT_NAME stablehlo) | |
+# | |
+# declare_mlir_python_sources(VhloPythonSources.Dialects | |
+# ADD_TO_PARENT IREEDialectsPythonSources | |
+# ) | |
+# | |
+# declare_mlir_dialect_python_bindings( | |
+# ADD_TO_PARENT VhloPythonSources.Dialects | |
+# ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" | |
+# TD_FILE dialects/VhloOps.td | |
+# SOURCES dialects/vhlo.py | |
+# DIALECT_NAME vhlo) | |
+# | |
+# ################################################################################ | |
+# # Extensions | |
+# ################################################################################ | |
+# | |
+# set(STABLEHLO_PYTHON_SOURCE_DIR "/../../../third_party/stablehlo/stablehlo/integrations/python") | |
+# | |
+# declare_mlir_python_extension(CheckPythonExtensions.Main | |
+# MODULE_NAME _check | |
+# ADD_TO_PARENT IREEDialectsPythonExtensions | |
+# SOURCES | |
+# "${STABLEHLO_PYTHON_SOURCE_DIR}/CheckModule.cpp" | |
+# EMBED_CAPI_LINK_LIBS | |
+# CheckCAPI | |
+# PRIVATE_LINK_LIBS | |
+# LLVMSupport | |
+# ) | |
+# | |
+# declare_mlir_python_extension(ChloPythonExtensions.Main | |
+# MODULE_NAME _chlo | |
+# ADD_TO_PARENT IREEDialectsPythonExtensions | |
+# SOURCES | |
+# "${STABLEHLO_PYTHON_SOURCE_DIR}/ChloModule.cpp" | |
+# EMBED_CAPI_LINK_LIBS | |
+# ChloCAPI | |
+# PRIVATE_LINK_LIBS | |
+# LLVMSupport | |
+# ) | |
+# | |
+# declare_mlir_python_extension(StablehloPythonExtensions.Main | |
+# MODULE_NAME _stablehlo | |
+# ADD_TO_PARENT IREEDialectsPythonExtensions | |
+# SOURCES | |
+# "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloApi.cpp" | |
+# "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloModule.cpp" | |
+# EMBED_CAPI_LINK_LIBS | |
+# StablehloCAPI | |
+# PRIVATE_LINK_LIBS | |
+# LLVMSupport | |
+# ) | |
+# | |
+# declare_mlir_python_extension(VhloPythonExtensions.Main | |
+# MODULE_NAME _vhlo | |
+# ADD_TO_PARENT IREEDialectsPythonExtensions | |
+# SOURCES | |
+# "${STABLEHLO_PYTHON_SOURCE_DIR}/VhloModule.cpp" | |
+# EMBED_CAPI_LINK_LIBS | |
+# VhloCAPI | |
+# PRIVATE_LINK_LIBS | |
+# LLVMSupport | |
+# ) | |
+#endif() | |
+ | |
################################################################################ | |
# Generate packages and shared library | |
# Downstreams typically will not use these, but they are useful for local |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment