[GPU][VectDist] Refactor multiReducOp lowering to reduce acc at the end. (#17974)

The main motivation behind this commit is to fix the numerics issue we
find in Attention GPU C++ pipeline. Where our output seems to be some
scale off from reference (i.e out = k * ref). Through experiments we
determine that the cause of the issue is the multiReducOp distribution,
which is required for scaling/merge computation.

Through analyzing of IR and experimentation, reducing the srcVector and
non constant accumulator at the same time with another multiDimReduction
seems to output numerically wrong values.

This commit fixes numerical issue by refactoring the distribution of
multireductionOp to 3 steps. First every thread locally reduce the
srcVector it holds with combiningIdentity as localInit. Second, it does
a subgroup/warp reduce amongs other threads. Finally, each thread does a
local reduction of the intermediate reduced data it has with the
accumulator it holds.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index cf2da5e..6aa994c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -9,6 +9,7 @@
 #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -394,7 +395,8 @@
 ///      the reduction dimensions. This is the batch, outer and element dims.
 ///   2. Thread Reduce: Each thread reduces result of step 1 across threads
 ///      by doing a butterfly shuffle.
-///
+///   3. Accumulator Reduce: Each thread reduces it's intermediate reduced
+///      results with the accumulator it holds.
 /// Currently, reduction across warps is not supported, but it would just add
 /// another step, Warp Reduce, where threads do an atomic addition on a buffer.
 struct DistributeMultiReduction final
@@ -454,9 +456,11 @@
     for (int i = 0; i < 3; ++i) {
       distributedReductionMask.append(reducedDims.begin(), reducedDims.end());
     }
-
+    Value localInit = getCombiningIdentityValue(
+        loc, rewriter, multiReduceOp.getKind(), disAcc.getType());
     auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
-        loc, disSrc, disAcc, distributedReductionMask, multiReduceOp.getKind());
+        loc, disSrc, localInit, distributedReductionMask,
+        multiReduceOp.getKind());
     auto locallyReduced = dyn_cast<VectorValue>(localReduction.getResult());
 
     assert(locallyReduced && "result should have been a vector");
@@ -469,15 +473,24 @@
     VectorValue flat =
         rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
 
+    // Do inter-thread/warp reduce.
     FailureOr<VectorValue> threadReduced = doThreadReduction(
         rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
     if (failed(threadReduced)) {
       return failure();
     }
 
+    // Do reduction against accumulator, which needs to be done after thread
+    // reduction.
     VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
         loc, shaped, threadReduced.value());
-    replaceOpWithDistributedValues(rewriter, multiReduceOp, unflattened);
+    Value accReduction = vector::makeArithReduction(
+        rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
+    auto accReduced = dyn_cast<VectorValue>(accReduction);
+    if (!accReduced) {
+      return failure();
+    }
+    replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
 
     return failure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index ae05af9..c407be2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -965,16 +965,19 @@
 // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
 // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
 // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
 // CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
 // CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
 // Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
 // Global reduction
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1xf32> -> vector<32xf32>
+// Accumulator reduction
+// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
+// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
 
 // -----
 
@@ -1016,3 +1019,5 @@
 // CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
 // Global reduction
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
+// Accumulator reduction
+// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index dbf3b89..dd3ca2a 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -504,9 +504,8 @@
 }
 
 /// Emit identity variable.
-static Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
-                                       vector::CombiningKind kind,
-                                       Type identityType) {
+Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
+                                vector::CombiningKind kind, Type identityType) {
   auto vectorType = llvm::dyn_cast<VectorType>(identityType);
   Type elementType = identityType;
   if (vectorType) {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index 74ef772..943d6a3 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -120,6 +120,9 @@
 Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput,
                      VectorType targetVecType);
 
+/// Emit identity constant based on combiningKind and type.
+Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
+                                vector::CombiningKind kind, Type identityType);
 //===----------------------------------------------------------------------===//
 // GPU CodeGen op filter
 //===----------------------------------------------------------------------===//