Add a default lowering config setting for custom_op.  (#18737)

This adds a default lowering configuration setting for `custom_op`
where it defers back to the normal lowering configuration for ops
within its body and derives the configuration to use for the
`custom_op` by itself.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
index 6bc67b8..240a131 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_configs.mlir
@@ -31,3 +31,61 @@
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
+    "llvm-cpu", "embedded-elf-x86_64",
+    {cpu_features = "+avx512f",
+     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+     native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+func.func @custom_op_compilation_info(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+    %arg2 : tensor<128xf32>) -> tensor<384x128xf32>
+    attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<384x128xf32>
+  %1 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+                       affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0, d1)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                        #iree_linalg_ext.iterator_type<parallel>]}
+      attributes {
+        compilation_info = #iree_codegen.compilation_info<
+          lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>,
+          translation_info = <CPUDefault>>
+      }
+      ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+      outs(%0 : tensor<384x128xf32>) {
+    ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+      %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %4 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                           affine_map<(d0, d1) -> (d1)>,
+                           affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+          outs(%t3 : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %5 = arith.addf %b0, %b1 : f32
+          linalg.yield %5 : f32
+      } -> tensor<?x?xf32>
+      iree_linalg_ext.yield %4 : tensor<?x?xf32>
+  } -> tensor<384x128xf32>
+  return %1 : tensor<384x128xf32>
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+//  CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDefault>
+//      CHECK: func @custom_op_compilation_info(
+// CHECK-SAME:     translation_info = #translation
+//      CHECK:   iree_linalg_ext.custom_op
+// CHECK-SAME:       attributes {lowering_config = #[[CONFIG]]}
+//  CHECK-NOT:   compilation_info
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
index ebf7e3c..93106c0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
@@ -113,6 +113,21 @@
 void setLoweringConfig(Operation *op, Attribute config);
 
 /// Convenience function that sets the lowering configuration on the operation
+/// and translation info.
+inline LogicalResult setOpConfigAndEntryPointFnTranslation(
+    mlir::FunctionOpInterface entryPointFn, Operation *op,
+    IREE::Codegen::LoweringConfigAttrInterface config,
+    IREE::Codegen::TranslationInfoAttr translationInfo) {
+  if (config) {
+    setLoweringConfig(op, config);
+  }
+  if (translationInfo) {
+    (void)setTranslationInfo(entryPointFn, translationInfo);
+  }
+  return success();
+}
+
+/// Convenience function that sets the lowering configuration on the operation
 /// and translation info for a generic lowering config, lowering pipeline,
 /// and optional workgroup/subgroup size.
 inline LogicalResult setOpConfigAndEntryPointFnTranslation(
@@ -123,11 +138,11 @@
     std::optional<int64_t> subgroupSize = {},
     DictionaryAttr pipelineConfig = DictionaryAttr()) {
   MLIRContext *context = entryPointFn.getContext();
-  setLoweringConfig(op, config);
   auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
       context, passPipeline, SymbolRefAttr(), workgroupSize, subgroupSize,
       pipelineConfig);
-  return setTranslationInfo(entryPointFn, translationInfo);
+  return setOpConfigAndEntryPointFnTranslation(entryPointFn, op, config,
+                                               translationInfo);
 }
 
 /// Convenience function that sets the lowering configuration on the operation
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index a1ee70b..c68c905 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -2557,6 +2557,10 @@
           return setRootConfig(entryPointFn, op, LinalgOpInfo(op),
                                targetMLTransInfo);
         })
+        .Case<IREE::LinalgExt::CustomOp>([&](auto op) {
+          return setDefaultCustomOpLoweringConfig(entryPointFn, op,
+                                                  initCPULaunchConfig);
+        })
         .Case<IREE::LinalgExt::AttentionOp, IREE::LinalgExt::FftOp,
               tensor::PackOp, tensor::PadOp, tensor::UnPackOp, linalg::Mmt4DOp,
               linalg::BatchMmt4DOp>(
@@ -3094,8 +3098,16 @@
       return failure();
     }
 
-    // Set vector level tile sizes for other operations individually.
-    if (failed(setLoweringConfigForComputeOps(entryPointFn, computeOps,
+    // Avoid this for ops within a custom_op since those ops have already their
+    // configuration set.
+    auto prunedComputeOps =
+        llvm::to_vector(llvm::make_filter_range(computeOps, [](Operation *op) {
+          return !isa_and_nonnull<IREE::LinalgExt::CustomOp>(
+                     op->getParentOp()) ||
+                 getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(op) ==
+                     nullptr;
+        }));
+    if (failed(setLoweringConfigForComputeOps(entryPointFn, prunedComputeOps,
                                               rootOperation))) {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
index 5a680fc..aeb2b64 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp
@@ -67,6 +67,9 @@
                             IREE::Codegen::TranslationInfoAttr translationInfo,
                             F verificationFn) {
   auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
+    if (isa<IREE::LinalgExt::CustomOp>(op)) {
+      return WalkResult::advance();
+    }
     auto loweringConfig =
         getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(op);
     if (!loweringConfig)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index 6533f23..8195726 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1830,3 +1830,113 @@
 // CHECK: #translation = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>
 // CHECK-LABEL: @test_mod_vectorizing_strategy_peeling
 // CHECK-SAME: attributes {hal.executable.target = #executable_target_system_elf_x86_64_, translation_info = #translation}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
+    "llvm-cpu", "embedded-elf-x86_64",
+    {cpu_features = "+avx512f",
+     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+     native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+    %arg2 : tensor<128xf32>) -> tensor<384x128xf32>
+    attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<384x128xf32>
+  %1 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+                       affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0, d1)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                        #iree_linalg_ext.iterator_type<parallel>]}
+      ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+      outs(%0 : tensor<384x128xf32>) {
+    ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+      %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %4 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                           affine_map<(d0, d1) -> (d1)>,
+                           affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+          outs(%t3 : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %5 = arith.addf %b0, %b1 : f32
+          linalg.yield %5 : f32
+      } -> tensor<?x?xf32>
+      iree_linalg_ext.yield %4 : tensor<?x?xf32>
+  } -> tensor<384x128xf32>
+  return %1 : tensor<384x128xf32>
+}
+//  CHECK-DAG: #[[CONFIG0:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64]]>
+//  CHECK-DAG: #[[CONFIG1:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64], [8, 32], [0, 0], [0, 0]]>
+//  CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[48, 64, 0], [48, 64, 0], [0, 0, 0], [8, 32, 0], [0, 0, 16], [0, 0, 0]]>
+//  CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>
+//      CHECK: func @custom_op(
+// CHECK-SAME:     translation_info = #translation
+//      CHECK:   iree_linalg_ext.custom_op
+// CHECK-SAME:       attributes {lowering_config = #[[CONFIG0]]}
+//      CHECK:   ^bb
+//      CHECK:     linalg.fill
+// CHECK-SAME:         {lowering_config = #[[CONFIG1]]}
+//      CHECK:     linalg.matmul
+// CHECK-SAME:         {lowering_config = #[[CONFIG2]]}
+//      CHECK:     linalg.generic
+// CHECK-SAME:         {lowering_config = #[[CONFIG1]]}
+//      CHECK:   iree_linalg_ext.yield
+
+// -----
+
+#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",
+    {cpu_features = "+avx512f",
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>
+module {
+  func.func @custom_op_preset_config(%arg0: tensor<384x512xf32>, %arg1: tensor<512x128xf32>,
+      %arg2: tensor<128xf32>) -> tensor<384x128xf32>
+      attributes {hal.executable.target = #executable_target, translation_info = #iree_codegen.translation_info<CPUDefault>} {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<384x128xf32>
+    %1 = iree_linalg_ext.custom_op{
+        indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+                         affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                         affine_map<(d0, d1)[s0] -> (d1)>,
+                         affine_map<(d0, d1)[s0] -> (d0, d1)>],
+        iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                          #iree_linalg_ext.iterator_type<parallel>]}
+        attributes {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>}
+        ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>) outs(%0 : tensor<384x128xf32>) {
+    ^bb0(%arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?x?xf32>):
+      %2 = linalg.fill ins(%cst : f32) outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %3 = linalg.matmul ins(%arg3, %arg4 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %4 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                           affine_map<(d0, d1) -> (d1)>,
+                           affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%3, %arg5 : tensor<?x?xf32>, tensor<?xf32>) outs(%arg6 : tensor<?x?xf32>) {
+      ^bb0(%in: f32, %in_0: f32, %out: f32):
+        %5 = arith.addf %in, %in_0 : f32
+        linalg.yield %5 : f32
+      } -> tensor<?x?xf32>
+      iree_linalg_ext.yield %4 : tensor<?x?xf32>
+    } -> tensor<384x128xf32>
+    return %1 : tensor<384x128xf32>
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+//  CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<CPUDefault>
+//      CHECK: func @custom_op_preset_config(
+// CHECK-SAME:     translation_info = #[[TRANSLATION_INFO]]
+//      CHECK:   iree_linalg_ext.custom_op
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+//  CHECK-NOT:   lowering_config
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 8e6eb45..807b9fd 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -2089,6 +2089,11 @@
         LDBG("Ukernel Config");
         return setUKernelConfig(entryPointFn, ukernelOp);
       })
+      .Case<IREE::LinalgExt::CustomOp>([&](auto customOp) {
+        LDBG("CustomOp Config");
+        return setDefaultCustomOpLoweringConfig(entryPointFn, customOp,
+                                                initGPULaunchConfig);
+      })
       .Default([&](auto op) {
         LDBG("Default Config");
         return setRootDefaultConfig(target, entryPointFn, computeOp);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 352f5c7..40945f2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -31,6 +31,7 @@
             "elementwise_pipeline.mlir",
             "cast_address_space_function.mlir",
             "cast_type_to_fit_mma.mlir",
+            "config_custom_op.mlir",
             "config_matvec.mlir",
             "config_winograd.mlir",
             "extract_address_computation_gpu.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 419d3eb..b771513 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
     "amdgpu_set_anchor_layouts.mlir"
     "cast_address_space_function.mlir"
     "cast_type_to_fit_mma.mlir"
+    "config_custom_op.mlir"
     "config_matvec.mlir"
     "config_winograd.mlir"
     "configure_tensor_layout.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
new file mode 100644
index 0000000..b74833f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
+
+func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
+    %arg2 : tensor<128xf32>) -> tensor<384x128xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<384x128xf32>
+  %1 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+                       affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0, d1)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                        #iree_linalg_ext.iterator_type<parallel>]}
+      ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>)
+      outs(%0 : tensor<384x128xf32>) {
+    ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?xf32>, %t3 : tensor<?x?xf32>):
+      %2 = linalg.fill ins(%cst : f32) outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %3 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %4 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                           affine_map<(d0, d1) -> (d1)>,
+                           affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%3, %t2 : tensor<?x?xf32>, tensor<?xf32>)
+          outs(%t3 : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %5 = arith.addf %b0, %b1 : f32
+          linalg.yield %5 : f32
+      } -> tensor<?x?xf32>
+      iree_linalg_ext.yield %4 : tensor<?x?xf32>
+  } -> tensor<384x128xf32>
+  return %1 : tensor<384x128xf32>
+}
+//      CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0]]>
+//      CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64,
+// CHECK-SAME:     mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 2, subgroup_n_count = 2>
+//      CHECK: func @custom_op
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   iree_linalg_ext.custom_op
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+//      CHECK:   ^bb
+//      CHECK:     linalg.matmul
+// CHECK-SAME:         lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 32], workgroup = [64, 64, 0]}>
+//      CHECK:   iree_linalg_ext.yield
+
+// -----
+
+func.func @custom_op_preset_config(%arg0: tensor<384x512xf32>, %arg1: tensor<512x128xf32>, %arg2: tensor<128xf32>) -> tensor<384x128xf32>
+  attributes {translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse>} {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<384x128xf32>
+  %1 = iree_linalg_ext.custom_op{
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
+                       affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0, d1)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                        #iree_linalg_ext.iterator_type<parallel>]}
+      attributes {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[24, 32]]>}
+      ins(%arg0, %arg1, %arg2 : tensor<384x512xf32>, tensor<512x128xf32>, tensor<128xf32>) outs(%0 : tensor<384x128xf32>) {
+  ^bb0(%arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, %arg5: tensor<?xf32>, %arg6: tensor<?x?xf32>):
+    %2 = linalg.fill ins(%cst : f32) outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %3 = linalg.matmul ins(%arg3, %arg4 : tensor<?x?xf32>, tensor<?x?xf32>)
+        outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %4 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                         affine_map<(d0, d1) -> (d1)>,
+                         affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%3, %arg5 : tensor<?x?xf32>, tensor<?xf32>) outs(%arg6 : tensor<?x?xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %5 = arith.addf %in, %in_0 : f32
+      linalg.yield %5 : f32
+    } -> tensor<?x?xf32>
+    iree_linalg_ext.yield %4 : tensor<?x?xf32>
+  } -> tensor<384x128xf32>
+  return %1 : tensor<384x128xf32>
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[24, 32]]>
+//  CHECK-DAG: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info<LLVMGPUTileAndFuse>
+//      CHECK: func @custom_op_preset_config(
+// CHECK-SAME:     translation_info = #[[TRANSLATION_INFO]]
+//      CHECK:   iree_linalg_ext.custom_op
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+//  CHECK-NOT:   lowering_config
diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
index e466c59..1773dc3 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
@@ -58,6 +58,7 @@
         "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:LinalgUtils",
         "@llvm-project//mlir:MemRefDialect",
+        "@llvm-project//mlir:MemRefTransforms",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:SideEffectInterfaces",
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
index d246048..37c1ff0 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -42,6 +42,7 @@
     MLIRLinalgTransforms
     MLIRLinalgUtils
     MLIRMemRefDialect
+    MLIRMemRefTransforms
     MLIRSCFDialect
     MLIRSCFTransforms
     MLIRSideEffectInterfaces
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
index defca31..e228195 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h
@@ -11,7 +11,9 @@
 
 namespace mlir::iree_compiler {
 
-/// Find the root operation for the dispatch region. The priority is:
+/// Find the root operation for the dispatch region given `computeOps` that are
+/// obtained by a post order walk, i.e. in presence of nested compute ops the
+/// outermost operations are towards the end of the list. The priority is:
 ///   1. A Linalg operation that has reduction loops.
 ///   2. Any other Linalg op or LinalgExt op.
 ///   3. An operation that implements TilingInterface.
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 4dc91bc..544e055 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -6,19 +6,24 @@
 
 #include "iree/compiler/Codegen/Utils/Utils.h"
 
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
 #include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -26,6 +31,8 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #define DEBUG_TYPE "iree-codegen-utils"
 
@@ -275,6 +282,245 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Setting CustomOp Lowering config.
+//===----------------------------------------------------------------------===//
+
+static std::tuple<SmallVector<Operation *>, SetVector<Value>>
+getNonConstantValuesDefinedFromAbove(Region &region) {
+  llvm::SetVector<Value> valuesDefinedFromAbove;
+  mlir::getUsedValuesDefinedAbove(region, valuesDefinedFromAbove);
+  SmallVector<Operation *> constants;
+  SetVector<Value> erasedVals;
+  for (auto value : valuesDefinedFromAbove) {
+    Attribute constVal;
+    if (!matchPattern(value, m_Constant(&constVal))) {
+      continue;
+    }
+    if (!isa<IntegerAttr, FloatAttr>(constVal)) {
+      continue;
+    }
+    constants.push_back(value.getDefiningOp());
+    erasedVals.insert(value);
+  }
+  valuesDefinedFromAbove.set_subtract(erasedVals);
+  return {constants, valuesDefinedFromAbove};
+}
+
+/// Listener to track mapping from operations in the body of a cloned custom op
+/// back to the original operations in the body of the original custom op.
+class CustomOpConfigListener : public RewriterBase::Listener {
+public:
+  CustomOpConfigListener(IREE::LinalgExt::CustomOp origCustomOp,
+                         IREE::LinalgExt::CustomOp clonedCustomOp) {
+    for (auto [origOp, clonedOp] :
+         llvm::zip_equal(origCustomOp.getBody()->without_terminator(),
+                         clonedCustomOp.getBody()->without_terminator())) {
+      clonedOpToOrigOp[&clonedOp] = &origOp;
+    }
+  }
+  void notifyOperationErased(Operation *op) override {
+    clonedOpToOrigOp.erase(op);
+  }
+  void notifyOperationReplaced(Operation *op, Operation *replacement) override {
+    auto it = clonedOpToOrigOp.find(op);
+    if (it != clonedOpToOrigOp.end()) {
+      Operation *origOp = it->second;
+      clonedOpToOrigOp.erase(it);
+      clonedOpToOrigOp[replacement] = origOp;
+    }
+  }
+  void notifyOperationReplaced(Operation *op,
+                               ValueRange replacements) override {
+    Operation *replacementOp = nullptr;
+    for (auto val : replacements) {
+      Operation *definingOp = getDefiningOp(val);
+      if (!definingOp) {
+        // One of the replacements is definitely not from an op. Bail
+        // immediately.
+        return;
+      }
+      if (replacementOp) {
+        if (definingOp != replacementOp) {
+          // No consistent replacementOp. Bail.
+          return;
+        }
+      } else {
+        replacementOp = definingOp;
+      }
+    }
+    if (replacementOp && replacementOp->getName() == op->getName()) {
+      notifyOperationReplaced(op, replacementOp);
+    }
+  }
+
+  // Helper methods to get back the orig op for the cloned op.
+  std::optional<Operation *> getOrigOp(Operation *clonedOp) {
+    auto it = clonedOpToOrigOp.find(clonedOp);
+    if (it == clonedOpToOrigOp.end()) {
+      return std::nullopt;
+    }
+    return it->second;
+  }
+
+private:
+  llvm::MapVector<Operation *, Operation *> clonedOpToOrigOp;
+
+  /// On cast propagation, the replacement value used is not the
+  /// actual op that is used for replacement. Walk back the replacement
+  /// value use-def chain to get to the real replacement. This is a
+  /// bit of a hack, but the lowering config propagation is really
+  /// best effort, so not incorrect.
+  Operation *getDefiningOp(Value v) {
+    Operation *definingOp = v.getDefiningOp();
+    while (definingOp) {
+      if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
+        definingOp = castOp.getSource().getDefiningOp();
+        continue;
+      }
+      // Default is to break out of the loop.
+      break;
+    }
+    return definingOp;
+  }
+};
+
+LogicalResult setDefaultCustomOpLoweringConfig(
+    FunctionOpInterface funcOp, IREE::LinalgExt::CustomOp customOp,
+    std::function<LogicalResult(FunctionOpInterface)> configFn) {
+
+  MLIRContext *context = funcOp.getContext();
+  IRRewriter rewriter(context);
+  rewriter.setInsertionPoint(funcOp);
+
+  // 1. Get values captured from above in the custom op region.
+  llvm::SetVector<Value> valuesDefinedAbove;
+  SmallVector<Operation *> constantOps;
+  std::tie(constantOps, valuesDefinedAbove) =
+      getNonConstantValuesDefinedFromAbove(customOp.getRegion());
+
+  // 2. Create an empty function with arguments being the operands of the custom
+  // op and values captured from above in the custom op.
+  auto operandTypes = llvm::to_vector(customOp->getOperandTypes());
+  auto valuesDefinedAboveTypes =
+      llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getType(); });
+  operandTypes.append(valuesDefinedAboveTypes.begin(),
+                      valuesDefinedAboveTypes.end());
+  auto dummyFuncType =
+      FunctionType::get(context, operandTypes, customOp->getResultTypes());
+  std::string dummyFuncName =
+      std::string("__") + funcOp.getName().str() + "_config_setting__";
+  auto dummyFuncOp = rewriter.create<func::FuncOp>(
+      customOp.getLoc(), dummyFuncName, dummyFuncType);
+  auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
+  if (targetAttr) {
+    dummyFuncOp->setAttr(IREE::HAL::ExecutableTargetAttr::name, targetAttr);
+  }
+
+  // 3. Clone the custom op into the function
+  SmallVector<Location> locs = llvm::map_to_vector(
+      customOp->getOperands(), [](Value v) { return v.getLoc(); });
+  auto valuesDefinedAboveLocs =
+      llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getLoc(); });
+  locs.append(valuesDefinedAboveLocs.begin(), valuesDefinedAboveLocs.end());
+  Block *body =
+      rewriter.createBlock(&dummyFuncOp.getRegion(),
+                           dummyFuncOp.getRegion().begin(), operandTypes, locs);
+  rewriter.setInsertionPointToStart(body);
+  IRMapping map;
+  map.map(customOp.getOperands(),
+          body->getArguments().take_front(customOp.getNumOperands()));
+  map.map(valuesDefinedAbove.getArrayRef(),
+          body->getArguments().take_back(valuesDefinedAbove.size()));
+  for (auto op : constantOps) {
+    rewriter.clone(*op, map);
+  }
+  auto clonedCustomOp = cast<IREE::LinalgExt::CustomOp>(
+      rewriter.clone(*customOp.getOperation(), map));
+  rewriter.create<func::ReturnOp>(customOp.getLoc(),
+                                  clonedCustomOp->getResults());
+  CustomOpConfigListener customOpConfigListener(customOp, clonedCustomOp);
+
+  // 4. Inline the cloned custom op.
+  rewriter.setInsertionPoint(clonedCustomOp);
+  FailureOr<SmallVector<Value>> replacements =
+      clonedCustomOp.decomposeOperation(rewriter);
+  if (failed(replacements)) {
+    return customOp.emitOpError(
+        "failed to decompose op during custom op configuration setting");
+  }
+  rewriter.replaceOp(clonedCustomOp, replacements.value());
+
+  // 5. Run canonicalizations on the created function to constant propagate the
+  // shape.
+  RewritePatternSet patterns(context);
+  auto addCanonicalizationPatterns = [&context,
+                                      &patterns](StringRef dialectName) {
+    context->getLoadedDialect(dialectName)
+        ->getCanonicalizationPatterns(patterns);
+  };
+  addCanonicalizationPatterns(linalg::LinalgDialect::getDialectNamespace());
+  addCanonicalizationPatterns(
+      IREE::LinalgExt::IREELinalgExtDialect::getDialectNamespace());
+  tensor::CastOp::getCanonicalizationPatterns(patterns, context);
+  addCanonicalizationPatterns(tensor::TensorDialect::getDialectNamespace());
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+  GreedyRewriteConfig config;
+  config.listener = &customOpConfigListener;
+  if (failed(applyPatternsAndFoldGreedily(dummyFuncOp, std::move(patterns),
+                                          config))) {
+    return customOp.emitOpError(
+        "failed to canonicalize during custom op configuration setting");
+  }
+
+  // 6. Run set configuration on the new dummy function.
+  if (failed(configFn(dummyFuncOp))) {
+    return customOp.emitOpError("failed to set configuration for custom op");
+  }
+
+  // 7. Set translation info and lowering config for the custom op.
+  IREE::Codegen::TranslationInfoAttr translationInfo =
+      getTranslationInfo(dummyFuncOp);
+  // Move lowering config from ops in the cloned function to the ops
+  // within the body of the custom op.
+  // TODO: This logic needs to be made more robust (by account for indexing maps
+  // specified for operands on the custom op and the indexing maps of the
+  // operations within the region of the custom op). For now, just use the first
+  // operation with lowering config.
+  std::optional<SmallVector<int64_t>> workgroupTileSizes;
+  std::optional<SmallVector<int64_t>> workgroupInterchange;
+  for (Operation &op : dummyFuncOp.getBody().front()) {
+    auto currLoweringConfig =
+        getLoweringConfig<IREE::Codegen::LoweringConfigAttrInterface>(&op);
+    if (!currLoweringConfig)
+      continue;
+
+    // Translate the lowering config to the original operation.
+    if (std::optional<Operation *> originalOperation =
+            customOpConfigListener.getOrigOp(&op)) {
+      setLoweringConfig(originalOperation.value(), currLoweringConfig);
+    }
+
+    auto currWorkgroupTileSizes = currLoweringConfig.getWorkgroupTileSizes();
+    if (currWorkgroupTileSizes.empty())
+      continue;
+    workgroupTileSizes = currWorkgroupTileSizes;
+    workgroupInterchange = currLoweringConfig.getWorkgroupInterchange();
+  }
+  IREE::Codegen::LoweringConfigAttr loweringConfig;
+  if (workgroupTileSizes) {
+    loweringConfig = IREE::Codegen::LoweringConfigAttr::get(
+        context, workgroupTileSizes.value_or(SmallVector<int64_t>{}),
+        workgroupInterchange.value_or(SmallVector<int64_t>{}));
+  }
+  if (failed(setOpConfigAndEntryPointFnTranslation(
+          funcOp, customOp, loweringConfig, translationInfo))) {
+    return funcOp.emitOpError("failed to set custom op configuration");
+  }
+  rewriter.eraseOp(dummyFuncOp);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // Utility functions to set configurations
 //===----------------------------------------------------------------------===//
 
@@ -627,9 +873,14 @@
   return loopInfo;
 }
 
-SmallVector<Operation *> getComputeOps(mlir::FunctionOpInterface funcOp) {
+SmallVector<Operation *> getComputeOps(Operation *containingOp) {
+  if (containingOp->getNumRegions() == 0) {
+    return {};
+  }
+  assert(containingOp->getNumRegions() == 1 &&
+         "expected op with a single region");
   SmallVector<Operation *> computeOps;
-  funcOp.walk([&](Operation *op) {
+  containingOp->getRegion(0).walk([&](Operation *op) {
     if (isa<TilingInterface, IREE::Codegen::UKernelOpInterface>(op)) {
       computeOps.push_back(op);
     }
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
index c211a33..7337549 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h
@@ -8,6 +8,7 @@
 #define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_
 
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/TargetParser/Triple.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -107,6 +108,11 @@
 // Utility functions to set configurations
 //===----------------------------------------------------------------------===//
 
+LogicalResult setDefaultCustomOpLoweringConfig(
+    mlir::FunctionOpInterface FunctionOpInterface,
+    IREE::LinalgExt::CustomOp customOp,
+    std::function<LogicalResult(mlir::FunctionOpInterface)> configFn);
+
 /// Information about a tiled and distributed loop.
 ///
 /// Right now distribution is happening as the same time when we tile the linalg
@@ -148,25 +154,11 @@
   unsigned processorDistributionDim;
 };
 
-/// Assuming that `funcOp` contains a single nested scf.for that represented the
-/// tiled+fused+distributed loops with the distribution being across workgroups,
-/// i.e.
-///
-/// scf.for ... {
-///   ...
-///   scf.for ... {
-///     ...
-///     filtered_op.
-///     ...
-///     filtered_op.
-///     ...
-///   }
-/// }
-///
-/// Returns the list of TilingInterface ops in the functions. If there are no
-/// `scf.for` operations in the function return the TilingInterface operations
-/// in the body of the function if it has a single basic block.
-SmallVector<Operation *> getComputeOps(mlir::FunctionOpInterface funcOp);
+/// Returns the list of TilingInterface ops in the operation obtained by a
+/// post order walk of the operation. This implies that in case of
+/// nested compute ops, the outermost compute ops are towards the end of the
+/// list.
+SmallVector<Operation *> getComputeOps(Operation *containingOp);
 
 /// If the given `forOp` is a tiled and distributed loop, returns its tiling and
 /// distribution information.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c7d6e98..24b3bd1 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1589,10 +1589,10 @@
   let assemblyFormat = [{
     `{` `indexing_maps` `=` $indexing_maps `,`
     `iterator_types` `=` $iterator_types `}`
+    attr-dict-with-keyword
     (`ins` `(` $inputs^ `:` type($inputs) `)`)?
     (`outs` `(` $outputs^ `:` type($outputs) `)`)?
-    $region
-    attr-dict (`->` type($results)^)?
+    $region (`->` type($results)^)?
   }];
 
   let extraClassDeclaration =[{