Mark MHLO and TOSA dialect ops illegal at input to IREE core compiler. (#6707)
Fixes #6692
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 9b0d0a6..a7d9f04 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -57,6 +57,7 @@
"SimplifyVariableAccesses.cpp",
"StripAndSplatConstantVariables.cpp",
"TypeConverter.cpp",
+ "VerifyInputLegality.cpp",
],
hdrs = [
"DestructiveUpdateUtils.h",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 17924a2..6095dc8 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -54,6 +54,7 @@
"SimplifyVariableAccesses.cpp"
"StripAndSplatConstantVariables.cpp"
"TypeConverter.cpp"
+ "VerifyInputLegality.cpp"
DEPS
::PassesIncGen
LLVMSupport
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 130c13f..b6745e5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -87,6 +87,10 @@
// Expands dynamic !shapex.ranked_shape dimensions in variables.
std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass();
+/// Verified if the input to the Flow transformation passes has operations from
+/// dialects that are expected to be legalized before this pass.
+std::unique_ptr<OperationPass<FuncOp>> createVerifyInputLegalityPass();
+
//===----------------------------------------------------------------------===//
// Dispatches (flow.dispatch.workgroups)
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 60f708a..8ac2564 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -146,4 +146,9 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createStripAndSplatConstantVariablesPass()";
}
+def VerifyInputLegality: Pass<"iree-verify-input-legality", "FuncOp"> {
+ let summary = "Checks the legality of the IR at the start of IREE compilation flow.";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createVerifyInputLegalityPass()";
+}
+
#endif // IREE_DIALECT_FLOW_PASSES
diff --git a/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp b/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp
new file mode 100644
index 0000000..30617e2
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp
@@ -0,0 +1,46 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+class VerifyInputLegalityPass
+ : public VerifyInputLegalityBase<VerifyInputLegalityPass> {
+ void runOnOperation() override {
+ FuncOp funcOp = getOperation();
+ auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
+ StringRef opDialectName = op->getDialect()->getNamespace();
+ if (opDialectName == "mhlo" || opDialectName == "tosa") {
+ return op->emitOpError(
+ "illegal operation in input to iree core compiler. Use "
+ "-iree-input-type=")
+ << opDialectName << " to legalize this operation";
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted()) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createVerifyInputLegalityPass() {
+ return std::make_unique<VerifyInputLegalityPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index 12c0f37..7fb9c4b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -41,6 +41,7 @@
"simplify_variable_accesses.mlir",
"strip_and_splat_constant_variables.mlir",
"transformation.mlir",
+ "verify_input_ir.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 83b449a..0c8ca09 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -38,6 +38,7 @@
"simplify_variable_accesses.mlir"
"strip_and_splat_constant_variables.mlir"
"transformation.mlir"
+ "verify_input_ir.mlir"
DATA
iree::tools::IreeFileCheck
iree::tools::iree-opt
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir b/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir
new file mode 100644
index 0000000..1759162
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-opt -iree-verify-input-legality -verify-diagnostics %s -split-input-file
+
+func @check_no_mlir(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{illegal operation in input to iree core compiler. Use -iree-input-type=mhlo to legalize this operation}}
+ %0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @check_no_tosa(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{illegal operation in input to iree core compiler. Use -iree-input-type=tosa to legalize this operation}}
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}