[Codegen] Resolve constantOp with multiple layouts users. (#18354)

Main motivation is to handle distribution of constantOp who has users
with different layouts.

Original use case is to ensure we can distribute attention when the tile
size for M,K1,N is the same. Which means the init of 1st contract, and
IV's init uses the same constantOp.

Since constantOp can only hold a single layout, but multiple to_layout
ops with different layouts, for each user, there will be non resolved
to_layout op(s). only one of the to_layout op can be resolved properly,
the rest would be a "non trivial" resolution since layouts are
different.

To solve this issue, we introduce a mechanism that detect these cases
and make a copy of the arith.constant that get used by other users, when
we are trying to resolve for the current constantOp.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index d197d64..be2968d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -643,6 +643,44 @@
 
 // -----
 
+// This test is used to ensure that we are handling cases
+// where the same arith.constant has multiple users with different layouts.
+
+// Main motivation is to ensure we can distribute attention when the tile
+// size for M,K1,N is the same. Which means the init of 1st contract, and
+// IV's init uses the same constant.
+
+#layoutA = #iree_vector_ext.layout<<[BATCHY, LANEX], [2, 32]>, <[BATCHX,  LANEY,  VECTORX], [2, 4, 8]>>
+#layoutB = #iree_vector_ext.layout<<[BATCHY, LANEX], [2, 32]>, <[BATCHX,  VECTORY,  LANEY,  VECTORX], [2, 4, 2, 4]>>
+
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @resolve_constant_with_multiple_layout_uses(%A : vector<64x64xf16>, %B : vector<64x64xf16>) -> vector<64x64xf16> {
+    %a = iree_vector_ext.to_layout %A to #layoutA : vector<64x64xf16>
+    %b = iree_vector_ext.to_layout %B to #layoutB : vector<64x64xf16>
+    %zero = arith.constant dense<0.0> : vector<64x64xf16>
+    %add_0 = arith.addf %a, %zero : vector<64x64xf16>
+    %add_1 = arith.addf %b, %zero : vector<64x64xf16>
+    %layout_change = iree_vector_ext.to_layout %add_1 to #layoutA : vector<64x64xf16>
+    %out = arith.addf %layout_change, %add_0 : vector<64x64xf16>
+    func.return %out : vector<64x64xf16>
+  }
+// CHECK-LABEL: func.func @resolve_constant_with_multiple_layout_uses
+// CHECK-SAME: (%[[ARG0:.+]]: vector<64x64xf16>, %[[ARG0:.+]]: vector<64x64xf16>)
+// CHECK: %[[V0:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x16xf16>
+// CHECK: %[[V1:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf16>
+// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[V1]]{{.*}} : vector<2x2x8xf16>
+// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[V0]]{{.*}} : vector<2x2x16xf16>
+// CHECK: arith.addf %{{.+}}, %[[ADD0]]{{.*}} : vector<2x2x8xf16>
+
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 #row_layout = #iree_vector_ext.per_dim_layout<[BATCHX, LANEY, VECTORX], [2, 4, 4]>
 #col_layout = #iree_vector_ext.per_dim_layout<[BATCHY, LANEX], [1, 16]>
 #layout0 = #iree_vector_ext.layout<#row_layout, #col_layout>
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 03ee001..fc1a1d2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -198,6 +198,19 @@
 
 ChangeResult DistributionLayout::resolveWithPossibleConflict(
     const VectorLayoutInterface &rhs, OpOperand &opOperand) {
+
+  IRRewriter builder(opOperand.getOwner());
+  // Handle case where constantOp may have multiple consumers with different
+  // layouts by creating a copy of constOp for other users.
+  if (!opOperand.get().hasOneUse() && !vectorLayout &&
+      llvm::dyn_cast_or_null<arith::ConstantOp>(
+          opOperand.get().getDefiningOp())) {
+    Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp());
+    Value copiedConst = copiedConstOp->getResult(0);
+    builder.replaceAllUsesExcept(opOperand.get(), copiedConst,
+                                 opOperand.getOwner());
+  }
+
   ResolutionResult result = doResolution(rhs);
 
   // If there is no conflict, simply return.
@@ -210,7 +223,6 @@
 
   // Resolve conflict by create an operation that takes the input the conflicted
   // value and returns the resolved value.
-  OpBuilder builder(opOperand.getOwner());
   Value input = opOperand.get();
   // Create a resolution operation. This conflict should be handeled later by
   // someone else, not this analysis.