[Codegen] Use linearize_index op when swapping slice and expand (#19730)
This PR replaces the affine.apply used for index computation in
`SwapExpandShapeWithSlicePattern` with an `affine.linearize_index` op.
This is a more canonical form, and will make CSE easier with index
computation generated by other similar patterns.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 99a46be..8dfe361 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -433,8 +433,8 @@
// THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice
// THREAD: scf.forall (%[[X:[A-Za-z0-9]+]], %[[Y:[A-Za-z0-9]+]], %[[Z:[A-Za-z0-9]+]])
-// THREAD: %[[APPLY:.+]] = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 30 + d2 * 10)>(%[[Z]], %[[X]], %[[Y]])
-// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+// THREAD: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
// THREAD: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
// THREAD: linalg.exp {{.*}} ins(%[[EXPAND]]
@@ -451,9 +451,10 @@
}
// THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+// THREAD: %[[C0:.+]] = arith.constant 0 : index
// THREAD: scf.forall (%[[X:[A-Za-z0-9]+]], %[[Y:[A-Za-z0-9]+]])
-// THREAD: %[[APPLY:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 40 + d1 * 10)>(%[[X]], %[[Y]])
-// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+// THREAD: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
// THREAD: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
// THREAD: linalg.exp {{.*}} ins(%[[EXPAND]]
@@ -488,10 +489,11 @@
}
// THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
+// THREAD: %[[C0:.+]] = arith.constant 0 : index
// THREAD: scf.forall (%[[ID0:[A-Za-z0-9]+]], %[[ID1:[A-Za-z0-9]+]], %[[ID2:[A-Za-z0-9]+]], %[[ID3:[A-Za-z0-9]+]])
-// THREAD: %[[APPLY0:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 40 + d1 * 10)>(%[[ID0]], %[[ID1]])
-// THREAD: %[[APPLY1:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 8 + d1)>(%[[ID2]], %[[ID3]])
-// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY0]], %[[APPLY1]]] [20, 4] [1, 1]
+// THREAD: %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[ID0]], %[[ID1]], %[[C0]]] by (3, 4, 10)
+// THREAD: %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[ID2]], %[[ID3]]] by (7, 8)
+// THREAD: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1]
// THREAD: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
// THREAD: linalg.exp {{.*}} ins(%[[EXPAND]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index 79bf739..8deabcc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -57,11 +57,6 @@
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
{v1, v2});
};
- auto mulAdd = [&](OpFoldResult v1, OpFoldResult v2, OpFoldResult v3) {
- auto mulMap = AffineMap::get(3, 0, {d0 * d1 + d2});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2, v3});
- };
SmallVector<OpFoldResult> outputShape =
getMixedValues(expandShapeOp.getStaticOutputShape(),
@@ -107,8 +102,8 @@
SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
for (const ReassociationIndices &indices :
expandShapeOp.getReassociationIndices()) {
- OpFoldResult newOffset = rewriter.getIndexAttr(0);
OpFoldResult newSize = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> basis, delinOffsets;
int64_t i = 0;
int64_t e = indices.size();
@@ -118,24 +113,32 @@
if (!isConstantIntValue(sizes[expandedDim], 1))
break;
- newOffset =
- mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+ basis.push_back(outputShape[expandedDim]);
+ delinOffsets.push_back(offsets[expandedDim]);
}
if (i != e) {
int64_t expandedDim = indices[i];
- newOffset =
- mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+ basis.push_back(outputShape[expandedDim]);
+ delinOffsets.push_back(offsets[expandedDim]);
newSize = sizes[expandedDim];
i++;
}
for (; i < e; ++i) {
OpFoldResult fullSize = outputShape[indices[i]];
- newOffset = mul(newOffset, fullSize);
+ basis.push_back(fullSize);
+ delinOffsets.push_back(rewriter.getIndexAttr(0));
newSize = mul(newSize, fullSize);
}
-
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+ });
+ OpFoldResult newOffset = rewriter
+ .create<affine::AffineLinearizeIndexOp>(
+ loc, offsetVals, basis, /*disjoint=*/true)
+ .getResult();
newOffsets.push_back(newOffset);
newLengths.push_back(newSize);