Add support for bufferizing LinalgExt ops. (#6377)

Since LinalgExtInterface is a subset of LinalgInterface, we can use
template in convertAnyLinalgOp. analyseLinalg*Ops function has different
implementation because we don't define indexing maps in LinalgExtOp.

Also adds a interface method -- clone.

This is a step towards https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index afd1b47..6e6b5c7 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -50,6 +50,7 @@
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Dialect/LinalgExt/IR",
         "//iree/compiler/Dialect/Shape/IR",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Affine",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index ae91a28..8466201 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -58,6 +58,7 @@
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::IREE::IR
+    iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::Shape::IR
   PUBLIC
 )
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 7a2d0e0..ad130d0 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -46,6 +46,7 @@
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -380,6 +381,26 @@
   return tiedOperands;
 }
 
+static LogicalResult analyseLinalgExtOps(linalg_ext::LinalgExtOp op,
+                                         BufferizationPlan &plan) {
+  if (!op.hasTensorSemantics()) return success();
+  // TODO(hanchung): Revisit if we can tie together op.getOutputOperands() with
+  // the corresponding op.getInputOperands(). For now we have limit LinalgExt
+  // ops, and there is no use case. So we ignore it.
+  // Note: this is what should be done for LinalgOps, except for a what is done
+  // for operand fusion today.
+  for (auto input : op.getInputOperands()) {
+    plan.insert(input->get());
+  }
+  for (auto output : op.getOutputOperands()) {
+    plan.insert(output->get());
+  }
+  for (auto result : op->getResults()) {
+    plan.insert(result);
+  }
+  return success();
+}
+
 /// Adds the corresponding `outs` and result tensors of the linalg op into the
 /// same equivalence class.
 static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
@@ -580,6 +601,10 @@
         .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
           return analyseLinalgOps(linalgOp, plan);
         })
+        .Case<linalg_ext::LinalgExtOp>(
+            [&](linalg_ext::LinalgExtOp linalgExtOp) {
+              return analyseLinalgExtOps(linalgExtOp, plan);
+            })
         .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
             [&](auto reshapeOp) {
               return analyseSingleOperandResultOp(reshapeOp.src(),
@@ -910,7 +935,8 @@
     resultBuffer =
         TypeSwitch<Operation *, Value>(op)
             .Case<scf::IfOp, scf::ForOp, linalg::LinalgOp,
-                  tensor::InsertSliceOp, vector::TransferWriteOp>(
+                  linalg_ext::LinalgExtOp, tensor::InsertSliceOp,
+                  vector::TransferWriteOp>(
                 [&](auto op) { return resultBuffer; })
             .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
                 [&](auto reshapeOp) {
@@ -1123,9 +1149,10 @@
 /// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
 /// template instantiating one pattern for each linalg::LinalgOp. The method
 /// expects all operands and results have already been mapped to memrefs.
+template <typename OpTy>
 static LogicalResult convertAnyLinalgOp(
-    OpBuilder &b, linalg::LinalgOp op, BlockAndValueMapping &bvm,
-    BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) {
+    OpBuilder &b, OpTy op, BlockAndValueMapping &bvm, BufferizationPlan &plan,
+    WorkgroupMemoryAllocationFn allocationFn) {
   // Skip linalg ops inserted by this pass.
   if (op.hasBufferSemantics()) return success();
 
@@ -1539,12 +1566,12 @@
           }
           return convertPadTensorOp(b, padTensorOp, bvm);
         })
-        .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
-          if (failed(getOrAllocateResultBuffers(b, linalgOp.getOperation(), bvm,
-                                                plan, allocationFn))) {
+        .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>([&](auto op) {
+          if (failed(
+                  getOrAllocateResultBuffers(b, op, bvm, plan, allocationFn))) {
             return failure();
           }
-          return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn);
+          return convertAnyLinalgOp(b, op, bvm, plan, allocationFn);
         })
         .Case<tensor::InsertSliceOp>(
             [&](tensor::InsertSliceOp subTensorInsertOp) {
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index e21945f..07257b6 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -2384,3 +2384,23 @@
 //       CHECK:       scf.if
 //   CHECK-DAG:         memref.store %[[V1]], %[[INOUT]][%[[P1]]]
 //   CHECK-DAG:         memref.store %[[V2]], %[[INOUT]][%[[ARG1]]]
+
+// -----
+
+func @linalg_ext_sort_1d() {
+  %c0 = constant 0 : index
+  %0 = hal.interface.binding.subspan @io::@rw[%c0] : !flow.dispatch.tensor<readwrite:128xi32>
+  %1 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:128xi32> -> tensor<128xi32>
+  %2 = linalg_ext.sort {dimension = 0 : i64} outs(%1 : tensor<128xi32>) {
+  ^bb0(%arg0: i32, %arg1: i32):  // no predecessors
+    %3 = cmpi sgt, %arg0, %arg1 : i32
+    linalg_ext.yield %3 : i1
+  } -> tensor<128xi32>
+  flow.dispatch.tensor.store %2, %0, offsets = [], sizes = [], strides = [] : tensor<128xi32> -> !flow.dispatch.tensor<readwrite:128xi32>
+  return
+}
+// CHECK-LABEL: func @linalg_ext_sort_1d()
+//   CHECK-DAG:   %[[INOUT:.+]] = hal.interface.binding.subspan @io::@rw
+//       CHECK:   linalg_ext.sort
+//  CHECK-SAME:     dimension = 0 : i64
+//  CHECK-SAME:     outs(%[[INOUT]] : memref<128xi32>)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
index f67a982..35584d8 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -7,6 +7,8 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
 
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 19424d2..77d0ddf 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -438,6 +438,30 @@
             return opOperand->get().getType().template isa<RankedTensorType>();
           });
       }]
+    >,
+    //===------------------------------------------------------------------===//
+    // Other static interface methods.
+    //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      /*desc=*/[{
+        Clone the current operation with the given location and operands. This
+        is used to abstract away the optional underlying region creation. This
+        does not change the balance between input, output_buffer and
+        init_tensors operands.
+      }],
+      /*retTy=*/"Operation *",
+      /*methodName=*/"clone",
+      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+           "ValueRange":$operands),
+      [{
+        BlockAndValueMapping bvm;
+        OperationState state(
+          loc, ConcreteOp::getOperationName(), operands, resultTypes,
+          $_op->getAttrs());
+        for (Region &r : $_op->getRegions())
+          r.cloneInto(state.addRegion(), bvm);
+        return b.createOperation(state);
+      }]
     >
   ];