Fold flow.tensor.slice when all operands are constant (#2972)
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index fe130ff..3ea0f07 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
@@ -342,11 +343,51 @@
return operand();
}
+// Slices tensor from start to (start + length) exclusively at dim.
+static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim,
+ uint64_t start, uint64_t length) {
+ auto shape = llvm::to_vector<4>(tensor.getType().getShape());
+ if (length == shape[dim]) {
+ // No need to slice.
+ return tensor;
+ }
+ auto outputShape = shape;
+ outputShape[dim] = length;
+ auto outputType =
+ RankedTensorType::get(outputShape, getElementTypeOrSelf(tensor));
+ llvm::SmallVector<Attribute, 4> newContents;
+ newContents.reserve(outputType.getNumElements());
+ auto valuesBegin = tensor.getValues<Attribute>().begin();
+ int64_t step =
+ std::accumulate(shape.rbegin(), shape.rbegin() + shape.size() - dim,
+ /*init=*/1, /*op=*/std::multiplies<int64_t>());
+ int64_t num = length * step / shape[dim];
+ for (int64_t offset = step / shape[dim] * start,
+ numElements = tensor.getType().getNumElements();
+ offset < numElements; offset += step) {
+ newContents.append(valuesBegin + offset, valuesBegin + offset + num);
+ }
+ return DenseElementsAttr::get(outputType, newContents);
+}
+
OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) {
- if (operands[0] && operands[1] && operands[2]) {
+ if (llvm::count(operands, nullptr) == 0) {
// Fully constant arguments so we can perform the slice here.
- // TODO(benvanik): constant slice.
- return {};
+ auto tensor = operands[0].cast<ElementsAttr>();
+ int64_t rank = source().getType().cast<ShapedType>().getRank();
+ // start = operands[1:1+rank), and length = operands[1+rank:].
+ auto start = llvm::to_vector<4>(llvm::map_range(
+ operands.drop_front(1).drop_back(rank), [](Attribute value) {
+ return value.cast<IntegerAttr>().getValue().getZExtValue();
+ }));
+ auto length = llvm::to_vector<4>(
+ llvm::map_range(operands.drop_front(1 + rank), [](Attribute value) {
+ return value.cast<IntegerAttr>().getValue().getZExtValue();
+ }));
+ for (int64_t dim = 0; dim < rank; ++dim) {
+ tensor = tensorSlice(tensor, dim, start[dim], length[dim]);
+ }
+ return tensor;
}
return {};
}
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 3bd181a..24fc4c1 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -113,7 +113,85 @@
// -----
-// TODO(benvanik): const folder for slice.
+// CHECK-LABEL: @sliceConst0D
+func @sliceConst0D() -> tensor<i32> {
+ %0 = constant dense<0> : tensor<i32>
+ // CHECK-NEXT: %[[C:.+]] = constant dense<0> : tensor<i32>
+ %1 = flow.tensor.slice %0[for] : tensor<i32> -> tensor<i32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<i32>
+}
+
+// CHECK-LABEL: @sliceConst1D
+func @sliceConst1D() -> tensor<1xi32> {
+ %0 = constant dense<0> : tensor<1xi32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<0> : tensor<1xi32>
+ %1 = flow.tensor.slice %0[%c0 for %c1] : tensor<1xi32> -> tensor<1xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<1xi32>
+}
+
+// CHECK-LABEL: @sliceConst1DZeroLength
+func @sliceConst1DZeroLength() -> tensor<0xi32> {
+ %0 = constant dense<0> : tensor<1xi32>
+ %c0 = constant 0 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<0xi32>
+ %1 = flow.tensor.slice %0[%c0 for %c0] : tensor<1xi32> -> tensor<0xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<0xi32>
+}
+
+// CHECK-LABEL: @sliceConst2D
+func @sliceConst2D() -> tensor<1x2xi32> {
+ %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<[
+ // CHECK-SAME: [1, 2]
+ // CHECK-SAME: ]> : tensor<1x2xi32>
+ %1 = flow.tensor.slice %0[%c0, %c1 for %c1, %c2] : tensor<2x3xi32> -> tensor<1x2xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @sliceConst2DZeroLength1
+func @sliceConst2DZeroLength1() -> tensor<1x0xi32> {
+ %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<1x0xi32>
+ %1 = flow.tensor.slice %0[%c0, %c0 for %c1, %c0] : tensor<2x3xi32> -> tensor<1x0xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<1x0xi32>
+}
+
+// CHECK-LABEL: @sliceConst2DZeroLength01
+func @sliceConst2DZeroLength01() -> tensor<0x0xi32> {
+ %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+ %c0 = constant 0 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<0x0xi32>
+ %1 = flow.tensor.slice %0[%c0, %c0 for %c0, %c0] : tensor<2x3xi32> -> tensor<0x0xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<0x0xi32>
+}
+
+// CHECK-LABEL: @sliceConst3D
+func @sliceConst3D() -> tensor<1x2x3xi32> {
+ %0 = constant dense<[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[9, 10, 11], [12, 13, 14], [15, 16, 17]]]> : tensor<2x3x3xi32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ // CHECK-NEXT: %[[C:.+]] = constant dense<[
+ // CHECK-SAME: [
+ // CHECK-SAME: [3, 4, 5], [6, 7, 8]]]> : tensor<1x2x3xi32>
+ %1 = flow.tensor.slice %0[%c0, %c1, %c0 for %c1, %c2, %c3] : tensor<2x3x3xi32> -> tensor<1x2x3xi32>
+ // CHECK-NEXT: return %[[C]]
+ return %1 : tensor<1x2x3xi32>
+}
// -----