Split mhlo.optimization_barrier lowering into several util.optimization_barrier ops if there are multiple inputs (#13210)
`util.optimization_barrier` contains the `SameOperandsAndResultType`
verifier, which means that all the operand types must match each other
and the result type. Since `mhlo.optimization_barrier` only matches the
input-output pairs, this creates a situation where a valid op fails to
lower.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 133fe98..3197903 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -228,8 +228,14 @@
LogicalResult matchAndRewrite(
mhlo::OptimizationBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(
- op, op.getOperands());
+ SmallVector<Value> outputs;
+ for (auto operand : adaptor.getOperands()) {
+ outputs.push_back(
+ rewriter
+ .create<IREE::Util::OptimizationBarrierOp>(op.getLoc(), operand)
+ .getResult(0));
+ }
+ rewriter.replaceOp(op, outputs);
return success();
}
};
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
index 649a1f1..a135925 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir
@@ -30,9 +30,10 @@
// -----
// CHECK: func.func @optimization_barrier
-// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %arg0 : tensor<3x4xf32
-// CHECK: return %[[BARRIER]]
-func.func @optimization_barrier(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
- %0 = "mhlo.optimization_barrier"(%arg0) : (tensor<3x4xf32>) -> (tensor<3x4xf32>)
- return %0 : tensor<3x4xf32>
+// CHECK: %[[RESULT1:.+]] = util.optimization_barrier %arg0 : tensor<3x4xf32
+// CHECK: %[[RESULT2:.+]] = util.optimization_barrier %arg1 : tensor<4xi32>
+// CHECK: return %[[RESULT1]], %[[RESULT2]]
+func.func @optimization_barrier(%arg0: tensor<3x4xf32>, %arg1: tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>) {
+ %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>)
+ return %0, %1 : tensor<3x4xf32>, tensor<4xi32>
}