[GlobalOpt] Do not set encoding if they have preset compilation info. (#15455)
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index 5d176a4..d036b9c 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -62,6 +62,7 @@
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
+ "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 73aa647..88781e7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -75,6 +75,7 @@
MLIRTransforms
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
+ iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
index 6177897..560bfc9 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
@@ -12,6 +12,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
+#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -156,6 +157,12 @@
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
return failure();
+
+ if (getCompilationInfo(matmulOp)) {
+ return rewriter.notifyMatchFailure(
+ matmulOp, "the op has preset compilation strategy, skip SetEncoding");
+ }
+
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
@@ -228,6 +235,12 @@
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
return failure();
+
+ if (getCompilationInfo(matmulOp)) {
+ return rewriter.notifyMatchFailure(
+ matmulOp, "the op has preset compilation strategy, skip SetEncoding");
+ }
+
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
index e51ec26..e675c55 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
@@ -639,3 +639,33 @@
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[FILL]]
+
+// -----
+
+#compilation0 = #iree_codegen.compilation_info<
+ lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0]]>,
+ translation_info = <CPUDefault>>
+
+#compilation1 = #iree_codegen.compilation_info<
+ lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0, 0]]>,
+ translation_info = <CPUDefault>>
+
+
+func.func @preset_compilation_info(
+ %arg0 : tensor<?x?xf32>,
+ %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>,
+ %arg3 : tensor<?x?x?xf32>,
+ %arg4 : tensor<?x?x?xf32>,
+ %arg5 : tensor<?x?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?x?xf32>) {
+ %0 = linalg.matmul {compilation_info = #compilation0} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.batch_matmul {compilation_info = #compilation1} ins(%arg3, %arg4 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg5 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func.func @preset_compilation_info
+// CHECK-NOT: set_encoding
+// CHECK-NOT: unset_encoding
+// CHECK: linalg.matmul
+// CHECK: linalg.batch_matmul