Lowering HLO gather to VMLA copies.
This is largely the existing implementation with minor tweaks for dynamic shapes.

PiperOrigin-RevId: 294519556
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index e6195de..946dca7 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -337,7 +337,6 @@
 
 void ConstRankedShapeOp::build(Builder *builder, OperationState &result,
                                Type type) {
-  result.addAttribute("value", UnitAttr::get(builder->getContext()));
   result.types.push_back(type);
 }
 
diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
index c0456c5..400fe71 100644
--- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
@@ -21,7 +21,7 @@
 // CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<1x2xf32>
 func @staticGetRankedShapeToConst(%arg0: tensor<1x2xf32>) -> (!shapex.ranked_shape<[1,2],i32>) {
   // CHECK-NOT: %[[T]]
-  // CHECK: %[[S:.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[1,2],i32>
+  // CHECK: %[[S:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2],i32>
   %0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2],i32>
   // CHECK: return %[[S]]
   return %0 : !shapex.ranked_shape<[1,2],i32>
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index e03e99f..4fdf8c7 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -111,9 +111,8 @@
   PatternMatchResult matchAndRewrite(
       xla_hlo::ConcatenateOp srcOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    auto indexType = rewriter.getIntegerType(32);
     auto zero = rewriter.createOrFold<mlir::ConstantOp>(
-        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+        srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
 
     auto dst = VMLAConversionTarget::allocateOutputBuffer(
         srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
@@ -154,6 +153,114 @@
   TypeConverter &typeConverter;
 };
 
+// Lowers a subset of gathers along axis 0 that are really just a slice and
+// reshape.
+struct GatherOpConversion : public OpConversionPattern<xla_hlo::GatherOp> {
+  GatherOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  // TODO(gcmn): This only handles a minimal number of cases. When XLA
+  // redefines gather to be simpler, lower it properly.
+  PatternMatchResult matchAndRewrite(
+      xla_hlo::GatherOp gatherOp, ArrayRef<Value> operandValues,
+      ConversionPatternRewriter &rewriter) const override {
+    xla_hlo::GatherOpOperandAdaptor operands(operandValues);
+    auto dimension_numbers = gatherOp.dimension_numbers();
+    if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
+      gatherOp.emitRemark()
+          << "couldn't lower gather with index_vector_dim != 0";
+      return matchFailure();
+    }
+    if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
+        dimension_numbers.start_index_map()
+                .getValue(0)
+                .cast<IntegerAttr>()
+                .getValue() != 0) {
+      gatherOp.emitRemark()
+          << "couldn't lower gather with start_index_map != [0]";
+      return matchFailure();
+    }
+    if (dimension_numbers.collapsed_slice_dims().getType().getRank() != 1 ||
+        dimension_numbers.collapsed_slice_dims()
+                .getValue(0)
+                .cast<IntegerAttr>()
+                .getValue() != 0) {
+      gatherOp.emitRemark()
+          << "couldn't lower gather with collapsed_dims != [0]";
+      return matchFailure();
+    }
+
+    auto resultType = gatherOp.getResult().getType().cast<RankedTensorType>();
+    if (dimension_numbers.offset_dims().getType().getNumElements() !=
+        resultType.getRank()) {
+      gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
+                               "[0,...,rank of output]";
+      return matchFailure();
+    }
+    for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
+      if (it.index() != it.value()) {
+        gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
+                                 "[0,...,rank of output]";
+        return matchFailure();
+      }
+    }
+
+    for (auto it : llvm::enumerate(resultType.getShape())) {
+      if (gatherOp.slice_sizes()
+              .getValue(it.index() + 1)
+              .cast<IntegerAttr>()
+              .getValue() != it.value()) {
+        gatherOp.emitRemark()
+            << "couldn't lower gather with slice_sizes not [1] + final shape";
+        return matchFailure();
+      }
+    }
+
+    auto srcShape = VMLAConversionTarget::getTensorShape(
+        gatherOp.getLoc(), gatherOp.operand(), typeConverter, rewriter);
+    auto dstShape = VMLAConversionTarget::getTensorShape(
+        gatherOp.getLoc(), gatherOp.getResult(), typeConverter, rewriter);
+
+    auto inputType = gatherOp.operand().getType().cast<RankedTensorType>();
+    auto startIndicesType =
+        gatherOp.start_indices().getType().cast<ShapedType>();
+    int rank = inputType.getRank();
+    SmallVector<Value, 4> srcIndices(rank);
+    SmallVector<Value, 4> dstIndices(rank);
+    SmallVector<Value, 4> lengths(rank);
+    Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+        gatherOp.getLoc(), rewriter.getI32IntegerAttr(0));
+    for (int i = 0; i < rank; ++i) {
+      if (i < startIndicesType.getNumElements()) {
+        auto srcIndexByteOffset = rewriter.createOrFold<mlir::ConstantOp>(
+            gatherOp.getLoc(), rewriter.getI32IntegerAttr(i * sizeof(int32_t)));
+        srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
+            gatherOp.getLoc(), rewriter.getIntegerType(32),
+            operands.start_indices(), srcIndexByteOffset);
+      } else {
+        // Pad missing dimensions to zero offsets.
+        srcIndices[i] = zero;
+      }
+      dstIndices[i] = zero;
+      lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+          gatherOp.getLoc(),
+          rewriter.getI32IntegerAttr(gatherOp.slice_sizes().getValue<int64_t>(
+              {static_cast<uint64_t>(i)})));
+    }
+
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        gatherOp.getLoc(), gatherOp.getResult(), typeConverter, rewriter);
+    rewriter.create<IREE::VMLA::CopyOp>(
+        gatherOp.getLoc(), operands.operand(), srcShape, srcIndices, dst,
+        dstShape, dstIndices, lengths,
+        TypeAttr::get(inputType.getElementType()));
+    rewriter.replaceOp(gatherOp, {dst});
+    return matchSuccess();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 // Converts a static slice op to a copy (if the source must be preserved).
 struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
   SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
@@ -178,24 +285,21 @@
         srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
 
     int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
-    auto indexType = rewriter.getIntegerType(32);
     SmallVector<Value, 4> srcIndices(rank);
     SmallVector<Value, 4> dstIndices(rank);
     SmallVector<Value, 4> lengths(rank);
     Value zero = rewriter.createOrFold<mlir::ConstantOp>(
-        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+        srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
     for (int i = 0; i < rank; ++i) {
       uint64_t ui = static_cast<uint64_t>(i);
       srcIndices[i] = rewriter.createOrFold<mlir::ConstantOp>(
-          srcOp.getLoc(), indexType,
-          rewriter.getI32IntegerAttr(
-              srcOp.start_indices().getValue<int64_t>({ui})));
+          srcOp.getLoc(), rewriter.getI32IntegerAttr(
+                              srcOp.start_indices().getValue<int64_t>({ui})));
       dstIndices[i] = zero;
       lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
-          srcOp.getLoc(), indexType,
-          rewriter.getI32IntegerAttr(
-              srcOp.limit_indices().getValue<int64_t>({ui}) -
-              srcOp.start_indices().getValue<int64_t>({ui})));
+          srcOp.getLoc(), rewriter.getI32IntegerAttr(
+                              srcOp.limit_indices().getValue<int64_t>({ui}) -
+                              srcOp.start_indices().getValue<int64_t>({ui})));
     }
 
     auto dst = VMLAConversionTarget::allocateOutputBuffer(
@@ -229,21 +333,19 @@
         srcOp.getLoc(), srcOp.result(), typeConverter, rewriter);
 
     int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
-    auto indexType = rewriter.getIntegerType(32);
     SmallVector<Value, 4> srcIndices(rank);
     SmallVector<Value, 4> dstIndices(rank);
     SmallVector<Value, 4> lengths(rank);
     Value zero = rewriter.createOrFold<mlir::ConstantOp>(
-        srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+        srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
     for (int i = 0; i < rank; ++i) {
       srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
           srcOp.getLoc(), rewriter.getIntegerType(32), operands[1],
           rewriter.createOrFold<mlir::ConstantOp>(
-              srcOp.getLoc(), indexType,
-              rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
+              srcOp.getLoc(), rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
       dstIndices[i] = zero;
       lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
-          srcOp.getLoc(), indexType,
+          srcOp.getLoc(),
           rewriter.getI32IntegerAttr(srcOp.slice_sizes().getValue<int64_t>(
               {static_cast<uint64_t>(i)})));
     }
@@ -355,6 +457,7 @@
   // or transfers.
   patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
   patterns.insert<ConcatenateOpConversion>(context, typeConverter);
+  patterns.insert<GatherOpConversion>(context, typeConverter);
   patterns.insert<SliceOpConversion>(context, typeConverter);
   patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
 
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
index fe5271c..655eeb6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -3,8 +3,8 @@
 // CHECK-LABEL: @broadcast_in_dim_2D_3D
 func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> {
   // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
-  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[2,4],i32>
-  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[2,4],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4],i32>
   // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
   %input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
   // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
@@ -19,8 +19,8 @@
 // CHECK-LABEL: @broadcast_in_dim_3D_scalar
 func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> {
   // CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
-  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[],i32>
-  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4],i32>
   // CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
   %input = constant dense<42> : tensor<i32>
   // CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir
new file mode 100644
index 0000000..6164e05
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir
@@ -0,0 +1,129 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s -verify-diagnostics | IreeFileCheck %s
+
+// CHECK-LABEL: @gather_scalar_indices
+// CHECK-SAME: [[SRC:%.+]]: !iree.ref<!vmla.buffer>,
+// CHECK-SAME: [[INDICES:%.+]]: !iree.ref<!vmla.buffer>)
+func @gather_scalar_indices(%input : tensor<5x1x5xi32>, %start_indices : tensor<i64>) -> tensor<1x5xi32> {
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5,1,5],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,5],i32>
+  // CHECK-DAG: [[INDEX0:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c0_i32)
+  // CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"(%c20_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[INDEX0]], %c0_i32, %c0_i32,
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c1_i32, %c5_i32
+  // CHECK-SAME: ) {element_type = i32}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1]> : tensor<2xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+  } : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
+  // CHECK-NEXT: return [[DST]]
+  return %0 : tensor<1x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @gather_fully_specified_indices
+// CHECK-SAME: [[SRC:%.+]]: !iree.ref<!vmla.buffer>,
+// CHECK-SAME: [[INDICES:%.+]]: !iree.ref<!vmla.buffer>)
+func @gather_fully_specified_indices(%input : tensor<5x2x3xf32>, %start_indices : tensor<3xi64>) -> tensor<2x3xf32> {
+  // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5,2,3],i32>
+  // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[2,3],i32>
+  // CHECK-DAG: [[INDEX0:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c0_i32)
+  // CHECK-DAG: [[INDEX1:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c4_i32)
+  // CHECK-DAG: [[INDEX2:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c8_i32)
+  // CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"(%c24_i32)
+  // CHECK-NEXT: "vmla.copy"(
+  // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[INDEX0]], [[INDEX1]], [[INDEX2]],
+  // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32, %c0_i32,
+  // CHECK-SAME: %c1_i32, %c2_i32, %c3_i32
+  // CHECK-SAME: ) {element_type = f32}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1]> : tensor<2xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<5x2x3xf32>, tensor<3xi64>) -> tensor<2x3xf32>
+  // CHECK-NEXT: return [[DST]]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_axis_1(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+  // expected-remark@+2 {{couldn't lower gather}}
+  // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 1 : i64,
+      offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_collapse(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+  // expected-remark@+2 {{couldn't lower gather}}
+  // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<1> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_transposes(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+  // expected-remark@+2 {{couldn't lower gather}}
+  // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+      start_index_map = dense<[1, 0]> : tensor<2xi64>
+    },
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_batch_dims(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+  // expected-remark@+2 {{couldn't lower gather}}
+  // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+  %0 = "xla_hlo.gather"(%input, %start_indices) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<1> : tensor<1xi64>,
+      start_index_map = dense<[1, 0]> : tensor<2xi64>
+    },
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+  return
+}