Reapply "Add VHLO support for IREE" (#17162) (#17179)
This reverts commit d407c7814fa146060b4b36877503413522efd7fa.
ci-extra:
build_test_all_windows,build_test_all_macos_arm64,build_test_all_macos_x86_64
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 608ee01..39a295c 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -86,6 +86,12 @@
"@stablehlo//:broadcast_utils": [
"StablehloBroadcastUtils",
],
+ "@stablehlo//:stablehlo_passes": [
+ "StablehloPasses",
+ ],
+ "@stablehlo//:vhlo_ops": [
+ "VhloOps",
+ ],
# NCCL
"@nccl//:headers": [
"nccl::headers",
diff --git a/compiler/plugins/input/StableHLO/BUILD.bazel b/compiler/plugins/input/StableHLO/BUILD.bazel
index e40ee37..dfdb017 100644
--- a/compiler/plugins/input/StableHLO/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/BUILD.bazel
@@ -31,5 +31,6 @@
"@llvm-project//mlir:Transforms",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
+ "@stablehlo//:vhlo_ops",
],
)
diff --git a/compiler/plugins/input/StableHLO/CMakeLists.txt b/compiler/plugins/input/StableHLO/CMakeLists.txt
index 11eb2a4..5f22917 100644
--- a/compiler/plugins/input/StableHLO/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/CMakeLists.txt
@@ -32,6 +32,7 @@
MLIRPass
MLIRTransforms
StablehloOps
+ VhloOps
iree::compiler::PluginAPI
iree::compiler::plugins::input::StableHLO::Conversion
PUBLIC
diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
index 13581c5..110c27a 100644
--- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
@@ -61,6 +61,7 @@
iree_compiler_cc_library(
name = "StableHLOLegalization",
srcs = [
+ "CheckVHLOStableHloMixUsage.cpp",
"ConvertCollectives.cpp",
"LegalizeCHLO.cpp",
"LegalizeControlFlow.cpp",
@@ -121,6 +122,7 @@
"@stablehlo//:broadcast_utils",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
+ "@stablehlo//:vhlo_ops",
],
)
@@ -152,5 +154,6 @@
"@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:Transforms",
+ "@stablehlo//:stablehlo_passes",
],
)
diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
index 6635859..5afe3a2 100644
--- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
@@ -47,6 +47,7 @@
NAME
StableHLOLegalization
SRCS
+ "CheckVHLOStableHloMixUsage.cpp"
"ConvertCollectives.cpp"
"LegalizeCHLO.cpp"
"LegalizeControlFlow.cpp"
@@ -99,6 +100,7 @@
MLIRVectorDialect
StablehloBroadcastUtils
StablehloOps
+ VhloOps
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Util::IR
@@ -130,6 +132,7 @@
MLIRShapeOpsTransforms
MLIRShapeToStandard
MLIRTransforms
+ StablehloPasses
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
diff --git a/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp b/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp
new file mode 100644
index 0000000..9659ceb
--- /dev/null
+++ b/compiler/plugins/input/StableHLO/Conversion/CheckVHLOStableHloMixUsage.cpp
@@ -0,0 +1,61 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h"
+#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/ChloOps.h"
+#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+#define GEN_PASS_DEF_CHECKVHLOSTABLEHLOMIXUSAGE
+#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"
+
+namespace {
+struct CheckVHLOStableHloMixUsage final
+ : impl::CheckVHLOStableHloMixUsageBase<CheckVHLOStableHloMixUsage> {
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ auto moduleOp = getOperation();
+ Operation *lastStablehloOp = nullptr;
+ Operation *lastVhloOp = nullptr;
+ bool errorsFound = false;
+ const Dialect *stablehloDialect = ctx->getLoadedDialect("stablehlo");
+ const Dialect *vhloDialect = ctx->getLoadedDialect("vhlo");
+ auto emitError = [&](Operation *vhloOp, Operation *stablehloOp) {
+ vhloOp->emitOpError()
+ << "using VHLO and StableHLO Ops in the same module "
+ "is not supported. ";
+ stablehloOp->emitRemark() << "last StableHLO Op was found here: ";
+ errorsFound = true;
+ };
+ moduleOp->walk([&](Operation *op) {
+ auto opDialect = op->getDialect();
+ if (opDialect == stablehloDialect) {
+ if (lastVhloOp) {
+ emitError(lastVhloOp, op);
+ return WalkResult::interrupt();
+ }
+ lastStablehloOp = op;
+ } else if (opDialect == vhloDialect) {
+ if (lastStablehloOp) {
+ emitError(op, lastStablehloOp);
+ return WalkResult::interrupt();
+ }
+ lastVhloOp = op;
+ }
+ return WalkResult::advance();
+ });
+ if (errorsFound) {
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
+} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
index 69d98e5..da76cae 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
@@ -19,6 +19,7 @@
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
+#include "stablehlo/transforms/Passes.h"
namespace mlir::iree_compiler::stablehlo {
namespace {
@@ -40,6 +41,11 @@
// Prepare HLO for use as an input to the Flow dialect.
void buildStableHLOInputConversionPassPipelineImpl(
OpPassManager &passManager, const StableHloOptions &options, bool detuple) {
+ // Having both StableHLO and VHLO in the same module is not supported.
+ // If the input is VHLO, then it is automatically converted to StableHLO.
+ // If the input is StableHLO, this pass is considered a NOP.
+ passManager.addPass(stablehlo::createCheckVHLOStableHloMixUsage());
+ ::mlir::stablehlo::createStablehloDeserializePipeline(passManager);
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
passManager.addNestedPass<func::FuncOp>(mlir::createCSEPass());
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td
index 7855f8c..8f7aa2b 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.td
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td
@@ -67,4 +67,11 @@
"Verifies that only supported IR constructs are passed to the compiler";
}
+def CheckVHLOStableHloMixUsage :
+ Pass<"iree-check-vhlostablehlo-mix-usage", "ModuleOp"> {
+ let summary =
+ "Check and report an error when VHLO and StableHLO are used in the same module";
+}
+
+
#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES
diff --git a/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp b/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
index aa9bea7..c0751f2 100644
--- a/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/VerifyCompilerInputLegality.cpp
@@ -11,6 +11,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
namespace mlir::iree_compiler::stablehlo {
@@ -32,6 +33,7 @@
// that we explicitly deny the dialects we know about.
conversionTarget.addIllegalDialect<mlir::stablehlo::StablehloDialect>();
conversionTarget.addIllegalDialect<mlir::chlo::ChloDialect>();
+ conversionTarget.addIllegalDialect<mlir::vhlo::VhloDialect>();
conversionTarget.addIllegalDialect<mlir::shape::ShapeDialect>();
// NOTE: It is not fully illegal to tunnel input dialect ops through to
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
index c42d6f8..2229c05 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel
@@ -34,6 +34,7 @@
"stablehlo_to_linalg_reduce.mlir",
"stablehlo_to_linalg.mlir",
"verify_compiler_input_legality.mlir",
+ "vhlo_stablehlo_mix_invalid.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
index 7b41288..6d9c9b1 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt
@@ -32,6 +32,7 @@
"stablehlo_to_linalg_random.mlir"
"stablehlo_to_linalg_reduce.mlir"
"verify_compiler_input_legality.mlir"
+ "vhlo_stablehlo_mix_invalid.mlir"
TOOLS
FileCheck
iree-compile
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir b/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
index 08b556b..efb9b89 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/auto_input_conversion.mlir
@@ -8,3 +8,43 @@
%0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
+
+// -----
+
+// CHECK-LABEL: util.func public @vhlo_func
+vhlo.func_v1 @vhlo_func(%arg0: !vhlo.tensor_v1<!vhlo.i32_v1>) -> (!vhlo.tensor_v1<!vhlo.i32_v1>) {
+ // CHECK: arith.constant
+ %0 = "vhlo.constant_v1"() <{value = #vhlo.tensor_v1<dense<1> : tensor<i32>>}> : () -> !vhlo.tensor_v1<!vhlo.i32_v1>
+ // CHECK: return
+ "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<!vhlo.i32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// ----
+
+// CHECK-LABEL: util.func public @dot_vhlo_example
+vhlo.func_v1 @dot_vhlo_example(%arg0: !vhlo.tensor_v1<8x16x!vhlo.f32_v1>, %arg1: !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<8x8x!vhlo.f32_v1>) {
+ // CHECK: linalg.matmul
+ %0 = "vhlo.dot_v1"(%arg0, %arg1) <{precision_config = #vhlo.array_v1<[#vhlo<precision_v1 DEFAULT>, #vhlo<precision_v1 DEFAULT>]>}> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1>
+ // CHECK: return
+ "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<8x8x!vhlo.f32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// -----
+
+// CHECK-LABEL: util.func public @gather_vhlo_example
+vhlo.func_v1 @gather_vhlo_example(%arg0: !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>) {
+ // CHECK: flow.collective.all_gather
+ %0 = "vhlo.all_gather_v1"(%arg0) <{all_gather_dim = #vhlo.integer_v1<1 : i64>, channel_id = #vhlo.integer_v1<0 : i64>, replica_groups = #vhlo.tensor_v1<dense<[[0], [1]]> : tensor<2x1xi64>>, use_global_device_ids = #vhlo.bool_v1<false>}> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1>
+ // CHECK: return
+ "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
+
+// -----
+
+// CHECK-LABEL: util.func public @compare_vhlo_example
+vhlo.func_v1 @compare_vhlo_example(%arg0: !vhlo.tensor_v1<!vhlo.f32_v1>, %arg1: !vhlo.tensor_v1<!vhlo.f32_v1>) -> (!vhlo.tensor_v1<!vhlo.bool_v1>) {
+ // CHECK: arith.cmpf
+ %0 = "vhlo.compare_v1"(%arg0, %arg1) <{compare_type = #vhlo<comparison_type_v1 NOTYPE>, comparison_direction = #vhlo<comparison_direction_v1 EQ>}> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.bool_v1>
+ // CHECK: return
+ "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<!vhlo.bool_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir b/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir
new file mode 100644
index 0000000..6f55b57
--- /dev/null
+++ b/compiler/plugins/input/StableHLO/Conversion/test/vhlo_stablehlo_mix_invalid.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt --iree-check-vhlostablehlo-mix-usage --split-input-file %s --verify-diagnostics
+
+vhlo.func_v1 @vhlo_stablehlo_func(%arg0: !vhlo.tensor_v1<!vhlo.i32_v1>) -> (!vhlo.tensor_v1<!vhlo.i32_v1>) {
+ // expected-error @+1 {{using VHLO and StableHLO Ops in the same module is not supported}}
+ %0 = "vhlo.constant_v1"() <{value = #vhlo.tensor_v1<dense<1> : tensor<i32>>}> : () -> !vhlo.tensor_v1<!vhlo.i32_v1>
+ %1 = builtin.unrealized_conversion_cast %0 : !vhlo.tensor_v1<!vhlo.i32_v1> to tensor<i32>
+ // expected-remark @+1 {{last StableHLO Op was found here}}
+ %2 = stablehlo.abs %1 : tensor<i32>
+ %3 = builtin.unrealized_conversion_cast %2 : tensor<i32> to !vhlo.tensor_v1<!vhlo.i32_v1>
+ "vhlo.return_v1"(%3) : (!vhlo.tensor_v1<!vhlo.i32_v1>) -> ()
+} {arg_attrs = #vhlo.array_v1<[]>, res_attrs = #vhlo.array_v1<[]>, sym_visibility = #vhlo.string_v1<"">}
diff --git a/compiler/plugins/input/StableHLO/PluginRegistration.cpp b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
index d714310..ec9058c 100644
--- a/compiler/plugins/input/StableHLO/PluginRegistration.cpp
+++ b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
@@ -11,6 +11,7 @@
#include "mlir/Pass/PassManager.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/VhloOps.h"
namespace mlir::iree_compiler::stablehlo {
@@ -91,6 +92,7 @@
void onRegisterDialects(DialectRegistry ®istry) override {
registry.insert<mlir::chlo::ChloDialect>();
registry.insert<mlir::stablehlo::StablehloDialect>();
+ registry.insert<mlir::vhlo::VhloDialect>();
}
bool extendCustomInputConversionPassPipeline(
@@ -100,7 +102,10 @@
stableHloOptions.demoteF64ToF32 = options.demoteF64ToF32;
stableHloOptions.promoteBF16ToF32 = options.promoteBF16ToF32;
- if (typeMnemonic == "stablehlo") {
+ // VHLO is converted to StableHLO. The conversion function is called
+ // automatically, and if the input is fully stablehlo the function
+ // acts as Nop.
+ if (typeMnemonic == "stablehlo" || typeMnemonic == "vhlo") {
buildStableHLOInputConversionPassPipeline(passManager, stableHloOptions);
return true;
} else if (typeMnemonic == "stablehlo_xla") {
@@ -115,6 +120,7 @@
void populateCustomInputConversionTypes(StringSet<> &typeMnemonics) override {
typeMnemonics.insert("stablehlo");
typeMnemonics.insert("stablehlo_xla");
+ typeMnemonics.insert("vhlo");
}
void populateDetectedCustomInputConversionTypes(
@@ -123,6 +129,7 @@
auto *ctx = module.getContext();
const Dialect *chloDialect = ctx->getLoadedDialect("chlo");
const Dialect *stablehloDialect = ctx->getLoadedDialect("stablehlo");
+ const Dialect *vhloDialect = ctx->getLoadedDialect("vhlo");
// stablehlo ops _with tuples_ --> only "stablehlo_xla" type
// stablehlo ops _without tuples_ --> only "stablehlo" type
@@ -132,7 +139,7 @@
bool hasTuples = false;
module.walk([&](Operation *op) {
Dialect *d = op->getDialect();
- if (d == chloDialect || d == stablehloDialect) {
+ if (d == chloDialect || d == stablehloDialect || d == vhloDialect) {
hasStableHLO = true;
if (checkOpForTuples(op)) {
hasTuples = true;
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index afc2bd4..421fdf2 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -44,8 +44,8 @@
// messages, so we err on the side of being helpful and populating plugin
// options here, even though it is a layering violation.
#ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_STABLEHLO
- " =stablehlo - Legalize from StableHLO ops.\n"
- " =stablehlo_xla - Legalize from StableHLO ops (with XLA cleanup preprocessing).\n"
+ " =stablehlo - Legalize from StableHLO ops (including VHLO deserialization).\n"
+ " =stablehlo_xla - Legalize from StableHLO ops (including VHLO deserialization and XLA de-tupling).\n"
#endif // IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_STABLEHLO
#ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_TOSA
" =tosa - Legalize from TOSA ops.\n"