[CodeGen] Fix the argument replacements in scf.forall op lowering. (#18613)

They should be scaled by tile sizes. Otherwise, we always access the
same memory chunk.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index d7a81fc..f09d3a6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -211,6 +211,14 @@
   Block *parentBlock = forallOp->getBlock();
   Block *remainingBlock =
       rewriter.splitBlock(parentBlock, Block::iterator(forallOp));
+  for (auto [id, step] : llvm::zip_equal(procId, mixedStep)) {
+    rewriter.setInsertionPointToEnd(parentBlock);
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    AffineExpr expr = s1 * s0;
+    id = affine::makeComposedFoldedAffineApply(rewriter, forallOp.getLoc(),
+                                               expr, {id, step});
+  }
   auto argReplacements =
       getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId);
   Block *loopBody = forallOp.getBody();
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
index cc8aa9d..fa56c6d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
@@ -191,6 +191,8 @@
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 32)>
 //      CHECK: hal.executable.export public @scf_forall_2D layout
 // CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
@@ -203,7 +205,9 @@
 //  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
 //  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
 //  CHECK-NOT:   scf.forall
-//      CHECK:   "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+//      CHECK:   %[[I:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
+//      CHECK:   %[[J:.+]] = affine.apply #[[MAP3]]()[%[[WG_ID_X]]]
+//      CHECK:   "use"(%[[I]], %[[J]])
 
 // -----
 
@@ -236,6 +240,7 @@
   }
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
 //      CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout
 // CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
@@ -246,10 +251,14 @@
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //      CHECK:   hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
 //      CHECK: func @scf_forall_2D_dynamic_tile_size()
+//  CHECK-DAG:   %[[STEP_Y:.+]] = hal.interface.constant.load {{.+}} ordinal(2)
+//  CHECK-DAG:   %[[STEP_X:.+]] = hal.interface.constant.load {{.+}} ordinal(3)
 //  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
 //  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
 //  CHECK-NOT:   scf.forall
-//      CHECK:   "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+//      CHECK:   %[[I:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]], %[[STEP_Y]]]
+//      CHECK:   %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]], %[[STEP_X]]]
+//      CHECK:   "use"(%[[I]], %[[J]])
 
 // -----
 
@@ -305,6 +314,7 @@
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
 //      CHECK: hal.executable.export public @scf_forall_4D layout
 // CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
@@ -329,6 +339,8 @@
 //  CHECK-DAG:   %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5)
 //  CHECK-DAG:   %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8)
 //  CHECK-DAG:   %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9)
+//  CHECK-DAG:   %[[STEP2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(10)
+//  CHECK-DAG:   %[[STEP3:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(11)
 //  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
 //  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
 //  CHECK-DAG:   %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
@@ -336,7 +348,11 @@
 //  CHECK-DAG:   %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
 //  CHECK-NOT:   scf.forall
 //      CHECK:   %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]])
-//      CHECK:   "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]])
+//      CHECK:   %[[I:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[STEP0]]]
+//      CHECK:   %[[J:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#1, %[[STEP1]]]
+//      CHECK:   %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]], %[[STEP2]]]
+//      CHECK:   %[[L:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_X]], %[[STEP3]]]
+//      CHECK:   "use"(%[[I]], %[[J]], %[[K]], %[[L]])
 
 // -----
 
@@ -364,6 +380,10 @@
     }
   }
 }
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3)>
+//  CHECK-DAG:  #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+//  CHECK-DAG:  #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 5)>
 //      CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout
 //  CHECK-DAG:   %[[C6:.+]] = arith.constant 6 : index
 //  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
@@ -378,7 +398,11 @@
 //  CHECK-DAG:   %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
 //  CHECK-NOT:   scf.forall
 //      CHECK:   %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
-//      CHECK:   "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1)
+//      CHECK:   %[[I:.+]] = affine.apply #[[MAP0]]()[%[[DELINEARIZE]]#0]
+//      CHECK:   %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+//      CHECK:   %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
+//      CHECK:   %[[L:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#1]
+//      CHECK:   "use"(%[[DELINEARIZE]]#2, %[[I]], %[[J]], %[[K]], %[[L]])
 
 // -----