Add a default lowering config setting for custom_op. (#18737)
This adds a default lowering configuration setting for `custom_op`
where it defers back to the normal lowering configuration for ops
within its body and derives the configuration to use for the
`custom_op` by itself.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
index 6bc67b8..240a131 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
@@ -31,3 +31,61 @@
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
+ "llvm-cpu", "embedded-elf-x86_64",
+ {cpu_features = "+avx512f",
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+func.func @custom_op_compilation_info(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+ %arg2 : tensor<128xf32>) -> tensor<384x128xf32>
+ attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<384x128xf32>
+ %1 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d1)>,
+ affine_map<(d0, d1)[s0] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ attributes {
+ compilation_info = #iree_codegen.compilation_info<
+ lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>,
+ translation_info = <CPUDefault>>
+ }
+ ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+ outs(%0 : tensor<384x128xf32>) {
+ ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+ %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%t3 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %5 = arith.addf %b0, %b1 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<384x128xf32>
+ return %1 : tensor<384x128xf32>
+}
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+// CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDefault>
+// CHECK: func @custom_op_compilation_info(
+// CHECK-SAME: translation_info = #translation
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: attributes {lowering_config = #[[CONFIG]]}
+// CHECK-NOT: compilation_info
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
index ebf7e3c..93106c0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
@@ -113,6 +113,21 @@
void setLoweringConfig(Operation *op, Attribute config);
/// Convenience function that sets the lowering configuration on the operation
+/// and translation info.
+inline LogicalResult setOpConfigAndEntryPointFnTranslation(
+ mlir::FunctionOpInterface entryPointFn, Operation *op,
+ IREE::Codegen::LoweringConfigAttrInterface config,
+ IREE::Codegen::TranslationInfoAttr translationInfo) {
+ if (config) {
+ setLoweringConfig(op, config);
+ }
+ if (translationInfo) {
+ (void)setTranslationInfo(entryPointFn, translationInfo);
+ }
+ return success();
+}
+
+/// Convenience function that sets the lowering configuration on the operation
/// and translation info for a generic lowering config, lowering pipeline,
/// and optional workgroup/subgroup size.
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
@@ -123,11 +138,11 @@
std::optional<int64_t> subgroupSize = {},
DictionaryAttr pipelineConfig = DictionaryAttr()) {
MLIRContext *context = entryPointFn.getContext();
- setLoweringConfig(op, config);
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
context, passPipeline, SymbolRefAttr(), workgroupSize, subgroupSize,
pipelineConfig);
- return setTranslationInfo(entryPointFn, translationInfo);
+ return setOpConfigAndEntryPointFnTranslation(entryPointFn, op, config,
+ translationInfo);
}
/// Convenience function that sets the lowering configuration on the operation
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index a1ee70b..c68c905 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -2557,6 +2557,10 @@
return setRootConfig(entryPointFn, op, LinalgOpInfo(op),
targetMLTransInfo);
})
+ .Case<IREE::LinalgExt::CustomOp>([&](auto op) {
+ return setDefaultCustomOpLoweringConfig(entryPointFn, op,
+ initCPULaunchConfig);
+ })
.Case<IREE::LinalgExt::AttentionOp, IREE::LinalgExt::FftOp,
tensor::PackOp, tensor::PadOp, tensor::UnPackOp, linalg::Mmt4DOp,
linalg::BatchMmt4DOp>(
@@ -3094,8 +3098,16 @@
return failure();
}
- // Set vector level tile sizes for other operations individually.
- if (failed(setLoweringConfigForComputeOps(entryPointFn, computeOps,
+ // Avoid this for ops within a custom_op since those ops have already their
+ // configuration set.
+ auto prunedComputeOps =
+ llvm::to_vector(llvm::make_filter_range(computeOps, [](Operation *op) {
+ return !isa_and_nonnull<IREE::LinalgExt::CustomOp>(
+ op->getParentOp()) ||
+ getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(op) ==
+ nullptr;
+ }));
+ if (failed(setLoweringConfigForComputeOps(entryPointFn, prunedComputeOps,
rootOperation))) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
index 5a680fc..aeb2b64 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
@@ -67,6 +67,9 @@
IREE::Codegen::TranslationInfoAttr translationInfo,
F verificationFn) {
auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
+ if (isa<IREE::LinalgExt::CustomOp>(op)) {
+ return WalkResult::advance();
+ }
auto loweringConfig =
getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(op);
if (!loweringConfig)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index 6533f23..8195726 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1830,3 +1830,113 @@
// CHECK: #translation = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>
// CHECK-LABEL: @test_mod_vectorizing_strategy_peeling
// CHECK-SAME: attributes {hal.executable.target = #executable_target_system_elf_x86_64_, translation_info = #translation}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
+ "llvm-cpu", "embedded-elf-x86_64",
+ {cpu_features = "+avx512f",
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+ %arg2 : tensor<128xf32>) -> tensor<384x128xf32>
+ attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<384x128xf32>
+ %1 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d1)>,
+ affine_map<(d0, d1)[s0] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+ outs(%0 : tensor<384x128xf32>) {
+ ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+ %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%t3 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %5 = arith.addf %b0, %b1 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<384x128xf32>
+ return %1 : tensor<384x128xf32>
+}
+// CHECK-DAG: #[[CONFIG0:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64]]>
+// CHECK-DAG: #[[CONFIG1:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64], [8, 32], [0, 0], [0, 0]]>
+// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64, 0], [48, 64, 0], [0, 0, 0], [8, 32, 0], [0, 0, 16], [0, 0, 0]]>
+// CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>
+// CHECK: func @custom_op(
+// CHECK-SAME: translation_info = #translation
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: attributes {lowering_config = #[[CONFIG0]]}
+// CHECK: ^bb
+// CHECK: linalg.fill
+// CHECK-SAME: {lowering_config = #[[CONFIG1]]}
+// CHECK: linalg.matmul
+// CHECK-SAME: {lowering_config = #[[CONFIG2]]}
+// CHECK: linalg.generic
+// CHECK-SAME: {lowering_config = #[[CONFIG1]]}
+// CHECK: iree_linalg_ext.yield
+
+// -----
+
+#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",
+ {cpu_features = "+avx512f",
+ data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+ native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+module {
+ func.func @custom_op_preset_config(%arg0: tensor<384x512xf32>, %arg1: tensor<512x128xf32>,
+ %arg2: tensor<128xf32>) -> tensor<384x128xf32>
+ attributes {hal.executable.target = #executable_target, translation_info = #iree_codegen.translation_info<CPUDefault>} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<384x128xf32>
+ %1 = iree_linalg_ext.custom_op{
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d1)>,
+ affine_map<(d0, d1)[s0] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ attributes {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>}
+ ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>) outs(%0 : tensor<384x128xf32>) {
+ ^bb0(%arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?x?xf32>):
+ %2 = linalg.fill ins(%cst : f32) outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.matmul ins(%arg3, %arg4 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg5 : tensor<?x?xf32>, tensor<?xf32>) outs(%arg6 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %5 = arith.addf %in, %in_0 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<384x128xf32>
+ return %1 : tensor<384x128xf32>
+ }
+}
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+// CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDefault>
+// CHECK: func @custom_op_preset_config(
+// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: lowering_config = #[[CONFIG]]
+// CHECK-NOT: lowering_config
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 8e6eb45..807b9fd 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -2089,6 +2089,11 @@
LDBG("Ukernel Config");
return setUKernelConfig(entryPointFn, ukernelOp);
})
+ .Case<IREE::LinalgExt::CustomOp>([&](auto customOp) {
+ LDBG("CustomOp Config");
+ return setDefaultCustomOpLoweringConfig(entryPointFn, customOp,
+ initGPULaunchConfig);
+ })
.Default([&](auto op) {
LDBG("Default Config");
return setRootDefaultConfig(target, entryPointFn, computeOp);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 352f5c7..40945f2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -31,6 +31,7 @@
"elementwise_pipeline.mlir",
"cast_address_space_function.mlir",
"cast_type_to_fit_mma.mlir",
+ "config_custom_op.mlir",
"config_matvec.mlir",
"config_winograd.mlir",
"extract_address_computation_gpu.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 419d3eb..b771513 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
"amdgpu_set_anchor_layouts.mlir"
"cast_address_space_function.mlir"
"cast_type_to_fit_mma.mlir"
+ "config_custom_op.mlir"
"config_matvec.mlir"
"config_winograd.mlir"
"configure_tensor_layout.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
new file mode 100644
index 0000000..b74833f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
+
+func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+ %arg2 : tensor<128xf32>) -> tensor<384x128xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<384x128xf32>
+ %1 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d1)>,
+ affine_map<(d0, d1)[s0] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+ outs(%0 : tensor<384x128xf32>) {
+ ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+ %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%t3 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %5 = arith.addf %b0, %b1 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<384x128xf32>
+ return %1 : tensor<384x128xf32>
+}
+// CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0]]>
+// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64,
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 2, subgroup_n_count = 2>
+// CHECK: func @custom_op
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: lowering_config = #[[CONFIG]]
+// CHECK: ^bb
+// CHECK: linalg.matmul
+// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 32], workgroup = [64, 64, 0]}>
+// CHECK: iree_linalg_ext.yield
+
+// -----
+
+func.func @custom_op_preset_config(%arg0: tensor<384x512xf32>, %arg1: tensor<512x128xf32>, %arg2: tensor<128xf32>) -> tensor<384x128xf32>
+ attributes {translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse>} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<384x128xf32>
+ %1 = iree_linalg_ext.custom_op{
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d1)>,
+ affine_map<(d0, d1)[s0] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+ #iree_linalg_ext.iterator_type<parallel>]}
+ attributes {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>}
+ ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>) outs(%0 : tensor<384x128xf32>) {
+ ^bb0(%arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?x?xf32>):
+ %2 = linalg.fill ins(%cst : f32) outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.matmul ins(%arg3, %arg4 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%3, %arg5 : tensor<?x?xf32>, tensor<?xf32>) outs(%arg6 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %5 = arith.addf %in, %in_0 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<384x128xf32>
+ return %1 : tensor<384x128xf32>
+}
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+// CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<LLVMGPUTileAndFuse>
+// CHECK: func @custom_op_preset_config(
+// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: lowering_config = #[[CONFIG]]
+// CHECK-NOT: lowering_config
diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
index e466c59..1773dc3 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
@@ -58,6 +58,7 @@
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:SideEffectInterfaces",
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
index d246048..37c1ff0 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -42,6 +42,7 @@
MLIRLinalgTransforms
MLIRLinalgUtils
MLIRMemRefDialect
+ MLIRMemRefTransforms
MLIRSCFDialect
MLIRSCFTransforms
MLIRSideEffectInterfaces
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
index defca31..e228195 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
@@ -11,7 +11,9 @@
namespace mlir::iree_compiler {
-/// Find the root operation for the dispatch region. The priority is:
+/// Find the root operation for the dispatch region given `computeOps` that are
+/// obtained by a post order walk, i.e. in presence of nested compute ops the
+/// outermost operations are towards the end of the list. The priority is:
/// 1. A Linalg operation that has reduction loops.
/// 2. Any other Linalg op or LinalgExt op.
/// 3. An operation that implements TilingInterface.
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 4dc91bc..544e055 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -6,19 +6,24 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -26,6 +31,8 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
#define DEBUG_TYPE "iree-codegen-utils"
@@ -275,6 +282,245 @@
}
//===----------------------------------------------------------------------===//
+// Setting CustomOp Lowering config.
+//===----------------------------------------------------------------------===//
+
+static std::tuple<SmallVector<Operation *>, SetVector<Value>>
+getNonConstantValuesDefinedFromAbove(Region ®ion) {
+ llvm::SetVector<Value> valuesDefinedFromAbove;
+ mlir::getUsedValuesDefinedAbove(region, valuesDefinedFromAbove);
+ SmallVector<Operation *> constants;
+ SetVector<Value> erasedVals;
+ for (auto value : valuesDefinedFromAbove) {
+ Attribute constVal;
+ if (!matchPattern(value, m_Constant(&constVal))) {
+ continue;
+ }
+ if (!isa<IntegerAttr, FloatAttr>(constVal)) {
+ continue;
+ }
+ constants.push_back(value.getDefiningOp());
+ erasedVals.insert(value);
+ }
+ valuesDefinedFromAbove.set_subtract(erasedVals);
+ return {constants, valuesDefinedFromAbove};
+}
+
+/// Listener to track mapping from operations in the body of a cloned custom op
+/// back to the original operations in the body of the original custom op.
+class CustomOpConfigListener : public RewriterBase::Listener {
+public:
+ CustomOpConfigListener(IREE::LinalgExt::CustomOp origCustomOp,
+ IREE::LinalgExt::CustomOp clonedCustomOp) {
+ for (auto [origOp, clonedOp] :
+ llvm::zip_equal(origCustomOp.getBody()->without_terminator(),
+ clonedCustomOp.getBody()->without_terminator())) {
+ clonedOpToOrigOp[&clonedOp] = &origOp;
+ }
+ }
+ void notifyOperationErased(Operation *op) override {
+ clonedOpToOrigOp.erase(op);
+ }
+ void notifyOperationReplaced(Operation *op, Operation *replacement) override {
+ auto it = clonedOpToOrigOp.find(op);
+ if (it != clonedOpToOrigOp.end()) {
+ Operation *origOp = it->second;
+ clonedOpToOrigOp.erase(it);
+ clonedOpToOrigOp[replacement] = origOp;
+ }
+ }
+ void notifyOperationReplaced(Operation *op,
+ ValueRange replacements) override {
+ Operation *replacementOp = nullptr;
+ for (auto val : replacements) {
+ Operation *definingOp = getDefiningOp(val);
+ if (!definingOp) {
+ // One of the replacements is definitely not from an op. Bail
+ // immediately.
+ return;
+ }
+ if (replacementOp) {
+ if (definingOp != replacementOp) {
+ // No consistent replacementOp. Bail.
+ return;
+ }
+ } else {
+ replacementOp = definingOp;
+ }
+ }
+ if (replacementOp && replacementOp->getName() == op->getName()) {
+ notifyOperationReplaced(op, replacementOp);
+ }
+ }
+
+ // Helper methods to get back the orig op for the cloned op.
+ std::optional<Operation *> getOrigOp(Operation *clonedOp) {
+ auto it = clonedOpToOrigOp.find(clonedOp);
+ if (it == clonedOpToOrigOp.end()) {
+ return std::nullopt;
+ }
+ return it->second;
+ }
+
+private:
+ llvm::MapVector<Operation *, Operation *> clonedOpToOrigOp;
+
+ /// On cast propagation, the replacement value used is not the
+ /// actual op that is used for replacement. Walk back the replacement
+ /// value use-def chain to get to the real replacement. This is a
+ /// bit of a hack, but the lowering config propagation is really
+ /// best effort, so not incorrect.
+ Operation *getDefiningOp(Value v) {
+ Operation *definingOp = v.getDefiningOp();
+ while (definingOp) {
+ if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
+ definingOp = castOp.getSource().getDefiningOp();
+ continue;
+ }
+ // Default is to break out of the loop.
+ break;
+ }
+ return definingOp;
+ }
+};
+
+LogicalResult setDefaultCustomOpLoweringConfig(
+ FunctionOpInterface funcOp, IREE::LinalgExt::CustomOp customOp,
+ std::function<LogicalResult(FunctionOpInterface)> configFn) {
+
+ MLIRContext *context = funcOp.getContext();
+ IRRewriter rewriter(context);
+ rewriter.setInsertionPoint(funcOp);
+
+ // 1. Get values captured from above in the custom op region.
+ llvm::SetVector<Value> valuesDefinedAbove;
+ SmallVector<Operation *> constantOps;
+ std::tie(constantOps, valuesDefinedAbove) =
+ getNonConstantValuesDefinedFromAbove(customOp.getRegion());
+
+ // 2. Create an empty function with arguments being the operands of the custom
+ // op and values captured from above in the custom op.
+ auto operandTypes = llvm::to_vector(customOp->getOperandTypes());
+ auto valuesDefinedAboveTypes =
+ llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getType(); });
+ operandTypes.append(valuesDefinedAboveTypes.begin(),
+ valuesDefinedAboveTypes.end());
+ auto dummyFuncType =
+ FunctionType::get(context, operandTypes, customOp->getResultTypes());
+ std::string dummyFuncName =
+ std::string("__") + funcOp.getName().str() + "_config_setting__";
+ auto dummyFuncOp = rewriter.create<func::FuncOp>(
+ customOp.getLoc(), dummyFuncName, dummyFuncType);
+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
+ if (targetAttr) {
+ dummyFuncOp->setAttr(IREE::HAL::ExecutableTargetAttr::name, targetAttr);
+ }
+
+ // 3. Clone the custom op into the function
+ SmallVector<Location> locs = llvm::map_to_vector(
+ customOp->getOperands(), [](Value v) { return v.getLoc(); });
+ auto valuesDefinedAboveLocs =
+ llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getLoc(); });
+ locs.append(valuesDefinedAboveLocs.begin(), valuesDefinedAboveLocs.end());
+ Block *body =
+ rewriter.createBlock(&dummyFuncOp.getRegion(),
+ dummyFuncOp.getRegion().begin(), operandTypes, locs);
+ rewriter.setInsertionPointToStart(body);
+ IRMapping map;
+ map.map(customOp.getOperands(),
+ body->getArguments().take_front(customOp.getNumOperands()));
+ map.map(valuesDefinedAbove.getArrayRef(),
+ body->getArguments().take_back(valuesDefinedAbove.size()));
+ for (auto op : constantOps) {
+ rewriter.clone(*op, map);
+ }
+ auto clonedCustomOp = cast<IREE::LinalgExt::CustomOp>(
+ rewriter.clone(*customOp.getOperation(), map));
+ rewriter.create<func::ReturnOp>(customOp.getLoc(),
+ clonedCustomOp->getResults());
+ CustomOpConfigListener customOpConfigListener(customOp, clonedCustomOp);
+
+ // 4. Inline the cloned custom op.
+ rewriter.setInsertionPoint(clonedCustomOp);
+ FailureOr<SmallVector<Value>> replacements =
+ clonedCustomOp.decomposeOperation(rewriter);
+ if (failed(replacements)) {
+ return customOp.emitOpError(
+ "failed to decompose op during custom op configuration setting");
+ }
+ rewriter.replaceOp(clonedCustomOp, replacements.value());
+
+ // 5. Run canonicalizations on the created function to constant propagate the
+ // shape.
+ RewritePatternSet patterns(context);
+ auto addCanonicalizationPatterns = [&context,
+ &patterns](StringRef dialectName) {
+ context->getLoadedDialect(dialectName)
+ ->getCanonicalizationPatterns(patterns);
+ };
+ addCanonicalizationPatterns(linalg::LinalgDialect::getDialectNamespace());
+ addCanonicalizationPatterns(
+ IREE::LinalgExt::IREELinalgExtDialect::getDialectNamespace());
+ tensor::CastOp::getCanonicalizationPatterns(patterns, context);
+ addCanonicalizationPatterns(tensor::TensorDialect::getDialectNamespace());
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ GreedyRewriteConfig config;
+ config.listener = &customOpConfigListener;
+ if (failed(applyPatternsAndFoldGreedily(dummyFuncOp, std::move(patterns),
+ config))) {
+ return customOp.emitOpError(
+ "failed to canonicalize during custom op configuration setting");
+ }
+
+ // 6. Run set configuration on the new dummy function.
+ if (failed(configFn(dummyFuncOp))) {
+ return customOp.emitOpError("failed to set configuration for custom op");
+ }
+
+ // 7. Set translation info and lowering config for the custom op.
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(dummyFuncOp);
+ // Move lowering config from ops in the cloned function to the ops
+ // within the body of the custom op.
+ // TODO: This logic needs to be made more robust (by account for indexing maps
+ // specified for operands on the custom op and the indexing maps of the
+ // operations within the region of the custom op). For now, just use the first
+ // operation with lowering config.
+ std::optional<SmallVector<int64_t>> workgroupTileSizes;
+ std::optional<SmallVector<int64_t>> workgroupInterchange;
+ for (Operation &op : dummyFuncOp.getBody().front()) {
+ auto currLoweringConfig =
+ getLoweringConfig<IREE::Codegen::LoweringConfigAttrInterface>(&op);
+ if (!currLoweringConfig)
+ continue;
+
+ // Translate the lowering config to the original operation.
+ if (std::optional<Operation *> originalOperation =
+ customOpConfigListener.getOrigOp(&op)) {
+ setLoweringConfig(originalOperation.value(), currLoweringConfig);
+ }
+
+ auto currWorkgroupTileSizes = currLoweringConfig.getWorkgroupTileSizes();
+ if (currWorkgroupTileSizes.empty())
+ continue;
+ workgroupTileSizes = currWorkgroupTileSizes;
+ workgroupInterchange = currLoweringConfig.getWorkgroupInterchange();
+ }
+ IREE::Codegen::LoweringConfigAttr loweringConfig;
+ if (workgroupTileSizes) {
+ loweringConfig = IREE::Codegen::LoweringConfigAttr::get(
+ context, workgroupTileSizes.value_or(SmallVector<int64_t>{}),
+ workgroupInterchange.value_or(SmallVector<int64_t>{}));
+ }
+ if (failed(setOpConfigAndEntryPointFnTranslation(
+ funcOp, customOp, loweringConfig, translationInfo))) {
+ return funcOp.emitOpError("failed to set custom op configuration");
+ }
+ rewriter.eraseOp(dummyFuncOp);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Utility functions to set configurations
//===----------------------------------------------------------------------===//
@@ -627,9 +873,14 @@
return loopInfo;
}
-SmallVector<Operation *> getComputeOps(mlir::FunctionOpInterface funcOp) {
+SmallVector<Operation *> getComputeOps(Operation *containingOp) {
+ if (containingOp->getNumRegions() == 0) {
+ return {};
+ }
+ assert(containingOp->getNumRegions() == 1 &&
+ "expected op with a single region");
SmallVector<Operation *> computeOps;
- funcOp.walk([&](Operation *op) {
+ containingOp->getRegion(0).walk([&](Operation *op) {
if (isa<TilingInterface, IREE::Codegen::UKernelOpInterface>(op)) {
computeOps.push_back(op);
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
index c211a33..7337549 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -107,6 +108,11 @@
// Utility functions to set configurations
//===----------------------------------------------------------------------===//
+LogicalResult setDefaultCustomOpLoweringConfig(
+ mlir::FunctionOpInterface FunctionOpInterface,
+ IREE::LinalgExt::CustomOp customOp,
+ std::function<LogicalResult(mlir::FunctionOpInterface)> configFn);
+
/// Information about a tiled and distributed loop.
///
/// Right now distribution is happening as the same time when we tile the linalg
@@ -148,25 +154,11 @@
unsigned processorDistributionDim;
};
-/// Assuming that `funcOp` contains a single nested scf.for that represented the
-/// tiled+fused+distributed loops with the distribution being across workgroups,
-/// i.e.
-///
-/// scf.for ... {
-/// ...
-/// scf.for ... {
-/// ...
-/// filtered_op.
-/// ...
-/// filtered_op.
-/// ...
-/// }
-/// }
-///
-/// Returns the list of TilingInterface ops in the functions. If there are no
-/// `scf.for` operations in the function return the TilingInterface operations
-/// in the body of the function if it has a single basic block.
-SmallVector<Operation *> getComputeOps(mlir::FunctionOpInterface funcOp);
+/// Returns the list of TilingInterface ops in the operation obtained by a
+/// post order walk of the operation. This implies that in case of
+/// nested compute ops, the outermost compute ops are towards the end of the
+/// list.
+SmallVector<Operation *> getComputeOps(Operation *containingOp);
/// If the given `forOp` is a tiled and distributed loop, returns its tiling and
/// distribution information.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c7d6e98..24b3bd1 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1589,10 +1589,10 @@
let assemblyFormat = [{
`{` `indexing_maps` `=` $indexing_maps `,`
`iterator_types` `=` $iterator_types `}`
+ attr-dict-with-keyword
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
- $region
- attr-dict (`->` type($results)^)?
+ $region (`->` type($results)^)?
}];
let extraClassDeclaration =[{