Fold flow.tensor.slice when all operands are constant (#2972)

diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index fe130ff..3ea0f07 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -342,11 +343,51 @@
   return operand();
 }
 
+// Slices tensor from start to (start + length) exclusively at dim.
+static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim,
+                                uint64_t start, uint64_t length) {
+  auto shape = llvm::to_vector<4>(tensor.getType().getShape());
+  if (length == shape[dim]) {
+    // No need to slice.
+    return tensor;
+  }
+  auto outputShape = shape;
+  outputShape[dim] = length;
+  auto outputType =
+      RankedTensorType::get(outputShape, getElementTypeOrSelf(tensor));
+  llvm::SmallVector<Attribute, 4> newContents;
+  newContents.reserve(outputType.getNumElements());
+  auto valuesBegin = tensor.getValues<Attribute>().begin();
+  int64_t step =
+      std::accumulate(shape.rbegin(), shape.rbegin() + shape.size() - dim,
+                      /*init=*/1, /*op=*/std::multiplies<int64_t>());
+  int64_t num = length * step / shape[dim];
+  for (int64_t offset = step / shape[dim] * start,
+               numElements = tensor.getType().getNumElements();
+       offset < numElements; offset += step) {
+    newContents.append(valuesBegin + offset, valuesBegin + offset + num);
+  }
+  return DenseElementsAttr::get(outputType, newContents);
+}
+
 OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) {
-  if (operands[0] && operands[1] && operands[2]) {
+  if (llvm::count(operands, nullptr) == 0) {
     // Fully constant arguments so we can perform the slice here.
-    // TODO(benvanik): constant slice.
-    return {};
+    auto tensor = operands[0].cast<ElementsAttr>();
+    int64_t rank = source().getType().cast<ShapedType>().getRank();
+    // start = operands[1:1+rank), and length = operands[1+rank:].
+    auto start = llvm::to_vector<4>(llvm::map_range(
+        operands.drop_front(1).drop_back(rank), [](Attribute value) {
+          return value.cast<IntegerAttr>().getValue().getZExtValue();
+        }));
+    auto length = llvm::to_vector<4>(
+        llvm::map_range(operands.drop_front(1 + rank), [](Attribute value) {
+          return value.cast<IntegerAttr>().getValue().getZExtValue();
+        }));
+    for (int64_t dim = 0; dim < rank; ++dim) {
+      tensor = tensorSlice(tensor, dim, start[dim], length[dim]);
+    }
+    return tensor;
   }
   return {};
 }
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 3bd181a..24fc4c1 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -113,7 +113,85 @@
 
 // -----
 
-// TODO(benvanik): const folder for slice.
+// CHECK-LABEL: @sliceConst0D
+func @sliceConst0D() -> tensor<i32> {
+  %0 = constant dense<0> : tensor<i32>
+  // CHECK-NEXT: %[[C:.+]] = constant dense<0> : tensor<i32>
+  %1 = flow.tensor.slice %0[for] : tensor<i32> -> tensor<i32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<i32>
+}
+
+// CHECK-LABEL: @sliceConst1D
+func @sliceConst1D() -> tensor<1xi32> {
+  %0 = constant dense<0> : tensor<1xi32>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<0> : tensor<1xi32>
+  %1 = flow.tensor.slice %0[%c0 for %c1] : tensor<1xi32> -> tensor<1xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<1xi32>
+}
+
+// CHECK-LABEL: @sliceConst1DZeroLength
+func @sliceConst1DZeroLength() -> tensor<0xi32> {
+  %0 = constant dense<0> : tensor<1xi32>
+  %c0 = constant 0 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<0xi32>
+  %1 = flow.tensor.slice %0[%c0 for %c0] : tensor<1xi32> -> tensor<0xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<0xi32>
+}
+
+// CHECK-LABEL: @sliceConst2D
+func @sliceConst2D() -> tensor<1x2xi32> {
+  %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<[
+  // CHECK-SAME:     [1, 2]
+  // CHECK-SAME: ]> : tensor<1x2xi32>
+  %1 = flow.tensor.slice %0[%c0, %c1 for %c1, %c2] : tensor<2x3xi32> -> tensor<1x2xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @sliceConst2DZeroLength1
+func @sliceConst2DZeroLength1() -> tensor<1x0xi32> {
+  %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<1x0xi32>
+  %1 = flow.tensor.slice %0[%c0, %c0 for %c1, %c0] : tensor<2x3xi32> -> tensor<1x0xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<1x0xi32>
+}
+
+// CHECK-LABEL: @sliceConst2DZeroLength01
+func @sliceConst2DZeroLength01() -> tensor<0x0xi32> {
+  %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
+  %c0 = constant 0 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<> : tensor<0x0xi32>
+  %1 = flow.tensor.slice %0[%c0, %c0 for %c0, %c0] : tensor<2x3xi32> -> tensor<0x0xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<0x0xi32>
+}
+
+// CHECK-LABEL: @sliceConst3D
+func @sliceConst3D() -> tensor<1x2x3xi32> {
+  %0 = constant dense<[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[9, 10, 11], [12, 13, 14], [15, 16, 17]]]> : tensor<2x3x3xi32>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  // CHECK-NEXT: %[[C:.+]] = constant dense<[
+  // CHECK-SAME:                             [
+  // CHECK-SAME:                              [3, 4, 5], [6, 7, 8]]]> : tensor<1x2x3xi32>
+  %1 = flow.tensor.slice %0[%c0, %c1, %c0 for %c1, %c2, %c3] : tensor<2x3x3xi32> -> tensor<1x2x3xi32>
+  // CHECK-NEXT: return %[[C]]
+  return %1 : tensor<1x2x3xi32>
+}
 
 // -----