Replace mhlo ops with core dialect ops in more tests. (#9838)
Progress on https://github.com/iree-org/iree/issues/9667, working towards removing MHLO and other input dialects from the "core" parts of the IREE compiler. Any tests using input dialects should be organized under the relevant `compiler/InputConversion/*` or `tests/e2e/*_ops/` directories.
Notable remaining tests using mhlo in the "core" of the compiler:
* [demote_f32_to_f16.mlir](https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Transforms/test/demote_f32_to_f16.mlir)
* [demote_f64_to_f32.mlir](https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Transforms/test/demote_f64_to_f32.mlir)
* [demote_i64_to_i32.mlir](https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Transforms/test/demote_i64_to_i32.mlir)
* [promote_f16_to_f32.mlir](https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Transforms/test/promote_f16_to_f32.mlir)
diff --git a/compiler/src/iree/compiler/API/python/test/tools/compiler_core_test.py b/compiler/src/iree/compiler/API/python/test/tools/compiler_core_test.py
index 68dba25..87b0c5d 100644
--- a/compiler/src/iree/compiler/API/python/test/tools/compiler_core_test.py
+++ b/compiler/src/iree/compiler/API/python/test/tools/compiler_core_test.py
@@ -15,8 +15,8 @@
SIMPLE_MUL_ASM = """
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- return %0 : tensor<4xf32>
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
}
"""
@@ -35,7 +35,6 @@
def testCompileStr(self):
binary = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
logging.info("Flatbuffer size = %d", len(binary))
self.assertTrue(binary)
@@ -45,7 +44,6 @@
# specifically. See: https://github.com/iree-org/iree/issues/4439
def testCompileStrLLVMAOT(self):
binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM,
- input_type="mhlo",
target_backends=["dylib-llvm-aot"])
logging.info("Flatbuffer size = %d", len(binary))
self.assertTrue(binary)
@@ -56,7 +54,6 @@
def testCompileMultipleBackends(self):
binary = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
target_backends=["dylib-llvm-aot", "vulkan-spirv"])
logging.info("Flatbuffer size = %d", len(binary))
self.assertTrue(binary)
@@ -68,7 +65,6 @@
f.close()
binary = iree.compiler.tools.compile_file(
f.name,
- input_type="mhlo",
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
finally:
os.remove(f.name)
@@ -81,7 +77,6 @@
f.close()
output = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
output_file=f.name,
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
self.assertIsNone(output)
@@ -96,7 +91,6 @@
def testOutputFbText(self):
text = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
output_format=iree.compiler.tools.OutputFormat.FLATBUFFER_TEXT,
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
"utf-8")
@@ -119,7 +113,6 @@
"FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT"):
_ = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
output_format="foobar",
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
@@ -136,7 +129,6 @@
def testOutputMlirText(self):
text = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
output_format=iree.compiler.tools.OutputFormat.MLIR_TEXT,
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
"utf-8")
@@ -148,7 +140,6 @@
with io.StringIO() as buf, contextlib.redirect_stderr(buf):
iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
extra_args=["--mlir-timing"],
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
stderr = buf.getvalue()
@@ -157,7 +148,6 @@
def testAllOptions(self):
binary = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
optimize=False,
strip_debug_ops=True,
strip_source_map=True,
@@ -179,7 +169,6 @@
with iree.compiler.tools.TempFileSaver(temp_dir.name):
output = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
output_file=output_file.name,
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
self.assertIsNone(output)
@@ -202,7 +191,6 @@
with iree.compiler.tools.TempFileSaver(temp_dir.name):
output = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
self.assertIsNotNone(output)
self.assertGreater(len(output), 0)
@@ -226,7 +214,6 @@
with iree.compiler.tools.TempFileSaver(temp_dir.name):
output = iree.compiler.tools.compile_str(
SIMPLE_MUL_ASM,
- input_type="mhlo",
target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
self.assertIsNotNone(output)
self.assertGreater(len(output), 0)
diff --git a/compiler/src/iree/compiler/API/python/test/transforms/ireec/compile_sample_module.py b/compiler/src/iree/compiler/API/python/test/transforms/ireec/compile_sample_module.py
index c2a8215..92ac660 100644
--- a/compiler/src/iree/compiler/API/python/test/transforms/ireec/compile_sample_module.py
+++ b/compiler/src/iree/compiler/API/python/test/transforms/ireec/compile_sample_module.py
@@ -5,7 +5,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import io
-import subprocess
from iree.compiler import ir
from iree.compiler import passmanager
@@ -33,16 +32,14 @@
input_module = ir.Module.parse(r"""
builtin.module {
- func.func @fabs(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
- %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
- %1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
- return %1 : tensor<4x4xf32>
+ func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
}
}
""")
- options = ireec.CompilerOptions("--iree-hal-target-backends=cpu",
- "--iree-input-type=mhlo")
+ options = ireec.CompilerOptions("--iree-hal-target-backends=cpu")
print(options)
pm = passmanager.PassManager()
ireec.build_iree_vm_pass_pipeline(options, pm)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index e090148..e7b4e96 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -700,7 +700,7 @@
```mlir
%c = flow.tensor.constant tensor<2x2xf32> -> tensor<?x?xf32>
- %res = "mhlo.abs"(%c) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %res = math.abs %c : tensor<?x?xf32>
```
}];
let arguments = (ins ElementsAttr:$value);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
index 6722021..645a584 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
@@ -1,50 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-mhlo-input-transformation-pipeline --iree-flow-transformation-pipeline --iree-flow-export-benchmark-funcs --verify-diagnostics %s | FileCheck %s
-
-module {
- func.func @two_dispatch(%arg0: tensor<5x3xf32>, %arg1: tensor<3x5xf32>) -> (tensor<5x5xf32>, tensor<3x5xf32>) {
- %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
- %1 = "mhlo.dot"(%arg1, %0) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
- return %0, %1 : tensor<5x5xf32>, tensor<3x5xf32>
- }
-}
-
-// CHECK-DAG: util.global private @[[GLOBAL_ARG0:.+]] {noinline} = dense<{{.*}}> : tensor<5x3xf32>
-// CHECK-DAG: util.global private @[[GLOBAL_ARG1:.+]] {noinline} = dense<{{.*}}> : tensor<3x5xf32>
-
-// CHECK: func.func @two_dispatch_benchmark()
-// CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "entry"}}
-// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : tensor<5x3xf32>
-// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : tensor<3x5xf32>
-// CHECK: %[[RET:.+]]:2 = call @two_dispatch(%[[ARG0]], %[[ARG1]])
-// CHECK-DAG: util.do_not_optimize(%[[RET]]#0) : tensor<5x5xf32>
-// CHECK-DAG: util.do_not_optimize(%[[RET]]#1) : tensor<3x5xf32>
-
-// -----
-
-func.func @while(%start: tensor<i32>, %bound: tensor<i32>) -> tensor<i32> {
- cf.br ^bb1(%start : tensor<i32>)
-^bb1(%0: tensor<i32>):
- %1 = "mhlo.compare"(%0, %bound) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
- %2 = tensor.extract %1[] : tensor<i1>
- cf.cond_br %2, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
-^bb2(%3: tensor<i32>):
- %4 = arith.addi %3, %3 : tensor<i32>
- cf.br ^bb1(%4 : tensor<i32>)
-^bb3(%5: tensor<i32>):
- return %5 : tensor<i32>
-}
-
-// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {noinline} = dense<0> : tensor<i32>
-// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {noinline} = dense<0> : tensor<i32>
-
-// CHECK: func.func @while_benchmark()
-// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : tensor<i32>
-// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : tensor<i32>
-// CHECK: %[[RET0:.+]] = call @while(%[[ARG0]], %[[ARG1]])
-// CHECK: util.do_not_optimize(%[[RET0]]) : tensor<i32>
-// CHECK: return
-
-// -----
+// RUN: iree-opt --split-input-file --iree-flow-transformation-pipeline --iree-flow-export-benchmark-funcs --verify-diagnostics %s | FileCheck %s
// Basic usage from the `--iree-native-bindings-support` flag.
@@ -69,6 +23,32 @@
// -----
+// Ensures that functions with multiple blocks are handled correctly.
+
+func.func @while(%start: i32, %bound: i32) -> i32 {
+ cf.br ^bb1(%start : i32)
+^bb1(%0: i32):
+ %1 = arith.cmpi slt, %0, %bound : i32
+ cf.cond_br %1, ^bb2(%0 : i32), ^bb3(%0 : i32)
+^bb2(%3: i32):
+ %4 = arith.addi %3, %3 : i32
+ cf.br ^bb1(%4 : i32)
+^bb3(%5: i32):
+ return %5 : i32
+}
+
+// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {noinline} = 0 : i32
+// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {noinline} = 0 : i32
+
+// CHECK: func.func @while_benchmark()
+// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : i32
+// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : i32
+// CHECK: %[[RET0:.+]] = call @while(%[[ARG0]], %[[ARG1]])
+// CHECK: util.do_not_optimize(%[[RET0]]) : i32
+// CHECK: return
+
+// -----
+
// Ensure the tensors we allocate are of the desired type after casting.
// CHECK-LABEL: func private @importBufferViewBitcasting
diff --git a/runtime/bindings/python/tests/system_api_test.py b/runtime/bindings/python/tests/system_api_test.py
index 081e739..7065bd4 100644
--- a/runtime/bindings/python/tests/system_api_test.py
+++ b/runtime/bindings/python/tests/system_api_test.py
@@ -22,12 +22,11 @@
"""
module @arithmetic {
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- return %0 : tensor<4xf32>
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
}
}
""",
- input_type="mhlo",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index ec47e71..5a26631 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -22,7 +22,6 @@
return %0 : i32
}
""",
- input_type="mhlo",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
@@ -33,11 +32,10 @@
binary = iree.compiler.compile_str(
"""
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- return %0 : tensor<4xf32>
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
}
""",
- input_type="mhlo",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
@@ -45,17 +43,14 @@
def create_simple_dynamic_abs_module():
- # TODO(laurenzo): Compile for more backends as dynamic shapes come online.
- target_backends = iree.compiler.DEFAULT_TESTING_BACKENDS
binary = iree.compiler.compile_str(
"""
- func.func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+ func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = math.abs %arg0 : tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
}
""",
- input_type="mhlo",
- target_backends=target_backends,
+ target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
return m
@@ -168,7 +163,7 @@
m = create_simple_dynamic_abs_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
- f = m.lookup_function("simple_mul")
+ f = m.lookup_function("dynamic_abs")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
result = finv(arg0)
diff --git a/runtime/src/iree/runtime/demo/README.md b/runtime/src/iree/runtime/demo/README.md
index b4b0f02..65ef51f 100644
--- a/runtime/src/iree/runtime/demo/README.md
+++ b/runtime/src/iree/runtime/demo/README.md
@@ -7,9 +7,8 @@
tensors and returns the result:
```mlir
-func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
- {
- %0 = "mhlo.multiply"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
```
diff --git a/samples/static_library/README.md b/samples/static_library/README.md
index 39420a1..97f3dc9 100644
--- a/samples/static_library/README.md
+++ b/samples/static_library/README.md
@@ -9,10 +9,8 @@
`simple_mul` that returns the multiplication of two tensors:
```mlir
-func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
-{
- %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>,
- tensor<4xf32>) -> tensor<4xf32>
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
```
diff --git a/samples/static_library/simple_mul.mlir b/samples/static_library/simple_mul.mlir
index e5f4d5b..0b8a2f0 100644
--- a/samples/static_library/simple_mul.mlir
+++ b/samples/static_library/simple_mul.mlir
@@ -1,5 +1,4 @@
-func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
-{
- %0 = "arith.mulf"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}