[GlobalOpt] Do not set encoding if they have preset compilation info. (#15455)
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 5d176a4..d036b9c 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -62,6 +62,7 @@ ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses", + "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 73aa647..88781e7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -75,6 +75,7 @@ MLIRTransforms iree::compiler::Codegen::Common iree::compiler::Codegen::Common::CPU::CommonCPUPasses + iree::compiler::Codegen::Dialect::IREECodegenDialect iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp index 6177897..560bfc9 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
@@ -12,6 +12,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" +#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -156,6 +157,12 @@ PatternRewriter &rewriter) const override { if (!matmulOp.hasTensorSemantics()) return failure(); + + if (getCompilationInfo(matmulOp)) { + return rewriter.notifyMatchFailure( + matmulOp, "the op has preset compilation strategy, skip SetEncoding"); + } + auto inputs = matmulOp.getDpsInputs(); auto outputs = matmulOp.getDpsInits(); auto hasEncoding = [](Value operand) -> bool { @@ -228,6 +235,12 @@ PatternRewriter &rewriter) const override { if (!matmulOp.hasTensorSemantics()) return failure(); + + if (getCompilationInfo(matmulOp)) { + return rewriter.notifyMatchFailure( + matmulOp, "the op has preset compilation strategy, skip SetEncoding"); + } + auto inputs = matmulOp.getDpsInputs(); auto outputs = matmulOp.getDpsInits(); auto hasEncoding = [](Value operand) -> bool {
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir index e51ec26..e675c55 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
@@ -639,3 +639,33 @@ // CHECK: %[[FILL:.+]] = linalg.fill // CHECK-SAME: outs(%[[EMPTY]] : // CHECK: return %[[FILL]] + +// ----- + +#compilation0 = #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0]]>, + translation_info = <CPUDefault>> + +#compilation1 = #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0, 0]]>, + translation_info = <CPUDefault>> + + +func.func @preset_compilation_info( + %arg0 : tensor<?x?xf32>, + %arg1 : tensor<?x?xf32>, + %arg2 : tensor<?x?xf32>, + %arg3 : tensor<?x?x?xf32>, + %arg4 : tensor<?x?x?xf32>, + %arg5 : tensor<?x?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?x?xf32>) { + %0 = linalg.matmul {compilation_info = #compilation0} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> + %1 = linalg.batch_matmul {compilation_info = #compilation1} ins(%arg3, %arg4 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs(%arg5 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0, %1 : tensor<?x?xf32>, tensor<?x?x?xf32> +} +// CHECK-LABEL: func.func @preset_compilation_info +// CHECK-NOT: set_encoding +// CHECK-NOT: unset_encoding +// CHECK: linalg.matmul +// CHECK: linalg.batch_matmul