[LLVMGPUVectorDistribute] Refactor vector.contract distribute (#19631)

Currently, vector.contract distribution is implemented as a standalone
distribution closely following vector.multi_reduce. Therefore, we have
to duplicate code/effort when we improve either one.

This commit changes vector.contract just to distribute the "contract"
part of it. Then it creates a new vector.multi_reduce to be
re-distributed with partial reduction semantics. Thus, allowing the
improvements of vector.multi_reduce to be re-used by vector.contract

closes : https://github.com/iree-org/iree/issues/19620

---------

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 49981f1..17b9a5b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -442,13 +442,6 @@
     }
 
     Type elemTy = srcVector.getType().getElementType();
-    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
-    if (elemBitwidth != maxBitsPerShuffle) {
-      return rewriter.notifyMatchFailure(
-          multiReduceOp, llvm::formatv("unimplemented: packed shuffle",
-                                       elemBitwidth, maxBitsPerShuffle));
-    }
-
     VectorValue disSrc =
         getDistributed(rewriter, srcVector, signature[srcVector]);
 
@@ -770,24 +763,18 @@
   int64_t maxBitsPerShuffle;
 };
 
-/// The lowering for Contract is performed in three steps (similar to above
-/// multi_reduction):
-///   1. Local Contract: Each thread performs operations on its locally
-///   distributed elements.
-///   2. Subgroup Reduction: Threads in each subgroup reduce the results from
-///   step 1 across threads using a subgroup reduction if distribution occurs
-///   along the reduction dimension.
-///   3. Accumulator Reduction: Each thread combines its intermediate results
-///   with its held accumulator.
-///
-/// Currently, reduction across multiple warps is not supported.
+/// The distribution of contract is performed by doing a local contraction where
+/// each thread performs operations on its locally distributed elements. Then,
+/// the resulting vector is interpreted in undistributed domain. The said
+/// undistributed vector is a partial reduction when contraction has been
+/// performed only thread locally. Therefore, a to-be-distributed
+/// vector.multi_reduce
+////is added to complete the contraction.
 struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
-  DistributeContract(MLIRContext *context, int64_t subgroupSize,
-                     int64_t maxBitsPerShuffle, int64_t benefit = 1)
-      : OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
-        maxBitsPerShuffle(maxBitsPerShuffle) {}
+  DistributeContract(MLIRContext *context, int64_t benefit = 1)
+      : OpDistributionPattern(context, benefit) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 DistributionSignature &signature,
@@ -817,6 +804,16 @@
       return rewriter.notifyMatchFailure(
           contractOp, "missing nested layout for contraction rhs");
     }
+    NestedLayoutAttr resLayout;
+    if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
+      resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
+    } else {
+      // Create a zero-d layout because we
+      // are going to add reduction dims
+      // back to handle the partial reduction
+      resLayout = NestedLayoutAttr::get(
+          contractOp.getContext(), ArrayRef<int64_t>{}, {}, {}, {}, {}, {}, {});
+    }
 
     Value disLhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
     Value disRhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
@@ -838,21 +835,10 @@
     Location loc = contractOp.getLoc();
 
     // Step 1: local contraction
+    Value localInit = getCombiningIdentityValue(
+        loc, rewriter, contractOp.getKind(), disAcc.getType());
     vector::ContractionOp localContractOp = doDistributedContraction(
-        rewriter, loc, ctx, contractOp, disLhs, disRhs, disAcc);
-
-    int64_t rank = lhsLayout.getRank();
-    SmallVector<bool> reducedDims(rank, false);
-
-    // Identify the reduction dimension and apply it for subgroup reduction.
-    for (auto [index, iteratorType] :
-         llvm::enumerate(contractOp.getIteratorTypes())) {
-      if (vector::isReductionIterator(iteratorType)) {
-        auto map = contractOp.getIndexingMapsArray()[0];
-        int64_t redIdx = *(map.getResultPosition(getAffineDimExpr(index, ctx)));
-        reducedDims[redIdx] = true;
-      }
-    }
+        rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);
 
     VectorValue localContractValue;
     if (accVector) {
@@ -865,46 +851,79 @@
 
     assert(localContractValue && "result should have been a vector");
 
-    // Flatten the locally result value.
-    VectorType shaped = localContractValue.getType();
-    int64_t numElements = shaped.getNumElements();
-    SmallVector<int64_t> flatShape(1, numElements);
-    VectorType flatVecType = VectorType::get(flatShape, accElemTy);
-    VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
-                                                            localContractValue);
-
-    // Step 2: Do subgroup reduction.
-    FailureOr<VectorValue> threadReduced = doThreadReduction(
-        rewriter, lhsLayout, flat, contractOp.getKind(), reducedDims);
-    if (failed(threadReduced)) {
-      return failure();
+    // Identify the reduction dimension and apply it for subgroup reduction.
+    auto lhsMap = contractOp.getIndexingMapsArray()[0];
+    SmallVector<int64_t> reductionSubGroupTile;
+    SmallVector<int64_t> reductionSubGroupStrides;
+    SmallVector<int64_t> reductionThreadTile;
+    SmallVector<int64_t> reductionThreadStrides;
+    SmallVector<int64_t> partialReductionDims;
+    for (auto [index, iteratorType] :
+         llvm::enumerate(contractOp.getIteratorTypes())) {
+      if (vector::isReductionIterator(iteratorType)) {
+        int64_t redLhsIdx =
+            *(lhsMap.getResultPosition(getAffineDimExpr(index, ctx)));
+        partialReductionDims.push_back(resLayout.getRank() +
+                                       reductionSubGroupTile.size());
+        reductionSubGroupTile.push_back(lhsLayout.getSubgroupTile()[redLhsIdx]);
+        reductionSubGroupStrides.push_back(
+            lhsLayout.getSubgroupStrides()[redLhsIdx]);
+        reductionThreadTile.push_back(lhsLayout.getThreadTile()[redLhsIdx]);
+        reductionThreadStrides.push_back(
+            lhsLayout.getThreadStrides()[redLhsIdx]);
+      }
     }
+    SmallVector<int64_t> unitBroadcastTile(reductionThreadTile.size(), 1);
 
-    // Do reduction against accumulator, which needs to be done after thread
-    // reduction.
-    VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
-        loc, shaped, threadReduced.value());
+    // Manually infer the layout of partial reduction
+    // We do this by appending the reduction dims on
+    // subgroup and thread tiles to the layout of the
+    // result.
+    IREE::VectorExt::NestedLayoutAttr reductionLayout =
+        IREE::VectorExt::NestedLayoutAttr::get(
+            contractOp.getContext(),
+            /*source=*/resLayout,
+            /*appendSubGroupLens=*/reductionSubGroupTile,
+            /*appendBatchLens=*/unitBroadcastTile,
+            /*appendOuterLens=*/unitBroadcastTile,
+            /*appendThreadLens=*/reductionThreadTile,
+            /*appendElementLens=*/unitBroadcastTile,
+            /*appendSubgroupStrides=*/reductionSubGroupStrides,
+            /*appendThreadStrides=*/reductionThreadStrides);
 
-    if (!accVector) {
-      disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
+    VectorType partialReducedDistributedType =
+        VectorType::get(reductionLayout.getDistributedShape(),
+                        localContractValue.getType().getElementType());
+    Value shapeCasted = rewriter.create<vector::ShapeCastOp>(
+        loc, partialReducedDistributedType, localContractValue);
+    VectorType unDistributedType =
+        VectorType::get(reductionLayout.getUndistributedShape(),
+                        localContractValue.getType().getElementType());
+    Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+        loc, unDistributedType, shapeCasted);
+
+    // Create the partial reduction
+    auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
+        loc, contractOp.getKind(), undistrLocalReduced, acc,
+        partialReductionDims);
+    {
+      auto unitAttr = UnitAttr::get(rewriter.getContext());
+      auto reduceAttrs =
+          SmallVector<Attribute>(partialReduction->getNumOperands(), unitAttr);
+      reduceAttrs[0] = reductionLayout;
+      ArrayAttr reduceResultsAttr =
+          ArrayAttr::get(rewriter.getContext(), {unitAttr});
+      if (auto dstLayout =
+              dyn_cast_or_null<NestedLayoutAttr>(signature[resVector])) {
+        reduceAttrs[1] = dstLayout;
+        reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
+      }
+      ArrayAttr reduceOperandsAttr =
+          ArrayAttr::get(rewriter.getContext(), reduceAttrs);
+      setSignatureForRedistribution(rewriter, partialReduction.getOperation(),
+                                    reduceOperandsAttr, reduceResultsAttr);
     }
-
-    // Step 3: Accumulator Reduction
-    Value accReduction = vector::makeArithReduction(
-        rewriter, loc, contractOp.getKind(), unflattened, disAcc);
-    auto accReduced = dyn_cast<VectorValue>(accReduction);
-    if (!accReduced) {
-      return failure();
-    }
-
-    if (resVector) {
-      replaceOpWithDistributedValues(rewriter, contractOp, accReduced);
-    } else {
-      Value accReducedVal = rewriter.create<vector::ExtractOp>(
-          loc, accReduction, SmallVector<int64_t>{0});
-      replaceOpWithDistributedValues(rewriter, contractOp, accReducedVal);
-    }
-
+    rewriter.replaceOp(contractOp, partialReduction);
     return success();
   }
 
@@ -954,46 +973,6 @@
 
     return localContractOp;
   }
-
-  FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
-                                           NestedLayoutAttr layout,
-                                           VectorValue flat,
-                                           vector::CombiningKind kind,
-                                           ArrayRef<bool> reductionMask) const {
-    VectorType flatVecType = flat.getType();
-    int64_t numElements = flatVecType.getNumElements();
-    Location loc = flat.getLoc();
-
-    auto constOp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(flatVecType));
-    auto res = llvm::cast<VectorValue>(constOp.getResult());
-
-    for (unsigned i = 0; i < numElements; ++i) {
-      Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);
-
-      // Reduce across all reduction dimensions 1-by-1.
-      for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
-        if (reductionMask[i]) {
-          int64_t offset = getShuffleOffset(layout, i);
-          int64_t width = getShuffleWidth(layout, i);
-          assert(offset <= std::numeric_limits<uint32_t>::max() &&
-                 width <= std::numeric_limits<uint32_t>::max());
-
-          extracted = rewriter.create<gpu::SubgroupReduceOp>(
-              loc, extracted, combiningKindToAllReduce(kind),
-              /*uniform=*/false, /*cluster_size=*/width,
-              /*cluster_stride=*/offset);
-        }
-      }
-
-      res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
-    }
-
-    return res;
-  }
-
-  int64_t subgroupSize;
-  int64_t maxBitsPerShuffle;
 };
 
 struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
@@ -1344,8 +1323,7 @@
   patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
   patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
                                          maxBitsPerShuffle);
-  patterns.add<DistributeContract>(patterns.getContext(), subgroupSize,
-                                   maxBitsPerShuffle);
+  patterns.add<DistributeContract>(patterns.getContext());
   patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
   patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
 }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index d4062bb..e1bc3cc 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -112,6 +112,19 @@
   return shape;
 }
 
+SmallVector<int64_t> NestedLayoutAttr::getUndistributedShape() const {
+  int64_t rank = getRank();
+  SmallVector<int64_t> shape;
+  shape.reserve(rank);
+  for (int64_t i : llvm::seq<int64_t>(rank)) {
+    int64_t expectedDimLen = getSubgroupTile()[i] * getBatchTile()[i] *
+                             getOuterTile()[i] * getThreadTile()[i] *
+                             getElementTile()[i];
+    shape.push_back(expectedDimLen);
+  }
+  return shape;
+}
+
 // Gets the rank of the undistributed vector for this layout.
 int64_t NestedLayoutAttr::getRank() const {
   // The layout requires that all size lists are the same length and match
@@ -198,6 +211,42 @@
                    normalizedThreadStrides);
 }
 
+static SmallVector<int64_t> appendDims(ArrayRef<int64_t> tileLens,
+                                       ArrayRef<int64_t> appendLens) {
+  SmallVector<int64_t> tileLensResult = llvm::to_vector(tileLens);
+  tileLensResult.insert(tileLensResult.end(), appendLens.begin(),
+                        appendLens.end());
+  return tileLensResult;
+}
+
+NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context,
+                                       NestedLayoutAttr source,
+                                       ArrayRef<int64_t> appendSubGroupLens,
+                                       ArrayRef<int64_t> appendBatchLens,
+                                       ArrayRef<int64_t> appendOuterLens,
+                                       ArrayRef<int64_t> appendThreadLens,
+                                       ArrayRef<int64_t> appendElementLens,
+                                       ArrayRef<int64_t> appendSubgroupStrides,
+                                       ArrayRef<int64_t> appendThreadStrides) {
+  SmallVector<int64_t> subgroupTile =
+      appendDims(source.getSubgroupTile(), appendSubGroupLens);
+  SmallVector<int64_t> batchTile =
+      appendDims(source.getBatchTile(), appendBatchLens);
+  SmallVector<int64_t> outerTile =
+      appendDims(source.getOuterTile(), appendOuterLens);
+  SmallVector<int64_t> threadTile =
+      appendDims(source.getThreadTile(), appendThreadLens);
+  SmallVector<int64_t> elementTile =
+      appendDims(source.getElementTile(), appendElementLens);
+  SmallVector<int64_t> subgroupStrides =
+      appendDims(source.getSubgroupStrides(), appendSubgroupStrides);
+  SmallVector<int64_t> threadStrides =
+      appendDims(source.getThreadStrides(), appendThreadStrides);
+  return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile,
+                               threadTile, elementTile, subgroupStrides,
+                               threadStrides);
+}
+
 LogicalResult NestedLayoutAttr::verify(
     llvm::function_ref<InFlightDiagnostic()> emitError,
     ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index 01cf309..16bad7f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -218,7 +218,15 @@
                      "ArrayRef<int64_t>":$threadTile,
                      "ArrayRef<int64_t>":$elementTile,
                      "ArrayRef<int64_t>":$subgroupStrides,
-                     "ArrayRef<int64_t>":$threadStrides)>
+                     "ArrayRef<int64_t>":$threadStrides)>,
+    AttrBuilder<(ins "NestedLayoutAttr":$source,
+                     "ArrayRef<int64_t>":$appendSubGroupLens,
+                     "ArrayRef<int64_t>":$appendBatchLens,
+                     "ArrayRef<int64_t>":$appendOuterLens,
+                     "ArrayRef<int64_t>":$appendThreadLens,
+                     "ArrayRef<int64_t>":$appendElementLens,
+                     "ArrayRef<int64_t>":$appendSubgroupStrides,
+                     "ArrayRef<int64_t>":$appendThreadStrides)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
index fdca63f..b5181b8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -36,6 +36,12 @@
       /*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
     >,
     InterfaceMethod<
+      /*description=*/"Get the expected undistributed shape for the given vector type.",
+      /*retTy=*/"SmallVector<int64_t>",
+      /*methodName=*/"getUndistributedShape",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
       /*description=*/"Get the distributed shape for the given vector type.",
       /*retTy=*/"SmallVector<int64_t>",
       /*methodName=*/"getDistributedShape",