Add support for converting rank-reducing subtensor ops to Flow dialect (#5515)
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
index 865ce1b..b918e14 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
@@ -77,9 +77,29 @@
Value source = subTensorOp.source();
SmallVector<Value, 4> sourceSizesVals = sizesVals;
sourceSizesVals[0] = rewriter.createOrFold<memref::DimOp>(loc, source, 0);
- rewriter.replaceOpWithNewOp<TensorSliceOp>(
- subTensorOp, subTensorOp.getType(), subTensorOp.source(),
- sourceSizesVals, offsetVals, sizesVals, sizesVals);
+
+ // Different from SubTensor op, a TensorSliceOp does not have
+ // rank-reducing behavior.
+ Type type = SubTensorOp::inferResultType(subTensorOp.getSourceType(),
+ offsets, sizes, strides);
+ Value tensorSliceOp = rewriter.create<TensorSliceOp>(
+ loc, type, subTensorOp.source(), sourceSizesVals, offsetVals, sizesVals,
+ sizesVals);
+
+ if (type == subTensorOp.getType()) {
+ // Not rank-reducing subtensor, can replace with it directly.
+ rewriter.replaceOp(subTensorOp, tensorSliceOp);
+ } else {
+ // Rank-reducing subtensor, need a reshape op.
+ SmallVector<Value, 4> sourceDynSizes, resultDynSizes;
+ auto sourceType = tensorSliceOp.getType().cast<RankedTensorType>();
+ for (auto i : llvm::seq<unsigned>(0, sourceType.getNumDynamicDims())) {
+ sourceDynSizes.push_back(rewriter.create<ConstantIndexOp>(
+ loc, sourceType.getDynamicDimIndex(i)));
+ }
+ rewriter.replaceOpWithNewOp<TensorReshapeOp>(
+ subTensorOp, subTensorOp.getType(), tensorSliceOp, sourceDynSizes);
+ }
return success();
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
index febeef4..e3212f0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
@@ -40,3 +40,24 @@
// CHECK-DAG: %[[UNMODIFIED2:.+]] = subtensor %[[SLICE2]][0, 0, 0] [%[[D1]], 12, 24] [1, 2, 2]
// CHECK-DAG: %[[UNMODIFIED3:.+]] = subtensor %[[ARG0]][0, %[[ARG1]], 0]
// CHECK: return %[[UNMODIFIED1]], %[[UNMODIFIED2]], %[[UNMODIFIED3]]
+
+// -----
+
+func @rank_reducing_subtensor(%arg0: tensor<2x513xi32>, %arg1: index,
+ %arg2: index) -> tensor<513xi32> {
+ %0 = subtensor %arg0[%arg1, %arg2] [1, 513] [1, 1] : tensor<2x513xi32> to tensor<513xi32>
+ return %0 : tensor<513xi32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C513:.+]] = constant 513 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG2]] for %[[C1]], %[[C513]]]
+// CHECK-SAME: : tensor<2x513xi32>{%[[C2]], %[[C513]]}
+// CHECK-SAME: -> tensor<1x513xi32>{%[[C1]], %[[C513]]}
+// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %[[SLICE]] : tensor<1x513xi32> -> tensor<513xi32>
+// CHECK: return %[[RESHAPE]] : tensor<513xi32>