[Codegen] Use linearize_index op when swapping slice and expand (#19730)

This PR replaces the affine.apply used for index computation in
`SwapExpandShapeWithSlicePattern` with an `affine.linearize_index` op.
This is a more canonical form, and will make CSE easier with index
computation generated by other similar patterns.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 99a46be..8dfe361 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -433,8 +433,8 @@
 
 // THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice
 //       THREAD:   scf.forall (%[[X:[A-Za-z0-9]+]], %[[Y:[A-Za-z0-9]+]], %[[Z:[A-Za-z0-9]+]])
-//       THREAD:     %[[APPLY:.+]] = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 30 + d2 * 10)>(%[[Z]], %[[X]], %[[Y]])
-//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+//       THREAD:     %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
 //       THREAD:     %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
 //       THREAD:     linalg.exp {{.*}} ins(%[[EXPAND]]
 
@@ -451,9 +451,10 @@
 }
 
 // THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+//       THREAD:   %[[C0:.+]] = arith.constant 0 : index
 //       THREAD:   scf.forall (%[[X:[A-Za-z0-9]+]], %[[Y:[A-Za-z0-9]+]])
-//       THREAD:     %[[APPLY:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 40 + d1 * 10)>(%[[X]], %[[Y]])
-//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+//       THREAD:     %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
 //       THREAD:     %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
 //       THREAD:     linalg.exp {{.*}} ins(%[[EXPAND]]
 
@@ -488,10 +489,11 @@
 }
 
 // THREAD-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
+//       THREAD:   %[[C0:.+]] = arith.constant 0 : index
 //       THREAD:   scf.forall (%[[ID0:[A-Za-z0-9]+]], %[[ID1:[A-Za-z0-9]+]], %[[ID2:[A-Za-z0-9]+]], %[[ID3:[A-Za-z0-9]+]])
-//       THREAD:     %[[APPLY0:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 40 + d1 * 10)>(%[[ID0]], %[[ID1]])
-//       THREAD:     %[[APPLY1:.+]] = affine.apply affine_map<(d0, d1) -> (d0 * 8 + d1)>(%[[ID2]], %[[ID3]])
-//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[APPLY0]], %[[APPLY1]]] [20, 4] [1, 1]
+//       THREAD:     %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[ID0]], %[[ID1]], %[[C0]]] by (3, 4, 10)
+//       THREAD:     %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[ID2]], %[[ID3]]] by (7, 8)
+//       THREAD:     %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1]
 //       THREAD:     %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
 //       THREAD:     linalg.exp {{.*}} ins(%[[EXPAND]]
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index 79bf739..8deabcc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -57,11 +57,6 @@
     return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
                                                  {v1, v2});
   };
-  auto mulAdd = [&](OpFoldResult v1, OpFoldResult v2, OpFoldResult v3) {
-    auto mulMap = AffineMap::get(3, 0, {d0 * d1 + d2});
-    return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
-                                                 {v1, v2, v3});
-  };
 
   SmallVector<OpFoldResult> outputShape =
       getMixedValues(expandShapeOp.getStaticOutputShape(),
@@ -107,8 +102,8 @@
   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
   for (const ReassociationIndices &indices :
        expandShapeOp.getReassociationIndices()) {
-    OpFoldResult newOffset = rewriter.getIndexAttr(0);
     OpFoldResult newSize = rewriter.getIndexAttr(1);
+    SmallVector<OpFoldResult> basis, delinOffsets;
 
     int64_t i = 0;
     int64_t e = indices.size();
@@ -118,24 +113,32 @@
       if (!isConstantIntValue(sizes[expandedDim], 1))
         break;
 
-      newOffset =
-          mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+      basis.push_back(outputShape[expandedDim]);
+      delinOffsets.push_back(offsets[expandedDim]);
     }
 
     if (i != e) {
       int64_t expandedDim = indices[i];
-      newOffset =
-          mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+      basis.push_back(outputShape[expandedDim]);
+      delinOffsets.push_back(offsets[expandedDim]);
       newSize = sizes[expandedDim];
       i++;
     }
 
     for (; i < e; ++i) {
       OpFoldResult fullSize = outputShape[indices[i]];
-      newOffset = mul(newOffset, fullSize);
+      basis.push_back(fullSize);
+      delinOffsets.push_back(rewriter.getIndexAttr(0));
       newSize = mul(newSize, fullSize);
     }
-
+    SmallVector<Value> offsetVals =
+        llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
+          return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+        });
+    OpFoldResult newOffset = rewriter
+                                 .create<affine::AffineLinearizeIndexOp>(
+                                     loc, offsetVals, basis, /*disjoint=*/true)
+                                 .getResult();
     newOffsets.push_back(newOffset);
     newLengths.push_back(newSize);