Reapply "Add VHLO support for IREE" (#17162) (#17179)

This reverts commit d407c7814fa146060b4b36877503413522efd7fa.

ci-extra:
build_test_all_windows,build_test_all_macos_arm64,build_test_all_macos_x86_64
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 608ee01..39a295c 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -86,6 +86,12 @@
                 "@stablehlo//:broadcast_utils": [
                     "StablehloBroadcastUtils",
                 ],
+                "@stablehlo//:stablehlo_passes": [
+                    "StablehloPasses",
+                ],
+                "@stablehlo//:vhlo_ops": [
+                    "VhloOps",
+                ],
                 # NCCL
                 "@nccl//:headers": [
                     "nccl::headers",
diff --git a/compiler/plugins/input/StableHLO/BUILD.bazel b/compiler/plugins/input/StableHLO/BUILD.bazel
index e40ee37..dfdb017 100644
--- a/compiler/plugins/input/StableHLO/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/BUILD.bazel
@@ -31,5 +31,6 @@
         "@llvm-project//mlir:Transforms",
         "@stablehlo//:chlo_ops",
         "@stablehlo//:stablehlo_ops",
+        "@stablehlo//:vhlo_ops",
     ],
 )
diff --git a/compiler/plugins/input/StableHLO/CMakeLists.txt b/compiler/plugins/input/StableHLO/CMakeLists.txt
index 11eb2a4..5f22917 100644
--- a/compiler/plugins/input/StableHLO/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/CMakeLists.txt
@@ -32,6 +32,7 @@
     MLIRPass
     MLIRTransforms
     StablehloOps
+    VhloOps
     iree::compiler::PluginAPI
     iree::compiler::plugins::input::StableHLO::Conversion
   PUBLIC
diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
index 13581c5..110c27a 100644
--- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
@@ -61,6 +61,7 @@
 iree_compiler_cc_library(
     name = "StableHLOLegalization",
     srcs = [
+        "CheckVHLOStableHloMixUsage.cpp",
         "ConvertCollectives.cpp",
         "LegalizeCHLO.cpp",
         "LegalizeControlFlow.cpp",
@@ -121,6 +122,7 @@
         "@stablehlo//:broadcast_utils",
         "@stablehlo//:chlo_ops",
         "@stablehlo//:stablehlo_ops",
+        "@stablehlo//:vhlo_ops",
     ],
 )
 
@@ -152,5 +154,6 @@
         "@llvm-project//mlir:ShapeToStandard",
         "@llvm-project//mlir:ShapeTransforms",
         "@llvm-project//mlir:Transforms",
+        "@stablehlo//:stablehlo_passes",
     ],
 )
diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
index 6635859..5afe3a2 100644
--- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
@@ -47,6 +47,7 @@
   NAME
     StableHLOLegalization
   SRCS
+    "CheckVHLOStableHloMixUsage.cpp"
     "ConvertCollectives.cpp"
     "LegalizeCHLO.cpp"
     "LegalizeControlFlow.cpp"
@@ -99,6 +100,7 @@
     MLIRVectorDialect
     StablehloBroadcastUtils
     StablehloOps
+    VhloOps
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::Util::IR
@@ -130,6 +132,7 @@
     MLIRShapeOpsTransforms
     MLIRShapeToStandard
     MLIRTransforms
+    StablehloPasses
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Util::IR
     iree::compiler::Dialect::Util::Transforms
diff --git a/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp b/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp
new file mode 100644
index 0000000..9659ceb
--- /dev/null
+++ b/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp
@@ -0,0 +1,61 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h"
+#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/ChloOps.h"
+#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+#define GEN_PASS_DEF_CHECKVHLOSTABLEHLOMIXUSAGE
+#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"
+
+namespace {
+struct CheckVHLOStableHloMixUsage final
+    : impl::CheckVHLOStableHloMixUsageBase<CheckVHLOStableHloMixUsage> {
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    auto moduleOp = getOperation();
+    Operation *lastStablehloOp = nullptr;
+    Operation *lastVhloOp = nullptr;
+    bool errorsFound = false;
+    const Dialect *stablehloDialect = ctx->getLoadedDialect("stablehlo");
+    const Dialect *vhloDialect = ctx->getLoadedDialect("vhlo");
+    auto emitError = [&](Operation *vhloOp, Operation *stablehloOp) {
+      vhloOp->emitOpError()
+          << "using VHLO and StableHLO Ops in the same module "
+             "is not supported. ";
+      stablehloOp->emitRemark() << "last StableHLO Op was found here: ";
+      errorsFound = true;
+    };
+    moduleOp->walk([&](Operation *op) {
+      auto opDialect = op->getDialect();
+      if (opDialect == stablehloDialect) {
+        if (lastVhloOp) {
+          emitError(lastVhloOp, op);
+          return WalkResult::interrupt();
+        }
+        lastStablehloOp = op;
+      } else if (opDialect == vhloDialect) {
+        if (lastStablehloOp) {
+          emitError(op, lastStablehloOp);
+          return WalkResult::interrupt();
+        }
+        lastVhloOp = op;
+      }
+      return WalkResult::advance();
+    });
+    if (errorsFound) {
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
index 69d98e5..da76cae 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Pass/PassOptions.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/Passes.h"
+#include "stablehlo/transforms/Passes.h"
 
 namespace mlir::iree_compiler::stablehlo {
 namespace {
@@ -40,6 +41,11 @@
 // Prepare HLO for use as an input to the Flow dialect.
 void buildStableHLOInputConversionPassPipelineImpl(
     OpPassManager &passManager, const StableHloOptions &options, bool detuple) {
+  // Having both StableHLO and VHLO in the same module is not supported.
+  // If the input is VHLO, then it is automatically converted to StableHLO.
+  // If the input is StableHLO, this pass is considered a NOP.
+  passManager.addPass(stablehlo::createCheckVHLOStableHloMixUsage());
+  ::mlir::stablehlo::createStablehloDeserializePipeline(passManager);
   passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
   passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
   passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td
index 7855f8c..8f7aa2b 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.td
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td
@@ -67,4 +67,11 @@
       "Verifies that only supported IR constructs are passed to the compiler";
 }
 
+def CheckVHLOStableHloMixUsage :
+    Pass<"iree-check-vhlostablehlo-mix-usage", "ModuleOp"> {
+  let summary =
+      "Check and report an error when VHLO and StableHLO are used in the same module";
+}
+
+
 #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES
diff --git a/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp b/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
index aa9bea7..c0751f2 100644
--- a/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "stablehlo/dialect/ChloOps.h"
 #include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
 
 namespace mlir::iree_compiler::stablehlo {
 
@@ -32,6 +33,7 @@
     // that we explicitly deny the dialects we know about.
     conversionTarget.addIllegalDialect<mlir::stablehlo::StablehloDialect>();
     conversionTarget.addIllegalDialect<mlir::chlo::ChloDialect>();
+    conversionTarget.addIllegalDialect<mlir::vhlo::VhloDialect>();
     conversionTarget.addIllegalDialect<mlir::shape::ShapeDialect>();
 
     // NOTE: It is not fully illegal to tunnel input dialect ops through to
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
index c42d6f8..2229c05 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
@@ -34,6 +34,7 @@
             "stablehlo_to_linalg_reduce.mlir",
             "stablehlo_to_linalg.mlir",
             "verify_compiler_input_legality.mlir",
+            "vhlo_stablehlo_mix_invalid.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
index 7b41288..6d9c9b1 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
@@ -32,6 +32,7 @@
     "stablehlo_to_linalg_random.mlir"
     "stablehlo_to_linalg_reduce.mlir"
     "verify_compiler_input_legality.mlir"
+    "vhlo_stablehlo_mix_invalid.mlir"
   TOOLS
     FileCheck
     iree-compile
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir b/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
index 08b556b..efb9b89 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
@@ -8,3 +8,43 @@
   %0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   return %0 : tensor<2x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: util.func public @vhlo_func
+vhlo.func_v1 @vhlo_func(%arg0: !vhlo.tensor_v1<!vhlo.i32_v1>) -> (!vhlo.tensor_v1<!vhlo.i32_v1>) {
+  // CHECK: arith.constant
+  %0 = "vhlo.constant_v1"() <{value = #vhlo.tensor_v1<dense<1> : tensor<i32>>}> : () -> !vhlo.tensor_v1<!vhlo.i32_v1>
+  // CHECK: return
+  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<!vhlo.i32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// ----
+
+// CHECK-LABEL: util.func public @dot_vhlo_example
+vhlo.func_v1 @dot_vhlo_example(%arg0: !vhlo.tensor_v1<8x16x!vhlo.f32_v1>, %arg1: !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<8x8x!vhlo.f32_v1>) {
+  // CHECK: linalg.matmul
+  %0 = "vhlo.dot_v1"(%arg0, %arg1) <{precision_config = #vhlo.array_v1<[#vhlo<precision_v1 DEFAULT>, #vhlo<precision_v1 DEFAULT>]>}> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1>
+  // CHECK: return
+  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<8x8x!vhlo.f32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// -----
+
+// CHECK-LABEL: util.func public @gather_vhlo_example
+vhlo.func_v1 @gather_vhlo_example(%arg0: !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>) {
+  // CHECK: flow.collective.all_gather
+  %0 = "vhlo.all_gather_v1"(%arg0) <{all_gather_dim = #vhlo.integer_v1<1 : i64>, channel_id = #vhlo.integer_v1<0 : i64>, replica_groups = #vhlo.tensor_v1<dense<[[0], [1]]> : tensor<2x1xi64>>, use_global_device_ids = #vhlo.bool_v1<false>}> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1>
+  // CHECK: return
+  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// -----
+
+// CHECK-LABEL: util.func public @compare_vhlo_example
+vhlo.func_v1 @compare_vhlo_example(%arg0: !vhlo.tensor_v1<!vhlo.f32_v1>, %arg1: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.bool_v1>) {
+  // CHECK: arith.cmpf
+  %0 = "vhlo.compare_v1"(%arg0, %arg1) <{compare_type = #vhlo<comparison_type_v1 NOTYPE>, comparison_direction = #vhlo<comparison_direction_v1 EQ>}> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.bool_v1>
+  // CHECK: return
+  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<!vhlo.bool_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir b/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir
new file mode 100644
index 0000000..6f55b57
--- /dev/null
+++ b/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt --iree-check-vhlostablehlo-mix-usage --split-input-file %s --verify-diagnostics
+
+vhlo.func_v1 @vhlo_stablehlo_func(%arg0: !vhlo.tensor_v1<!vhlo.i32_v1>) -> (!vhlo.tensor_v1<!vhlo.i32_v1>) {
+  // expected-error @+1 {{using VHLO and StableHLO Ops in the same module is not supported}}
+  %0 = "vhlo.constant_v1"() <{value = #vhlo.tensor_v1<dense<1> : tensor<i32>>}> : () -> !vhlo.tensor_v1<!vhlo.i32_v1>
+  %1 = builtin.unrealized_conversion_cast %0 : !vhlo.tensor_v1<!vhlo.i32_v1> to tensor<i32>
+  // expected-remark @+1 {{last StableHLO Op was found here}}
+  %2 = stablehlo.abs %1 : tensor<i32>
+  %3 = builtin.unrealized_conversion_cast %2 : tensor<i32> to !vhlo.tensor_v1<!vhlo.i32_v1>
+  "vhlo.return_v1"(%3) : (!vhlo.tensor_v1<!vhlo.i32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
diff --git a/compiler/plugins/input/StableHLO/PluginRegistration.cpp b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
index d714310..ec9058c 100644
--- a/compiler/plugins/input/StableHLO/PluginRegistration.cpp
+++ b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "stablehlo/dialect/ChloOps.h"
 #include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
 
 namespace mlir::iree_compiler::stablehlo {
 
@@ -91,6 +92,7 @@
   void onRegisterDialects(DialectRegistry &registry) override {
     registry.insert<mlir::chlo::ChloDialect>();
     registry.insert<mlir::stablehlo::StablehloDialect>();
+    registry.insert<mlir::vhlo::VhloDialect>();
   }
 
   bool extendCustomInputConversionPassPipeline(
@@ -100,7 +102,10 @@
     stableHloOptions.demoteF64ToF32 = options.demoteF64ToF32;
     stableHloOptions.promoteBF16ToF32 = options.promoteBF16ToF32;
 
-    if (typeMnemonic == "stablehlo") {
+    // VHLO is converted to StableHLO. The conversion function is called
+    // automatically, and if the input is fully stablehlo the function
+    // acts as Nop.
+    if (typeMnemonic == "stablehlo" || typeMnemonic == "vhlo") {
       buildStableHLOInputConversionPassPipeline(passManager, stableHloOptions);
       return true;
     } else if (typeMnemonic == "stablehlo_xla") {
@@ -115,6 +120,7 @@
   void populateCustomInputConversionTypes(StringSet<> &typeMnemonics) override {
     typeMnemonics.insert("stablehlo");
     typeMnemonics.insert("stablehlo_xla");
+    typeMnemonics.insert("vhlo");
   }
 
   void populateDetectedCustomInputConversionTypes(
@@ -123,6 +129,7 @@
     auto *ctx = module.getContext();
     const Dialect *chloDialect = ctx->getLoadedDialect("chlo");
     const Dialect *stablehloDialect = ctx->getLoadedDialect("stablehlo");
+    const Dialect *vhloDialect = ctx->getLoadedDialect("vhlo");
 
     // stablehlo ops _with tuples_    --> only "stablehlo_xla" type
     // stablehlo ops _without tuples_ --> only "stablehlo" type
@@ -132,7 +139,7 @@
     bool hasTuples = false;
     module.walk([&](Operation *op) {
       Dialect *d = op->getDialect();
-      if (d == chloDialect || d == stablehloDialect) {
+      if (d == chloDialect || d == stablehloDialect || d == vhloDialect) {
         hasStableHLO = true;
         if (checkOpForTuples(op)) {
           hasTuples = true;
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index afc2bd4..421fdf2 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -44,8 +44,8 @@
 // messages, so we err on the side of being helpful and populating plugin
 // options here, even though it is a layering violation.
 #ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_STABLEHLO
-          "  =stablehlo     - Legalize from StableHLO ops.\n"
-          "  =stablehlo_xla - Legalize from StableHLO ops (with XLA cleanup preprocessing).\n"
+          "  =stablehlo     - Legalize from StableHLO ops (including VHLO deserialization).\n"
+          "  =stablehlo_xla - Legalize from StableHLO ops (including VHLO deserialization and XLA de-tupling).\n"
 #endif // IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_STABLEHLO
 #ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_TOSA
           "  =tosa          - Legalize from TOSA ops.\n"