Lowering HLO gather to VMLA copies.
This is largely the existing implementation with minor tweaks for dynamic shapes.
PiperOrigin-RevId: 294519556
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index e6195de..946dca7 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -337,7 +337,6 @@
void ConstRankedShapeOp::build(Builder *builder, OperationState &result,
Type type) {
- result.addAttribute("value", UnitAttr::get(builder->getContext()));
result.types.push_back(type);
}
diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
index c0456c5..400fe71 100644
--- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
@@ -21,7 +21,7 @@
// CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<1x2xf32>
func @staticGetRankedShapeToConst(%arg0: tensor<1x2xf32>) -> (!shapex.ranked_shape<[1,2],i32>) {
// CHECK-NOT: %[[T]]
- // CHECK: %[[S:.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[1,2],i32>
+ // CHECK: %[[S:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2],i32>
%0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2],i32>
// CHECK: return %[[S]]
return %0 : !shapex.ranked_shape<[1,2],i32>
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index e03e99f..4fdf8c7 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -111,9 +111,8 @@
PatternMatchResult matchAndRewrite(
xla_hlo::ConcatenateOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto indexType = rewriter.getIntegerType(32);
auto zero = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
auto dst = VMLAConversionTarget::allocateOutputBuffer(
srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
@@ -154,6 +153,114 @@
TypeConverter &typeConverter;
};
+// Lowers a subset of gathers along axis 0 that are really just a slice and
+// reshape.
+struct GatherOpConversion : public OpConversionPattern<xla_hlo::GatherOp> {
+ GatherOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ // TODO(gcmn): This only handles a minimal number of cases. When XLA
+ // redefines gather to be simpler, lower it properly.
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::GatherOp gatherOp, ArrayRef<Value> operandValues,
+ ConversionPatternRewriter &rewriter) const override {
+ xla_hlo::GatherOpOperandAdaptor operands(operandValues);
+ auto dimension_numbers = gatherOp.dimension_numbers();
+ if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
+ gatherOp.emitRemark()
+ << "couldn't lower gather with index_vector_dim != 0";
+ return matchFailure();
+ }
+ if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
+ dimension_numbers.start_index_map()
+ .getValue(0)
+ .cast<IntegerAttr>()
+ .getValue() != 0) {
+ gatherOp.emitRemark()
+ << "couldn't lower gather with start_index_map != [0]";
+ return matchFailure();
+ }
+ if (dimension_numbers.collapsed_slice_dims().getType().getRank() != 1 ||
+ dimension_numbers.collapsed_slice_dims()
+ .getValue(0)
+ .cast<IntegerAttr>()
+ .getValue() != 0) {
+ gatherOp.emitRemark()
+ << "couldn't lower gather with collapsed_dims != [0]";
+ return matchFailure();
+ }
+
+ auto resultType = gatherOp.getResult().getType().cast<RankedTensorType>();
+ if (dimension_numbers.offset_dims().getType().getNumElements() !=
+ resultType.getRank()) {
+ gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
+ "[0,...,rank of output]";
+ return matchFailure();
+ }
+ for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
+ if (it.index() != it.value()) {
+ gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
+ "[0,...,rank of output]";
+ return matchFailure();
+ }
+ }
+
+ for (auto it : llvm::enumerate(resultType.getShape())) {
+ if (gatherOp.slice_sizes()
+ .getValue(it.index() + 1)
+ .cast<IntegerAttr>()
+ .getValue() != it.value()) {
+ gatherOp.emitRemark()
+ << "couldn't lower gather with slice_sizes not [1] + final shape";
+ return matchFailure();
+ }
+ }
+
+ auto srcShape = VMLAConversionTarget::getTensorShape(
+ gatherOp.getLoc(), gatherOp.operand(), typeConverter, rewriter);
+ auto dstShape = VMLAConversionTarget::getTensorShape(
+ gatherOp.getLoc(), gatherOp.getResult(), typeConverter, rewriter);
+
+ auto inputType = gatherOp.operand().getType().cast<RankedTensorType>();
+ auto startIndicesType =
+ gatherOp.start_indices().getType().cast<ShapedType>();
+ int rank = inputType.getRank();
+ SmallVector<Value, 4> srcIndices(rank);
+ SmallVector<Value, 4> dstIndices(rank);
+ SmallVector<Value, 4> lengths(rank);
+ Value zero = rewriter.createOrFold<mlir::ConstantOp>(
+ gatherOp.getLoc(), rewriter.getI32IntegerAttr(0));
+ for (int i = 0; i < rank; ++i) {
+ if (i < startIndicesType.getNumElements()) {
+ auto srcIndexByteOffset = rewriter.createOrFold<mlir::ConstantOp>(
+ gatherOp.getLoc(), rewriter.getI32IntegerAttr(i * sizeof(int32_t)));
+ srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
+ gatherOp.getLoc(), rewriter.getIntegerType(32),
+ operands.start_indices(), srcIndexByteOffset);
+ } else {
+ // Pad missing dimensions to zero offsets.
+ srcIndices[i] = zero;
+ }
+ dstIndices[i] = zero;
+ lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
+ gatherOp.getLoc(),
+ rewriter.getI32IntegerAttr(gatherOp.slice_sizes().getValue<int64_t>(
+ {static_cast<uint64_t>(i)})));
+ }
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ gatherOp.getLoc(), gatherOp.getResult(), typeConverter, rewriter);
+ rewriter.create<IREE::VMLA::CopyOp>(
+ gatherOp.getLoc(), operands.operand(), srcShape, srcIndices, dst,
+ dstShape, dstIndices, lengths,
+ TypeAttr::get(inputType.getElementType()));
+ rewriter.replaceOp(gatherOp, {dst});
+ return matchSuccess();
+ }
+
+ TypeConverter &typeConverter;
+};
+
// Converts a static slice op to a copy (if the source must be preserved).
struct SliceOpConversion : public OpConversionPattern<xla_hlo::SliceOp> {
SliceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
@@ -178,24 +285,21 @@
srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
- auto indexType = rewriter.getIntegerType(32);
SmallVector<Value, 4> srcIndices(rank);
SmallVector<Value, 4> dstIndices(rank);
SmallVector<Value, 4> lengths(rank);
Value zero = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
for (int i = 0; i < rank; ++i) {
uint64_t ui = static_cast<uint64_t>(i);
srcIndices[i] = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType,
- rewriter.getI32IntegerAttr(
- srcOp.start_indices().getValue<int64_t>({ui})));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(
+ srcOp.start_indices().getValue<int64_t>({ui})));
dstIndices[i] = zero;
lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType,
- rewriter.getI32IntegerAttr(
- srcOp.limit_indices().getValue<int64_t>({ui}) -
- srcOp.start_indices().getValue<int64_t>({ui})));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(
+ srcOp.limit_indices().getValue<int64_t>({ui}) -
+ srcOp.start_indices().getValue<int64_t>({ui})));
}
auto dst = VMLAConversionTarget::allocateOutputBuffer(
@@ -229,21 +333,19 @@
srcOp.getLoc(), srcOp.result(), typeConverter, rewriter);
int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
- auto indexType = rewriter.getIntegerType(32);
SmallVector<Value, 4> srcIndices(rank);
SmallVector<Value, 4> dstIndices(rank);
SmallVector<Value, 4> lengths(rank);
Value zero = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType, rewriter.getI32IntegerAttr(0));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
for (int i = 0; i < rank; ++i) {
srcIndices[i] = rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
srcOp.getLoc(), rewriter.getIntegerType(32), operands[1],
rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType,
- rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
+ srcOp.getLoc(), rewriter.getI32IntegerAttr(i * sizeof(int32_t))));
dstIndices[i] = zero;
lengths[i] = rewriter.createOrFold<mlir::ConstantOp>(
- srcOp.getLoc(), indexType,
+ srcOp.getLoc(),
rewriter.getI32IntegerAttr(srcOp.slice_sizes().getValue<int64_t>(
{static_cast<uint64_t>(i)})));
}
@@ -355,6 +457,7 @@
// or transfers.
patterns.insert<BroadcastInDimOpConversion>(context, typeConverter);
patterns.insert<ConcatenateOpConversion>(context, typeConverter);
+ patterns.insert<GatherOpConversion>(context, typeConverter);
patterns.insert<SliceOpConversion>(context, typeConverter);
patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
index fe5271c..655eeb6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -3,8 +3,8 @@
// CHECK-LABEL: @broadcast_in_dim_2D_3D
func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> {
// CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
- // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[2,4],i32>
- // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[2,4],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4],i32>
// CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
%input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
// CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
@@ -19,8 +19,8 @@
// CHECK-LABEL: @broadcast_in_dim_3D_scalar
func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> {
// CHECK-DAG: [[SRC:%.+]] = "vmla.constant"
- // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[],i32>
- // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape {value} : !shapex.ranked_shape<[3,2,4],i32>
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4],i32>
// CHECK-DAG: [[DST_SIZE:%.+]] = constant 96 : i32
%input = constant dense<42> : tensor<i32>
// CHECK-NEXT: [[DST:%.+]] = "vmla.buffer.alloc"([[DST_SIZE]])
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir
new file mode 100644
index 0000000..6164e05
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/gather.mlir
@@ -0,0 +1,129 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s -verify-diagnostics | IreeFileCheck %s
+
+// CHECK-LABEL: @gather_scalar_indices
+// CHECK-SAME: [[SRC:%.+]]: !iree.ref<!vmla.buffer>,
+// CHECK-SAME: [[INDICES:%.+]]: !iree.ref<!vmla.buffer>)
+func @gather_scalar_indices(%input : tensor<5x1x5xi32>, %start_indices : tensor<i64>) -> tensor<1x5xi32> {
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5,1,5],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,5],i32>
+ // CHECK-DAG: [[INDEX0:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c0_i32)
+ // CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"(%c20_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[INDEX0]], %c0_i32, %c0_i32,
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c1_i32, %c5_i32
+ // CHECK-SAME: ) {element_type = i32}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+ } : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
+ // CHECK-NEXT: return [[DST]]
+ return %0 : tensor<1x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @gather_fully_specified_indices
+// CHECK-SAME: [[SRC:%.+]]: !iree.ref<!vmla.buffer>,
+// CHECK-SAME: [[INDICES:%.+]]: !iree.ref<!vmla.buffer>)
+func @gather_fully_specified_indices(%input : tensor<5x2x3xf32>, %start_indices : tensor<3xi64>) -> tensor<2x3xf32> {
+ // CHECK-DAG: [[SRC_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5,2,3],i32>
+ // CHECK-DAG: [[DST_SHAPE:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[2,3],i32>
+ // CHECK-DAG: [[INDEX0:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c0_i32)
+ // CHECK-DAG: [[INDEX1:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c4_i32)
+ // CHECK-DAG: [[INDEX2:%.+]] = "vmla.buffer.load.i32"([[INDICES]], %c8_i32)
+ // CHECK-DAG: [[DST:%.+]] = "vmla.buffer.alloc"(%c24_i32)
+ // CHECK-NEXT: "vmla.copy"(
+ // CHECK-SAME: [[SRC]], [[SRC_SHAPE]], [[INDEX0]], [[INDEX1]], [[INDEX2]],
+ // CHECK-SAME: [[DST]], [[DST_SHAPE]], %c0_i32, %c0_i32, %c0_i32,
+ // CHECK-SAME: %c1_i32, %c2_i32, %c3_i32
+ // CHECK-SAME: ) {element_type = f32}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<5x2x3xf32>, tensor<3xi64>) -> tensor<2x3xf32>
+ // CHECK-NEXT: return [[DST]]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_axis_1(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+ // expected-remark@+2 {{couldn't lower gather}}
+ // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 1 : i64,
+ offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+ return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_collapse(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+ // expected-remark@+2 {{couldn't lower gather}}
+ // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<1> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+ return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_transposes(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+ // expected-remark@+2 {{couldn't lower gather}}
+ // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
+ start_index_map = dense<[1, 0]> : tensor<2xi64>
+ },
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+ return
+}
+
+// -----
+
+// expected-error@-3 {{conversion to the VMLA dialect failed}}
+func @gather_not_lowered_batch_dims(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
+ // expected-remark@+2 {{couldn't lower gather}}
+ // expected-error@+1 {{failed to legalize operation 'xla_hlo.gather' that was explicitly marked illegal}}
+ %0 = "xla_hlo.gather"(%input, %start_indices) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<1> : tensor<1xi64>,
+ start_index_map = dense<[1, 0]> : tensor<2xi64>
+ },
+ slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+ } : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
+ return
+}