Mark MHLO and TOSA dialect ops illegal at input to IREE core compiler. (#6707)

Fixes #6692
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 9b0d0a6..a7d9f04 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -57,6 +57,7 @@
         "SimplifyVariableAccesses.cpp",
         "StripAndSplatConstantVariables.cpp",
         "TypeConverter.cpp",
+        "VerifyInputLegality.cpp",
     ],
     hdrs = [
         "DestructiveUpdateUtils.h",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 17924a2..6095dc8 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -54,6 +54,7 @@
     "SimplifyVariableAccesses.cpp"
     "StripAndSplatConstantVariables.cpp"
     "TypeConverter.cpp"
+    "VerifyInputLegality.cpp"
   DEPS
     ::PassesIncGen
     LLVMSupport
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 130c13f..b6745e5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -87,6 +87,10 @@
 // Expands dynamic !shapex.ranked_shape dimensions in variables.
 std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass();
 
+/// Verified if the input to the Flow transformation passes has operations from
+/// dialects that are expected to be legalized before this pass.
+std::unique_ptr<OperationPass<FuncOp>> createVerifyInputLegalityPass();
+
 //===----------------------------------------------------------------------===//
 // Dispatches (flow.dispatch.workgroups)
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 60f708a..8ac2564 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -146,4 +146,9 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createStripAndSplatConstantVariablesPass()";
 }
 
+def VerifyInputLegality: Pass<"iree-verify-input-legality", "FuncOp"> {
+  let summary = "Checks the legality of the IR at the start of IREE compilation flow.";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createVerifyInputLegalityPass()";
+}
+
 #endif  // IREE_DIALECT_FLOW_PASSES
diff --git a/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp b/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp
new file mode 100644
index 0000000..30617e2
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp
@@ -0,0 +1,46 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+class VerifyInputLegalityPass
+    : public VerifyInputLegalityBase<VerifyInputLegalityPass> {
+  void runOnOperation() override {
+    FuncOp funcOp = getOperation();
+    auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
+      StringRef opDialectName = op->getDialect()->getNamespace();
+      if (opDialectName == "mhlo" || opDialectName == "tosa") {
+        return op->emitOpError(
+                   "illegal operation in input to iree core compiler. Use "
+                   "-iree-input-type=")
+               << opDialectName << " to legalize this operation";
+      }
+      return WalkResult::advance();
+    });
+    if (walkResult.wasInterrupted()) {
+      return signalPassFailure();
+    }
+  }
+};
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createVerifyInputLegalityPass() {
+  return std::make_unique<VerifyInputLegalityPass>();
+}
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index 12c0f37..7fb9c4b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -41,6 +41,7 @@
             "simplify_variable_accesses.mlir",
             "strip_and_splat_constant_variables.mlir",
             "transformation.mlir",
+            "verify_input_ir.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 83b449a..0c8ca09 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -38,6 +38,7 @@
     "simplify_variable_accesses.mlir"
     "strip_and_splat_constant_variables.mlir"
     "transformation.mlir"
+    "verify_input_ir.mlir"
   DATA
     iree::tools::IreeFileCheck
     iree::tools::iree-opt
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir b/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir
new file mode 100644
index 0000000..1759162
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-opt -iree-verify-input-legality -verify-diagnostics %s -split-input-file
+
+func @check_no_mlir(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{illegal operation in input to iree core compiler. Use -iree-input-type=mhlo to legalize this operation}}
+  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @check_no_tosa(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{illegal operation in input to iree core compiler. Use -iree-input-type=tosa to legalize this operation}}
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}