Enable concat, gather and torch_index_select on Linalg on tensors path. (#5053)

diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index 804ef61..f97158f 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -47,6 +47,10 @@
 namespace mlir {
 namespace iree_compiler {
 
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
 static MemRefType getMemrefTypeForTensor(RankedTensorType tensorType,
                                          ArrayRef<AffineMap> layout = {},
                                          unsigned memorySpace = 0) {
@@ -85,6 +89,10 @@
   return b.create<SubViewOp>(loc, src, offsets, sizes, strides);
 }
 
+//===----------------------------------------------------------------------===//
+// Bufferization helper functions using BlockAndValueMapping.
+//===----------------------------------------------------------------------===//
+
 // Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
 // Allocate the output buffers for the bufferized Linalg op to write into.
 // If the tensor is an init tensor, we additionally copy the original value into
@@ -148,9 +156,10 @@
 }
 
 // Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
-static void finalizeBufferAllocation(OpBuilder &b, linalg::LinalgOp op,
-                                     ValueRange inputs, ValueRange outputs,
-                                     BlockAndValueMapping &bvm) {
+static LogicalResult finalizeBufferAllocation(OpBuilder &b, linalg::LinalgOp op,
+                                              ValueRange inputs,
+                                              ValueRange outputs,
+                                              BlockAndValueMapping &bvm) {
   SmallVector<Value, 8> newOperands = inputs;
   newOperands.append(outputs.begin(), outputs.end());
   auto otherOperands =
@@ -158,7 +167,7 @@
                       [&bvm](Value v) { return bvm.lookupOrDefault(v); });
   newOperands.append(otherOperands.begin(), otherOperands.end());
   Location loc = op.getLoc();
-  op.cloneWithMapper(b, loc, /*resultTypes=*/TypeRange{}, newOperands, bvm);
+  op.clone(b, loc, {}, newOperands);
 
   // Replace the results of the old op with the new output buffers.
   for (auto result : llvm::enumerate(op.getOperation()->getResults())) {
@@ -168,12 +177,9 @@
       b.create<linalg::CopyOp>(loc, outputs[result.index()], resultBuffer);
     }
   }
+  return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization helper functions using BlockAndValueMapping.
-//===----------------------------------------------------------------------===//
-
 /// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
 /// template instantiating one pattern for each linalg::LinalgOp.
 static LogicalResult convertAnyLinalgOp(
@@ -195,13 +201,12 @@
 
   // Delegate to the linalg generic pattern.
   if (auto genericOp = dyn_cast<linalg::GenericOp>(op.getOperation())) {
-    finalizeBufferAllocation(b, genericOp, newInputBuffers, newOutputBuffers,
-                             bvm);
-    return success();
+    return finalizeBufferAllocation(b, genericOp, newInputBuffers,
+                                    newOutputBuffers, bvm);
   }
 
-  finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
-  return success();
+  return finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers,
+                                  bvm);
 }
 
 /// Constants that return tensor types can be handled natively by the
@@ -372,9 +377,8 @@
 }
 
 /// Converts a `tensor.extract` operation into a `load`.
-static LogicalResult convertTensorExtractOp(
-    OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
-    tensor::ExtractOp op, BlockAndValueMapping &bvm) {
+static LogicalResult convertTensorExtractOp(OpBuilder &b, tensor::ExtractOp op,
+                                            BlockAndValueMapping &bvm) {
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
   Value inputBuffer = bvm.lookup(op.tensor());
@@ -663,13 +667,25 @@
           return convertTensorReshapeOp(b, allocationFn, reshapeOp, bvm);
         })
         .Case<tensor::ExtractOp>([&](tensor::ExtractOp extractOp) {
-          return convertTensorExtractOp(b, allocationFn, extractOp, bvm);
+          return convertTensorExtractOp(b, extractOp, bvm);
         })
         .Case<VectorTransferOpInterface>(
             [&](VectorTransferOpInterface vectorTransferOp) {
               return convertTransferOp(b, allocationFn, vectorTransferOp, bvm);
             })
-        .Default([](Operation *) { return success(); });
+        .Default([&](Operation *op) {
+          // Replace any scalar remapped operands to the new values.
+          // TODO(GH-5013): This is really hacky solution, but gets us past for
+          // the time being. This all should be replaced by a pattern.
+          for (unsigned i : llvm::seq<unsigned>(0, op->getNumOperands())) {
+            Value operand = op->getOperand(i);
+            if (operand.getType().isIntOrIndexOrFloat()) {
+              Value remappedVal = bvm.lookupOrNull(operand);
+              if (remappedVal) op->setOperand(i, remappedVal);
+            }
+          }
+          return success();
+        });
   };
   if (funcOp.walk(conversionDispatch).wasInterrupted()) {
     return signalPassFailure();
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 8dd0d6a..ccfd17f 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -684,3 +684,38 @@
 //       CHECK:       linalg.matmul
 //  CHECK-SAME:         ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]]
 //  CHECK-SAME:         outs(%[[RESULT_SUBVIEW]]
+
+// -----
+
+func @gather() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+  %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?xi32>
+  %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
+  %4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+  %5 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
+  %d0 = dim %5, %c0 : tensor<?xi32>
+  %d1 = dim %4, %c1 : tensor<?x?xf32>
+  %3 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %7 = linalg.indexed_generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<?xi32>) outs(%3 : tensor<?x?xf32>) {
+  ^bb0(%arg0: index, %arg1: index, %arg2: i32, %arg3: f32):  // no predecessors
+    %8 = index_cast %arg2 : i32 to index
+    %9 = tensor.extract %4[%8, %arg1] : tensor<?x?xf32>
+    linalg.yield %9 : f32
+  } -> tensor<?x?xf32>
+  flow.dispatch.tensor.store %7, %2 : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+  return
+}
+hal.interface @legacy_io attributes {sym_visibility = "private"} {
+  hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+  hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+  hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @gather()
+//   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @legacy_io::@ret0
+//       CHECK:   linalg.indexed_generic
+//       CHECK:     %[[VAL:.+]] = load %[[ARG0]]
+//       CHECK:     linalg.yield %[[VAL]]
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 52cf165..4a68823 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -197,7 +197,7 @@
         "clamp.mlir",
         "compare.mlir",
         # https://github.com/google/iree/issues/4079
-        # "concatenate.mlir",
+        "concatenate.mlir",
         "constant.mlir",
         # https://github.com/google/iree/issues/4079
         # "convolution.mlir",
@@ -210,7 +210,7 @@
         "exponential_minus_one.mlir",
         "floor.mlir",
         # https://github.com/google/iree/issues/4692
-        # "gather.mlir",
+        "gather.mlir",
         "iota.mlir",
         "log.mlir",
         "log_plus_one.mlir",
@@ -234,7 +234,7 @@
         "subtract.mlir",
         "tanh.mlir",
         # https://github.com/google/iree/issues/4079
-        # "torch_index_select.mlir",
+        "torch_index_select.mlir",
         "transpose.mlir",
         "while.mlir",
     ],
@@ -258,7 +258,7 @@
         "clamp.mlir",
         "compare.mlir",
         # https://github.com/google/iree/issues/4079
-        # "concatenate.mlir",
+        "concatenate.mlir",
         "constant.mlir",
         # https://github.com/google/iree/issues/4079
         # "convolution.mlir",
@@ -271,7 +271,7 @@
         "exponential_minus_one.mlir",
         "floor.mlir",
         # https://github.com/google/iree/issues/4692
-        # "gather.mlir",
+        "gather.mlir",
         "iota.mlir",
         "log.mlir",
         "log_plus_one.mlir",
@@ -294,7 +294,7 @@
         "subtract.mlir",
         "tanh.mlir",
         # https://github.com/google/iree/issues/4079
-        # "torch_index_select.mlir",
+        "torch_index_select.mlir",
         "transpose.mlir",
         "while.mlir",
     ],
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index b8753bc..a04dd67 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -183,6 +183,7 @@
     "broadcast_in_dim.mlir"
     "clamp.mlir"
     "compare.mlir"
+    "concatenate.mlir"
     "constant.mlir"
     "cosine.mlir"
     "divide.mlir"
@@ -191,6 +192,7 @@
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
+    "gather.mlir"
     "iota.mlir"
     "log.mlir"
     "log_plus_one.mlir"
@@ -209,6 +211,7 @@
     "sqrt.mlir"
     "subtract.mlir"
     "tanh.mlir"
+    "torch_index_select.mlir"
     "transpose.mlir"
     "while.mlir"
   TARGET_BACKEND
@@ -232,6 +235,7 @@
     "broadcast_in_dim.mlir"
     "clamp.mlir"
     "compare.mlir"
+    "concatenate.mlir"
     "constant.mlir"
     "cosine.mlir"
     "divide.mlir"
@@ -240,6 +244,7 @@
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
+    "gather.mlir"
     "iota.mlir"
     "log.mlir"
     "log_plus_one.mlir"
@@ -258,6 +263,7 @@
     "sqrt.mlir"
     "subtract.mlir"
     "tanh.mlir"
+    "torch_index_select.mlir"
     "transpose.mlir"
     "while.mlir"
   TARGET_BACKEND