Enable concat, gather and torch_index_select on Linalg on tensors path. (#5053)
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index 804ef61..f97158f 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -47,6 +47,10 @@
namespace mlir {
namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
static MemRefType getMemrefTypeForTensor(RankedTensorType tensorType,
ArrayRef<AffineMap> layout = {},
unsigned memorySpace = 0) {
@@ -85,6 +89,10 @@
return b.create<SubViewOp>(loc, src, offsets, sizes, strides);
}
+//===----------------------------------------------------------------------===//
+// Bufferization helper functions using BlockAndValueMapping.
+//===----------------------------------------------------------------------===//
+
// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
// Allocate the output buffers for the bufferized Linalg op to write into.
// If the tensor is an init tensor, we additionally copy the original value into
@@ -148,9 +156,10 @@
}
// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
-static void finalizeBufferAllocation(OpBuilder &b, linalg::LinalgOp op,
- ValueRange inputs, ValueRange outputs,
- BlockAndValueMapping &bvm) {
+static LogicalResult finalizeBufferAllocation(OpBuilder &b, linalg::LinalgOp op,
+ ValueRange inputs,
+ ValueRange outputs,
+ BlockAndValueMapping &bvm) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
auto otherOperands =
@@ -158,7 +167,7 @@
[&bvm](Value v) { return bvm.lookupOrDefault(v); });
newOperands.append(otherOperands.begin(), otherOperands.end());
Location loc = op.getLoc();
- op.cloneWithMapper(b, loc, /*resultTypes=*/TypeRange{}, newOperands, bvm);
+ op.clone(b, loc, {}, newOperands);
// Replace the results of the old op with the new output buffers.
for (auto result : llvm::enumerate(op.getOperation()->getResults())) {
@@ -168,12 +177,9 @@
b.create<linalg::CopyOp>(loc, outputs[result.index()], resultBuffer);
}
}
+ return success();
}
-//===----------------------------------------------------------------------===//
-// Bufferization helper functions using BlockAndValueMapping.
-//===----------------------------------------------------------------------===//
-
/// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
/// template instantiating one pattern for each linalg::LinalgOp.
static LogicalResult convertAnyLinalgOp(
@@ -195,13 +201,12 @@
// Delegate to the linalg generic pattern.
if (auto genericOp = dyn_cast<linalg::GenericOp>(op.getOperation())) {
- finalizeBufferAllocation(b, genericOp, newInputBuffers, newOutputBuffers,
- bvm);
- return success();
+ return finalizeBufferAllocation(b, genericOp, newInputBuffers,
+ newOutputBuffers, bvm);
}
- finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
- return success();
+ return finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers,
+ bvm);
}
/// Constants that return tensor types can be handled natively by the
@@ -372,9 +377,8 @@
}
/// Converts a `tensor.extract` operation into a `load`.
-static LogicalResult convertTensorExtractOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
- tensor::ExtractOp op, BlockAndValueMapping &bvm) {
+static LogicalResult convertTensorExtractOp(OpBuilder &b, tensor::ExtractOp op,
+ BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Value inputBuffer = bvm.lookup(op.tensor());
@@ -663,13 +667,25 @@
return convertTensorReshapeOp(b, allocationFn, reshapeOp, bvm);
})
.Case<tensor::ExtractOp>([&](tensor::ExtractOp extractOp) {
- return convertTensorExtractOp(b, allocationFn, extractOp, bvm);
+ return convertTensorExtractOp(b, extractOp, bvm);
})
.Case<VectorTransferOpInterface>(
[&](VectorTransferOpInterface vectorTransferOp) {
return convertTransferOp(b, allocationFn, vectorTransferOp, bvm);
})
- .Default([](Operation *) { return success(); });
+ .Default([&](Operation *op) {
+ // Replace any scalar remapped operands to the new values.
+ // TODO(GH-5013): This is really hacky solution, but gets us past for
+ // the time being. This all should be replaced by a pattern.
+ for (unsigned i : llvm::seq<unsigned>(0, op->getNumOperands())) {
+ Value operand = op->getOperand(i);
+ if (operand.getType().isIntOrIndexOrFloat()) {
+ Value remappedVal = bvm.lookupOrNull(operand);
+ if (remappedVal) op->setOperand(i, remappedVal);
+ }
+ }
+ return success();
+ });
};
if (funcOp.walk(conversionDispatch).wasInterrupted()) {
return signalPassFailure();
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 8dd0d6a..ccfd17f 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -684,3 +684,38 @@
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]]
// CHECK-SAME: outs(%[[RESULT_SUBVIEW]]
+
+// -----
+
+func @gather() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?xi32>
+ %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
+ %4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %5 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
+ %d0 = dim %5, %c0 : tensor<?xi32>
+ %d1 = dim %4, %c1 : tensor<?x?xf32>
+ %3 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %7 = linalg.indexed_generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<?xi32>) outs(%3 : tensor<?x?xf32>) {
+ ^bb0(%arg0: index, %arg1: index, %arg2: i32, %arg3: f32): // no predecessors
+ %8 = index_cast %arg2 : i32 to index
+ %9 = tensor.extract %4[%8, %arg1] : tensor<?x?xf32>
+ linalg.yield %9 : f32
+ } -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %7, %2 : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+ return
+}
+hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @gather()
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @legacy_io::@ret0
+// CHECK: linalg.indexed_generic
+// CHECK: %[[VAL:.+]] = load %[[ARG0]]
+// CHECK: linalg.yield %[[VAL]]
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 52cf165..4a68823 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -197,7 +197,7 @@
"clamp.mlir",
"compare.mlir",
# https://github.com/google/iree/issues/4079
- # "concatenate.mlir",
+ "concatenate.mlir",
"constant.mlir",
# https://github.com/google/iree/issues/4079
# "convolution.mlir",
@@ -210,7 +210,7 @@
"exponential_minus_one.mlir",
"floor.mlir",
# https://github.com/google/iree/issues/4692
- # "gather.mlir",
+ "gather.mlir",
"iota.mlir",
"log.mlir",
"log_plus_one.mlir",
@@ -234,7 +234,7 @@
"subtract.mlir",
"tanh.mlir",
# https://github.com/google/iree/issues/4079
- # "torch_index_select.mlir",
+ "torch_index_select.mlir",
"transpose.mlir",
"while.mlir",
],
@@ -258,7 +258,7 @@
"clamp.mlir",
"compare.mlir",
# https://github.com/google/iree/issues/4079
- # "concatenate.mlir",
+ "concatenate.mlir",
"constant.mlir",
# https://github.com/google/iree/issues/4079
# "convolution.mlir",
@@ -271,7 +271,7 @@
"exponential_minus_one.mlir",
"floor.mlir",
# https://github.com/google/iree/issues/4692
- # "gather.mlir",
+ "gather.mlir",
"iota.mlir",
"log.mlir",
"log_plus_one.mlir",
@@ -294,7 +294,7 @@
"subtract.mlir",
"tanh.mlir",
# https://github.com/google/iree/issues/4079
- # "torch_index_select.mlir",
+ "torch_index_select.mlir",
"transpose.mlir",
"while.mlir",
],
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index b8753bc..a04dd67 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -183,6 +183,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "concatenate.mlir"
"constant.mlir"
"cosine.mlir"
"divide.mlir"
@@ -191,6 +192,7 @@
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
+ "gather.mlir"
"iota.mlir"
"log.mlir"
"log_plus_one.mlir"
@@ -209,6 +211,7 @@
"sqrt.mlir"
"subtract.mlir"
"tanh.mlir"
+ "torch_index_select.mlir"
"transpose.mlir"
"while.mlir"
TARGET_BACKEND
@@ -232,6 +235,7 @@
"broadcast_in_dim.mlir"
"clamp.mlir"
"compare.mlir"
+ "concatenate.mlir"
"constant.mlir"
"cosine.mlir"
"divide.mlir"
@@ -240,6 +244,7 @@
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
+ "gather.mlir"
"iota.mlir"
"log.mlir"
"log_plus_one.mlir"
@@ -258,6 +263,7 @@
"sqrt.mlir"
"subtract.mlir"
"tanh.mlir"
+ "torch_index_select.mlir"
"transpose.mlir"
"while.mlir"
TARGET_BACKEND