[CodeGen] Fix the argument replacements in scf.forall op lowering. (#18613)
They should be scaled by tile sizes. Otherwise, we always access the
same memory chunk.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index d7a81fc..f09d3a6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -211,6 +211,14 @@
Block *parentBlock = forallOp->getBlock();
Block *remainingBlock =
rewriter.splitBlock(parentBlock, Block::iterator(forallOp));
+ for (auto [id, step] : llvm::zip_equal(procId, mixedStep)) {
+ rewriter.setInsertionPointToEnd(parentBlock);
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ AffineExpr expr = s1 * s0;
+ id = affine::makeComposedFoldedAffineApply(rewriter, forallOp.getLoc(),
+ expr, {id, step});
+ }
auto argReplacements =
getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId);
Block *loopBody = forallOp.getBody();
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
index cc8aa9d..fa56c6d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
@@ -191,6 +191,8 @@
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 64)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: hal.executable.export public @scf_forall_2D layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
@@ -203,7 +205,9 @@
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-NOT: scf.forall
-// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[WG_ID_X]]]
+// CHECK: "use"(%[[I]], %[[J]])
// -----
@@ -236,6 +240,7 @@
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
// CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
@@ -246,10 +251,14 @@
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
// CHECK: func @scf_forall_2D_dynamic_tile_size()
+// CHECK-DAG: %[[STEP_Y:.+]] = hal.interface.constant.load {{.+}} ordinal(2)
+// CHECK-DAG: %[[STEP_X:.+]] = hal.interface.constant.load {{.+}} ordinal(3)
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-NOT: scf.forall
-// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+// CHECK: %[[I:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]], %[[STEP_Y]]]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]], %[[STEP_X]]]
+// CHECK: "use"(%[[I]], %[[J]])
// -----
@@ -305,6 +314,7 @@
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
// CHECK: hal.executable.export public @scf_forall_4D layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
@@ -329,6 +339,8 @@
// CHECK-DAG: %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5)
// CHECK-DAG: %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8)
// CHECK-DAG: %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9)
+// CHECK-DAG: %[[STEP2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(10)
+// CHECK-DAG: %[[STEP3:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(11)
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
@@ -336,7 +348,11 @@
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
// CHECK-NOT: scf.forall
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]])
-// CHECK: "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]])
+// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[STEP0]]]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#1, %[[STEP1]]]
+// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]], %[[STEP2]]]
+// CHECK: %[[L:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_X]], %[[STEP3]]]
+// CHECK: "use"(%[[I]], %[[J]], %[[K]], %[[L]])
// -----
@@ -364,6 +380,10 @@
}
}
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 5)>
// CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout
// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
@@ -378,7 +398,11 @@
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
// CHECK-NOT: scf.forall
// CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
-// CHECK: "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1)
+// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[DELINEARIZE]]#0]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
+// CHECK: %[[L:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#1]
+// CHECK: "use"(%[[DELINEARIZE]]#2, %[[I]], %[[J]], %[[K]], %[[L]])
// -----