[VectorDistribution] Split layout configuration and distribution (#18065)
This patch splits the LLVMGPUVectorDistribution pass into two separate
passes, one that sets the layouts and one that distributes. This
improves the debugging experience and the failing IR can be checked for
the anchors.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 70cd714..784a945 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -95,6 +95,7 @@
"KernelConfig.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
+ "LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 361a8a9..a552962 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -80,6 +80,7 @@
"KernelConfig.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
+ "LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp
new file mode 100644
index 0000000..e576102
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp
@@ -0,0 +1,369 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <algorithm>
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-configure-vector-layouts"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+// Sets an anchoring layout for the given contraction op. Looks for a
+// supported mma type from the cached list of mma types and populates the
+// necessary distribution pattern for those contractions.
+LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+ RewriterBase &rewriter,
+ vector::ContractionOp contract) {
+ // TODO: Add SIMT fallback.
+ if (!schedule) {
+ return contract->emitError("missing mma schedule for contraction");
+ }
+
+ auto layouts = schedule.getContractionLayout(contract);
+ if (failed(layouts)) {
+ return contract->emitError("cannot get concrete layout for contraction");
+ }
+
+ auto [aLayout, bLayout, cLayout] = *layouts;
+ Location loc = contract.getLoc();
+
+ // Set layouts for lhs, rhs and acc.
+ rewriter.setInsertionPoint(contract);
+ Value layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, contract.getLhsType(), contract.getLhs(), aLayout);
+ Value layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, contract.getRhsType(), contract.getRhs(), bLayout);
+ Value layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, contract.getAccType(), contract.getAcc(), cLayout);
+ contract->setOperand(0, layoutedLhs);
+ contract->setOperand(1, layoutedRhs);
+ contract->setOperand(2, layoutedAcc);
+
+ // Set layout for result.
+ rewriter.setInsertionPointAfter(contract);
+ auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, contract.getResultType(), contract.getResult(), cLayout);
+ rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout);
+
+ // Set intrinsic kind.
+ contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic());
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "chosen a layout: " << aLayout << "\n";
+ llvm::dbgs() << "chosen b layout: " << bLayout << "\n";
+ llvm::dbgs() << "chosen c layout: " << cLayout << "\n";
+ llvm::dbgs() << "anchor set on contract: " << contract << "\n";
+ });
+
+ return success();
+}
+
+// Sets a layout anchor for reads from global memory.
+// The layout this generates is approximately the following:
+//
+// #layout = #iree_vector_ext.nested_layout<
+// subgroups_per_workgroup = [1, ..., 1]
+// batches_per_subgroup = [<remaining undistributed elements>]
+// outers_per_batch = [1, ..., 1]
+// threads_per_outer = [<greedy from innermost memref dim>]
+// elements_per_thread = [1, ..., 128/element_bitwidth, ..., 1]
+// innermost_memref_dimension ^^^^^^
+//
+// (All orders are the same)
+// *_order = [<broadcasted_dims>, <transfer_permutation>]>
+//
+// So for the following transfer_read with 64 threads:
+// vector.transfer_read ... : memref<16x256xf16>, vector<16x32xf16>
+//
+// We use the following layout:
+// #layout = #iree_vector_ext.nested_layout<
+// subgroups_per_workgroup = [1, 1]
+// batches_per_subgroup = [1, 1]
+// outers_per_batch = [1, 1]
+// threads_per_outer = [16, 4]
+// elements_per_thread = [1, 8]
+//
+// *_order = [0, 1]>
+LogicalResult setTransferReadAnchor(ArrayRef<int64_t> workgroupSize,
+ RewriterBase &rewriter,
+ vector::TransferReadOp transfer) {
+ MLIRContext *context = rewriter.getContext();
+
+ // Get the forward slice of the transfer to approximate whether it will take
+ // the layout of a contraction instead. Transfer_read ops used directly by a
+ // contraction (i.e. without a copy to shared memory in between) should take
+ // the layout of the contraction op. This is common for cases where the
+ // initial values of the accumulator in a linalg.matmul is read from memory
+ // instead of just being a zerofill.
+ ForwardSliceOptions forwardOptions;
+ forwardOptions.filter = [&](Operation *op) -> bool {
+ return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
+ };
+ BackwardSliceOptions backwardOptions;
+ backwardOptions.filter = [&](Operation *op) -> bool {
+ return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
+ };
+ SetVector<Operation *> slice =
+ getSlice(transfer, backwardOptions, forwardOptions);
+
+ if (llvm::any_of(slice, llvm::IsaPred<vector::ContractionOp>)) {
+ return success();
+ }
+
+ // Shared memory loads are expected to take the layout of the contraction.
+ auto sourceMemRefType = dyn_cast<MemRefType>(transfer.getSource().getType());
+ if (!sourceMemRefType || hasSharedMemoryAddressSpace(sourceMemRefType)) {
+ return success();
+ }
+
+ // Take on layout of broadcast.
+ if (transfer->hasOneUse() &&
+ dyn_cast<vector::BroadcastOp>(*transfer->getUsers().begin())) {
+ return success();
+ }
+
+ // TODO: Support masking.
+ if (transfer.getMask()) {
+ transfer->emitOpError(
+ "Anchoring on transfer_read with masks is not yet implemented.");
+ return failure();
+ }
+
+ int64_t bitWidth = IREE::Util::getTypeBitWidth(
+ getElementTypeOrSelf(transfer.getVectorType()));
+ if (!llvm::isPowerOf2_64(bitWidth) || bitWidth > 128) {
+ transfer->emitOpError(
+ "Anchoring on transfer_read with element type of bitwidth " +
+ std::to_string(bitWidth) + " is not yet implemented");
+ return failure();
+ }
+ int64_t numElementsPerThread = 128 / bitWidth;
+ int64_t flatNumElements =
+ ShapedType::getNumElements(transfer.getVectorType().getShape());
+ int64_t flatNumThreads = ShapedType::getNumElements(workgroupSize);
+ if (flatNumElements % flatNumThreads != 0) {
+ transfer->emitOpError()
+ << "Anchoring on transfer_read with unsupported number of elements "
+ "(not divisible by workgroup size)"
+ << ", number of elements: " << flatNumElements
+ << ", workgroup size: " << flatNumThreads;
+ return failure();
+ }
+ numElementsPerThread =
+ std::min(numElementsPerThread, flatNumElements / flatNumThreads);
+
+ AffineMap transferMap = transfer.getPermutationMap();
+ if (transferMap.getNumDims() == 0) {
+ transfer->emitOpError("Anchoring on transfer_read with zero-rank "
+ "permutation map is not supported.");
+ return failure();
+ }
+
+ // Select the innermost dim of the memref as the contiguous dim to load
+ // from.
+ int64_t transferRank = transfer.getVectorType().getRank();
+ std::optional<unsigned> maybeDim = transferMap.getResultPosition(
+ getAffineDimExpr(transferMap.getNumDims() - 1, context));
+ int64_t distXDim = maybeDim ? *maybeDim : transferRank - 1;
+
+ ArrayRef<int64_t> vectorShape = transfer.getVectorType().getShape();
+
+ // Limit the maximum inner vector read width to the innermost contiguous
+ // dimension. We could try to be clever and extend this to adjacent
+ // dimensions in cases where the innermost read vector dimension is small,
+ // but that requires comparing memref strides and is uncommon. For now
+ // prioritize warp contiguity over 128-bit read granularity.
+ numElementsPerThread = std::min(numElementsPerThread, vectorShape[distXDim]);
+
+ llvm::SetVector<unsigned> vectorDimDistributionOrder;
+ // Get the order in which to distribute vector dimensions to threads, going
+ // from innermost to outermost memref dimension. It's important to note
+ // that this heuristic only applies to matrix multiplication cases where
+ // we are promoting the operands of a contraction to shared memory and we
+ // have no producers fused with the matmul. In general there is no universal
+ // way to set an anchoring layout for reads without doing an analysis of how
+ // the read values are used.
+ for (int i = transferMap.getNumDims() - 1; i >= 0; --i) {
+ std::optional<unsigned> maybeDim =
+ transferMap.getResultPosition(getAffineDimExpr(i, context));
+ if (maybeDim) {
+ vectorDimDistributionOrder.insert(*maybeDim);
+ }
+ }
+ // Add all remaining (broadcasted) dimensions
+ for (auto dim : llvm::seq(static_cast<int64_t>(0), transferRank)) {
+ if (!vectorDimDistributionOrder.contains(dim))
+ vectorDimDistributionOrder.insert(dim);
+ }
+
+ int64_t residualThreads = flatNumThreads;
+ int64_t residualElements = numElementsPerThread;
+
+ SmallVector<int64_t> order(vectorDimDistributionOrder.rbegin(),
+ vectorDimDistributionOrder.rend());
+
+ // Distribute all threads in the workgroup to the "threads" dimension,
+ // meaning subgroup counts is unit here, even though the read is being
+ // distributed to multiple subgroups. This is in an attempt to do a
+ // workgroup contiguous load.
+ SmallVector<int64_t> subgroupCounts(transferRank, 1);
+ SmallVector<int64_t> batchSizes(transferRank, 1);
+ SmallVector<int64_t> outerSizes(transferRank, 1);
+ SmallVector<int64_t> threadCounts(transferRank, 1);
+ SmallVector<int64_t> elementSizes(transferRank, 1);
+
+ SmallVector<int64_t> subgroupStrides(transferRank, 1);
+ SmallVector<int64_t> threadStrides(transferRank, 1);
+
+ int64_t currStrides = 1;
+ for (auto dim : llvm::reverse(order)) {
+ int64_t vectorSize = vectorShape[dim];
+ // Set the element count for the innermost vector dimension.
+ if (residualElements != 1) {
+ elementSizes[dim] = residualElements;
+ vectorSize /= residualElements;
+ residualElements = 1;
+ }
+
+ assert((residualThreads % vectorSize == 0 ||
+ vectorSize % residualThreads == 0) &&
+ "dividing threads to incompatible vector");
+ if (residualThreads <= vectorSize) {
+ vectorSize /= residualThreads;
+ threadCounts[dim] = residualThreads;
+ threadStrides[dim] = currStrides;
+ currStrides *= residualThreads;
+ residualThreads = 1;
+ } else {
+ residualThreads /= vectorSize;
+ threadCounts[dim] = vectorSize;
+ threadStrides[dim] = currStrides;
+ currStrides *= vectorSize;
+ vectorSize = 1;
+ }
+
+ batchSizes[dim] = vectorSize;
+ }
+
+ auto layout = IREE::VectorExt::NestedLayoutAttr::get(
+ context, subgroupCounts, batchSizes, outerSizes, threadCounts,
+ elementSizes, subgroupStrides, threadStrides);
+
+ Location loc = transfer.getLoc();
+ rewriter.setInsertionPointAfter(transfer);
+ auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, transfer.getResult().getType(), transfer.getResult(), layout);
+ rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout);
+
+ return success();
+}
+
+struct LLVMGPUConfigureVectorLayoutsPass
+ : public LLVMGPUConfigureVectorLayoutsBase<
+ LLVMGPUConfigureVectorLayoutsPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
+ registry.insert<vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ auto func = getOperation();
+
+ std::array<int64_t, 3> workgroupSize;
+ if (func->hasAttr("workgroup_size")) {
+ auto tmpSizes =
+ llvm::cast<ArrayAttr>(func->getAttr("workgroup_size")).getValue();
+ for (auto [i, size] : llvm::enumerate(tmpSizes)) {
+ workgroupSize[i] = llvm::cast<IntegerAttr>(size).getInt();
+ }
+ } else {
+ std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
+ getWorkgroupSize(func);
+ if (!maybeWorkgroupSize) {
+ func->emitOpError()
+ << "unable to query workgroup_size information from entry point";
+ return signalPassFailure();
+ }
+ for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
+ workgroupSize[index] = value;
+ }
+ for (auto index : llvm::seq<size_t>(maybeWorkgroupSize->size(), 3)) {
+ workgroupSize[index] = 1;
+ }
+ }
+
+ llvm::StringLiteral scheduleAttrName =
+ IREE::GPU::MMAScheduleAttr::getMnemonic();
+ auto scheduleAttr =
+ func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
+ if (!scheduleAttr) {
+ DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
+ scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
+ configDict.get(scheduleAttrName));
+ }
+
+ // Vector layout option setter aimed at contractions. Currently this only
+ // sets anchors for two types of operations; vector.contract and
+ // vector.transfer_read from non-shared memory. The assumption in this case
+ // is that all IR input to this pass has a leaf rooted on a transfer_read or
+ // includes a contraction in the program slice, meaning all operations
+ // should receive layouts. Layout setting for other problems like reductions
+ // is TODO.
+ SmallVector<vector::TransferReadOp> reads;
+ SmallVector<vector::ContractionOp> contracts;
+
+ func->walk([&](Operation *op) {
+ llvm::TypeSwitch<Operation *>(op)
+ .Case([&](vector::TransferReadOp transfer) {
+ reads.push_back(transfer);
+ })
+ .Case([&](vector::ContractionOp contract) {
+ contracts.push_back(contract);
+ });
+ });
+
+ IRRewriter rewriter(func);
+
+ for (vector::TransferReadOp read : reads) {
+ if (failed(setTransferReadAnchor(workgroupSize, rewriter, read))) {
+ return signalPassFailure();
+ }
+ }
+
+ for (vector::ContractionOp contract : contracts) {
+ if (failed(setContractionAnchor(scheduleAttr, rewriter, contract))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMGPUConfigureVectorLayouts() {
+ return std::make_unique<LLVMGPUConfigureVectorLayoutsPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index 25adf1f..d9d5138 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -4,26 +4,15 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include <algorithm>
-
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
-#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/MathExtras.h"
-#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -33,347 +22,25 @@
#define DEBUG_TYPE "iree-llvmgpu-vector-distribute"
-using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension;
-using LayoutDimensionAttr =
- mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr;
-using VectorLayoutInterface =
- mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface;
-using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr;
-using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr;
-
namespace mlir::iree_compiler {
namespace {
-// Vector layout option setter aimed at contractions. Currently this only sets
-// anchors for two types of operations; vector.contract and vector.transfer_read
-// from non-shared memory. The assumption in this case is that all IR input to
-// this pass has a leaf rooted on a transfer_read or includes a contraction in
-// the program slice, meaning all operations should receive layouts. Layout
-// setting for other problems like reductions is TODO.
class ContractionVectorLayoutOptions : public VectorLayoutOptions {
public:
- ContractionVectorLayoutOptions(Operation *root,
- ArrayRef<int64_t> workgroupSize,
- IREE::GPU::MMAScheduleAttr schedule,
- Value laneId, int64_t subgroupSize,
- bool printLayout)
- : VectorLayoutOptions(root, /*fullConversion=*/!printLayout),
- workgroupSize(workgroupSize), schedule(schedule),
- printLayout(printLayout), patterns(root->getContext()) {
+ ContractionVectorLayoutOptions(Operation *root, Value laneId,
+ int64_t subgroupSize)
+ : VectorLayoutOptions(root), patterns(root->getContext()) {
populateGPUDistributionPatterns(patterns);
populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
subgroupSize);
- }
-
- LogicalResult setAnchorOps(RewriterBase &rewriter) {
- MLIRContext *context = root->getContext();
- SmallVector<vector::TransferReadOp> reads;
- SmallVector<vector::ContractionOp> contracts;
-
- root->walk([&](Operation *op) {
- llvm::TypeSwitch<Operation *>(op)
- .Case([&](vector::TransferReadOp transfer) {
- reads.push_back(transfer);
- })
- .Case([&](vector::ContractionOp contract) {
- contracts.push_back(contract);
- });
- });
-
- for (vector::TransferReadOp read : reads) {
- if (failed(setTransferReadAnchor(context, rewriter, read))) {
- return failure();
- }
- }
-
- for (vector::ContractionOp contract : contracts) {
- if (failed(setContractionAnchor(context, rewriter, contract))) {
- return failure();
- }
- }
-
- return success();
+ populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
}
RewritePatternSet &getPatterns() { return patterns; }
private:
- // Sets an anchoring layout for the given contraction op. Looks for a
- // supported mma type from the cached list of mma types and populates the
- // necessary distribution pattern for those contractions.
- LogicalResult setContractionAnchor(MLIRContext *context,
- RewriterBase &rewriter,
- vector::ContractionOp contract) {
- // TODO: Add SIMT fallback.
- if (!schedule) {
- return contract->emitError("missing mma schedule for contraction");
- }
-
- auto layouts = schedule.getContractionLayout(contract);
- if (failed(layouts)) {
- return contract->emitError("cannot get concrete layout for contraction");
- }
-
- auto [aLayout, bLayout, cLayout] = *layouts;
- Location loc = contract.getLoc();
-
- // Set layouts for lhs, rhs and acc.
- rewriter.setInsertionPoint(contract);
- Value layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, contract.getLhsType(), contract.getLhs(), aLayout);
- Value layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, contract.getRhsType(), contract.getRhs(), bLayout);
- Value layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, contract.getAccType(), contract.getAcc(), cLayout);
- contract->setOperand(0, layoutedLhs);
- contract->setOperand(1, layoutedRhs);
- contract->setOperand(2, layoutedAcc);
-
- // Set layout for result.
- rewriter.setInsertionPointAfter(contract);
- auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, contract.getResultType(), contract.getResult(), cLayout);
- rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout);
-
- // Set intrinsic kind.
- contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic());
-
- if (printLayout) {
- llvm::outs() << "contract A vector layout: " << aLayout << "\n";
- llvm::outs() << "contract B vector layout: " << bLayout << "\n";
- llvm::outs() << "contract C vector layout: " << cLayout << "\n";
- }
- LLVM_DEBUG({
- llvm::dbgs() << "chosen a layout: " << aLayout << "\n";
- llvm::dbgs() << "chosen b layout: " << bLayout << "\n";
- llvm::dbgs() << "chosen c layout: " << cLayout << "\n";
- llvm::dbgs() << "anchor set on contract: " << contract << "\n";
- });
-
- if (isa<IREE::GPU::MMAAttr>(schedule.getIntrinsic())) {
- if (!populatedMma) {
- populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
- populatedMma = true;
- }
- } else {
- llvm_unreachable("Unsupported mma type");
- }
- return success();
- }
-
- // Sets a layout anchor for reads from global memory.
- // The layout this generates is approximately the following:
- //
- // #layout = #iree_vector_ext.nested_layout<
- // subgroups_per_workgroup = [1, ..., 1]
- // batches_per_subgroup = [<remaining undistributed elements>]
- // outers_per_batch = [1, ..., 1]
- // threads_per_outer = [<greedy from innermost memref dim>]
- // elements_per_thread = [1, ..., 128/element_bitwidth, ..., 1]
- // innermost_memref_dimension ^^^^^^
- //
- // (All orders are the same)
- // *_order = [<broadcasted_dims>, <transfer_permutation>]>
- //
- // So for the following transfer_read with 64 threads:
- // vector.transfer_read ... : memref<16x256xf16>, vector<16x32xf16>
- //
- // We use the following layout:
- // #layout = #iree_vector_ext.nested_layout<
- // subgroups_per_workgroup = [1, 1]
- // batches_per_subgroup = [1, 1]
- // outers_per_batch = [1, 1]
- // threads_per_outer = [16, 4]
- // elements_per_thread = [1, 8]
- //
- // *_order = [0, 1]>
- LogicalResult setTransferReadAnchor(MLIRContext *context,
- RewriterBase &rewriter,
- vector::TransferReadOp transfer) {
-
- // Get the forward slice of the transfer to approximate whether it will take
- // the layout of a contraction instead. Transfer_read ops used directly by a
- // contraction (i.e. without a copy to shared memory in between) should take
- // the layout of the contraction op. This is common for cases where the
- // initial values of the accumulator in a linalg.matmul is read from memory
- // instead of just being a zerofill.
- ForwardSliceOptions forwardOptions;
- forwardOptions.filter = [&](Operation *op) -> bool {
- return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
- };
- BackwardSliceOptions backwardOptions;
- backwardOptions.filter = [&](Operation *op) -> bool {
- return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
- };
- SetVector<Operation *> slice =
- getSlice(transfer, backwardOptions, forwardOptions);
-
- if (llvm::any_of(slice, llvm::IsaPred<vector::ContractionOp>)) {
- return success();
- }
-
- // Shared memory loads are expected to take the layout of the contraction.
- auto sourceMemRefType =
- dyn_cast<MemRefType>(transfer.getSource().getType());
- if (!sourceMemRefType || hasSharedMemoryAddressSpace(sourceMemRefType)) {
- return success();
- }
-
- // Take on layout of broadcast.
- if (transfer->hasOneUse() &&
- dyn_cast<vector::BroadcastOp>(*transfer->getUsers().begin())) {
- return success();
- }
-
- // TODO: Support masking.
- if (transfer.getMask()) {
- transfer->emitOpError(
- "Anchoring on transfer_read with masks is not yet implemented.");
- return failure();
- }
-
- int64_t bitWidth = IREE::Util::getTypeBitWidth(
- getElementTypeOrSelf(transfer.getVectorType()));
- if (!llvm::isPowerOf2_64(bitWidth) || bitWidth > 128) {
- transfer->emitOpError(
- "Anchoring on transfer_read with element type of bitwidth " +
- std::to_string(bitWidth) + " is not yet implemented");
- return failure();
- }
- int64_t numElementsPerThread = 128 / bitWidth;
- int64_t flatNumElements =
- ShapedType::getNumElements(transfer.getVectorType().getShape());
- int64_t flatNumThreads = ShapedType::getNumElements(workgroupSize);
- if (flatNumElements % flatNumThreads != 0) {
- transfer->emitOpError()
- << "Anchoring on transfer_read with unsupported number of elements "
- "(not divisible by workgroup size)"
- << ", number of elements: " << flatNumElements
- << ", workgroup size: " << flatNumThreads;
- return failure();
- }
- numElementsPerThread =
- std::min(numElementsPerThread, flatNumElements / flatNumThreads);
-
- AffineMap transferMap = transfer.getPermutationMap();
- if (transferMap.getNumDims() == 0) {
- transfer->emitOpError("Anchoring on transfer_read with zero-rank "
- "permutation map is not supported.");
- return failure();
- }
-
- // Select the innermost dim of the memref as the contiguous dim to load
- // from.
- int64_t transferRank = transfer.getVectorType().getRank();
- std::optional<unsigned> maybeDim = transferMap.getResultPosition(
- getAffineDimExpr(transferMap.getNumDims() - 1, context));
- int64_t distXDim = maybeDim ? *maybeDim : transferRank - 1;
-
- ArrayRef<int64_t> vectorShape = transfer.getVectorType().getShape();
-
- // Limit the maximum inner vector read width to the innermost contiguous
- // dimension. We could try to be clever and extend this to adjacent
- // dimensions in cases where the innermost read vector dimension is small,
- // but that requires comparing memref strides and is uncommon. For now
- // prioritize warp contiguity over 128-bit read granularity.
- numElementsPerThread =
- std::min(numElementsPerThread, vectorShape[distXDim]);
-
- llvm::SetVector<unsigned> vectorDimDistributionOrder;
- // Get the order in which to distribute vector dimensions to threads, going
- // from innermost to outermost memref dimension. It's important to note
- // that this heuristic only applies to matrix multiplication cases where
- // we are promoting the operands of a contraction to shared memory and we
- // have no producers fused with the matmul. In general there is no universal
- // way to set an anchoring layout for reads without doing an analysis of how
- // the read values are used.
- for (int i = transferMap.getNumDims() - 1; i >= 0; --i) {
- std::optional<unsigned> maybeDim =
- transferMap.getResultPosition(getAffineDimExpr(i, context));
- if (maybeDim) {
- vectorDimDistributionOrder.insert(*maybeDim);
- }
- }
- // Add all remaining (broadcasted) dimensions
- for (auto dim : llvm::seq(static_cast<int64_t>(0), transferRank)) {
- if (!vectorDimDistributionOrder.contains(dim))
- vectorDimDistributionOrder.insert(dim);
- }
-
- int64_t residualThreads = flatNumThreads;
- int64_t residualElements = numElementsPerThread;
-
- SmallVector<int64_t> order(vectorDimDistributionOrder.rbegin(),
- vectorDimDistributionOrder.rend());
-
- // Distribute all threads in the workgroup to the "threads" dimension,
- // meaning subgroup counts is unit here, even though the read is being
- // distributed to multiple subgroups. This is in an attempt to do a
- // workgroup contiguous load.
- SmallVector<int64_t> subgroupCounts(transferRank, 1);
- SmallVector<int64_t> batchSizes(transferRank, 1);
- SmallVector<int64_t> outerSizes(transferRank, 1);
- SmallVector<int64_t> threadCounts(transferRank, 1);
- SmallVector<int64_t> elementSizes(transferRank, 1);
-
- SmallVector<int64_t> subgroupStrides(transferRank, 1);
- SmallVector<int64_t> threadStrides(transferRank, 1);
-
- int64_t currStrides = 1;
- for (auto dim : llvm::reverse(order)) {
- int64_t vectorSize = vectorShape[dim];
- // Set the element count for the innermost vector dimension.
- if (residualElements != 1) {
- elementSizes[dim] = residualElements;
- vectorSize /= residualElements;
- residualElements = 1;
- }
-
- assert((residualThreads % vectorSize == 0 ||
- vectorSize % residualThreads == 0) &&
- "dividing threads to incompatible vector");
- if (residualThreads <= vectorSize) {
- vectorSize /= residualThreads;
- threadCounts[dim] = residualThreads;
- threadStrides[dim] = currStrides;
- currStrides *= residualThreads;
- residualThreads = 1;
- } else {
- residualThreads /= vectorSize;
- threadCounts[dim] = vectorSize;
- threadStrides[dim] = currStrides;
- currStrides *= vectorSize;
- vectorSize = 1;
- }
-
- batchSizes[dim] = vectorSize;
- }
-
- auto layout = IREE::VectorExt::NestedLayoutAttr::get(
- context, subgroupCounts, batchSizes, outerSizes, threadCounts,
- elementSizes, subgroupStrides, threadStrides);
-
- Location loc = transfer.getLoc();
- rewriter.setInsertionPointAfter(transfer);
- auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, transfer.getResult().getType(), transfer.getResult(), layout);
- rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout);
-
- if (printLayout) {
- llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout
- << "\n";
- }
- return success();
- }
-
- SmallVector<int64_t, 3> workgroupSize;
- IREE::GPU::MMAScheduleAttr schedule;
- // Whether to print the chosen layout for testing purposes
- bool printLayout;
-
- bool populatedMma = false;
RewritePatternSet patterns;
};
@@ -413,16 +80,6 @@
}
}
- llvm::StringLiteral scheduleAttrName =
- IREE::GPU::MMAScheduleAttr::getMnemonic();
- auto scheduleAttr =
- func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
- if (!scheduleAttr) {
- DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
- scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
- configDict.get(scheduleAttrName));
- }
-
AffineExpr x, y, z;
bindSymbols(func.getContext(), x, y, z);
// Construct the expression for linearizing the thread indices.
@@ -449,15 +106,8 @@
return signalPassFailure();
}
- ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr,
- linearThreadIdVal,
- subgroupSize.value(), testLayout);
-
- // Set anchor layouts.
- if (failed(options.setAnchorOps(rewriter))) {
- func->emitError() << "failed to set anchors";
- return signalPassFailure();
- }
+ ContractionVectorLayoutOptions options(func, linearThreadIdVal,
+ subgroupSize.value());
if (failed(distributeVectorOps(func, options.getPatterns(), options))) {
func->emitOpError() << "failed to distribute";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 8918fb2..f2a9759 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -789,6 +789,7 @@
funcPassManager.addPass(createAMDGPUPrepareForChainedMatmulPass());
// Vector SIMD -> Vector SIMT
+ funcPassManager.addPass(createLLVMGPUConfigureVectorLayouts());
funcPassManager.addPass(createLLVMGPUVectorDistribute());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 488705f..fb84275 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -152,6 +152,10 @@
createLLVMGPUPromoteMatmulToFitMMAPass(
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims);
+// Pass to set layouts for vector distribution.
+std::unique_ptr<InterfacePass<FunctionOpInterface>>
+createLLVMGPUConfigureVectorLayouts();
+
enum class GPUTensorCoreType {
WMMA = 0,
MMA_SYNC = 1,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index b4176a5..8ea7da4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -60,6 +60,12 @@
let constructor = "mlir::iree_compiler::createLLVMGPUCastTypeToFitMMAPass()";
}
+def LLVMGPUConfigureVectorLayouts :
+ InterfacePass<"iree-llvmgpu-configure-vector-layouts", "mlir::FunctionOpInterface"> {
+ let summary = "Pass to set layouts for vector distribution";
+ let constructor = "mlir::iree_compiler::createLLVMGPUConfigureVectorLayouts()";
+}
+
def LLVMGPULowerExecutableTarget :
InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> {
let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
@@ -125,11 +131,6 @@
InterfacePass<"iree-llvmgpu-vector-distribute", "mlir::FunctionOpInterface"> {
let summary = "Pass to distribute vectorized functions.";
let constructor = "mlir::iree_compiler::createLLVMGPUVectorDistribute()";
- let options = [
- Option<"testLayout", "test-layout", "bool", /*default=*/"false",
- "Annotate vector ops with deduced layouts without real conversion "
- "for testing purposes">
- ];
}
def LLVMGPUVectorLowering :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 1757b5c..057ff1c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -74,7 +74,7 @@
"transpose_pipeline_test.mlir",
"ukernel_pipeline_transform.mlir",
"vector_distribute_conversion.mlir",
- "vector_distribute_layout.mlir",
+ "configure_vector_layout.mlir",
"vector_lowering.mlir",
"vector_to_gpu.mlir",
"winograd_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 2ff84aa..5f603cf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -23,6 +23,7 @@
"cast_type_to_fit_mma.mlir"
"config_matvec.mlir"
"config_winograd.mlir"
+ "configure_vector_layout.mlir"
"conv_pipeline_test_cuda.mlir"
"conv_pipeline_test_rocm.mlir"
"convert_to_nvvm.mlir"
@@ -70,7 +71,6 @@
"transpose_pipeline_test.mlir"
"ukernel_pipeline_transform.mlir"
"vector_distribute_conversion.mlir"
- "vector_distribute_layout.mlir"
"vector_lowering.mlir"
"vector_to_gpu.mlir"
"winograd_pipeline_test.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir
new file mode 100644
index 0000000..e67c011
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir
@@ -0,0 +1,445 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-configure-vector-layouts, canonicalize, cse))' %s | FileCheck %s
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// Since CHECK-SAME doesnt work with CHECK-DAG, we cannot have prettier tests.
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [32, 1]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [32, 1]>
+
+// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm
+func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+ return %0 : vector<96x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [32, 1]>
+
+// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt
+func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf32>
+ return %0 : vector<96x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 3], outers_per_batch = [1, 4], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
+
+// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmtt
+func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<64x96xf32>) -> vector<64x96xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, k) -> (n, m)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<64x96xf32>
+ return %0 : vector<64x96xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 2, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1]
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1]
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1]
+
+// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multisubgroup
+func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<192x16xf16>, vector<16x64xf16> into vector<192x64xf32>
+ return %0 : vector<192x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], subgroup_strides = [0, 0], thread_strides = [4, 1]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 8], subgroup_strides = [0, 0], thread_strides = [2, 1]>
+
+// CHECK-LABEL: func.func @matmul_16x16x256_read
+func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
+ attributes { translation_info = #translation } {
+ %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
+ %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
+ %cst = arith.constant 0.000000e+00 : f16
+ %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) {
+ %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
+ %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
+ // CHECK: %[[READ0:.+]] = vector.transfer_read
+ // CHECK: to_layout %[[READ0]] to #[[$NESTED]]
+ // CHECK: %[[READ1:.+]] = vector.transfer_read
+ // CHECK: to_layout %[[READ1]] to #[[$NESTED1]]
+ vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
+ gpu.barrier
+ vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
+ gpu.barrier
+ %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
+ %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %10 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
+ scf.yield %10 : vector<16x16xf32>
+ }
+ vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
+ memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space<workgroup>>
+ memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>
+ return
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], subgroup_strides = [0, 0], thread_strides = [4, 1]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1], subgroup_strides = [0, 0], thread_strides = [1, 4]>
+
+// CHECK-LABEL: func.func @matmul_16x16x256_read_permute
+func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
+ attributes { translation_info = #translation } {
+ %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
+ %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
+ %cst = arith.constant 0.000000e+00 : f16
+ %cst_f32 = arith.constant 0.000000e+00 : f32
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]}
+ : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x16xf32>
+ // CHECK: scf.for
+ %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) {
+ %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
+ %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
+ // CHECK: %[[READ0:.+]] = vector.transfer_read
+ // CHECK: to_layout %[[READ0]] to #[[$NESTED]]
+ // CHECK: %[[READ1:.+]] = vector.transfer_read
+ // CHECK: to_layout %[[READ1]] to #[[$NESTED1]]
+ vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
+ gpu.barrier
+ vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
+ gpu.barrier
+ %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
+ %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %10 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
+ scf.yield %10 : vector<16x16xf32>
+ }
+ vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
+ memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space<workgroup>>
+ memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>
+ return
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 1, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// We don't really care what layout we assign here, just that the only anchor
+// we set is on the contraction.
+
+// CHECK-LABEL: func.func @matmul_16x16x256_fused
+func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>,
+ %rhs: memref<32x16xf16>,
+ %bias: memref<16x16xf32>,
+ %out: memref<16x16xf32>)
+ attributes { translation_info = #translation } {
+ %cst = arith.constant 0.000000e+00 : f16
+ %cst_f32 = arith.constant 0.000000e+00 : f32
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %8 = vector.transfer_read %lhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16>, vector<16x32xf16>
+ %9 = vector.transfer_read %rhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16>, vector<32x16xf16>
+ // CHECK-DAG: %[[READA:.+]] = vector.transfer_read
+ // CHECK-DAG: %[[READB:.+]] = vector.transfer_read
+ // CHECK-DAG: %[[READC:.+]] = vector.transfer_read
+ // CHECK-NOT: to_layout %[[READA]]
+ // CHECK-NOT: to_layout %[[READB]]
+ // CHECK-NOT: to_layout %[[READC]]
+
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %10 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %9, %acc : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
+ %11 = vector.transfer_read %bias[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
+ %12 = arith.addf %10, %11 : vector<16x16xf32>
+ vector.transfer_write %12, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+ return
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [32, 1, 1]
+ subgroup_size = 32,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], subgroup_strides = [0, 0], thread_strides = [1, 0]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1], subgroup_strides = [0, 0], thread_strides = [0, 1]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>
+
+// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm
+func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
+ return %0 : vector<48x32xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [32, 1, 1]
+ subgroup_size = 32,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], subgroup_strides = [0, 0], thread_strides = [1, 0]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], subgroup_strides = [0, 0], thread_strides = [1, 0]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>
+
+// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mmt
+func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
+ return %0 : vector<48x32xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 2, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1, 1], batches_per_subgroup = [1, 4, 1], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 16, 4], elements_per_thread = [1, 1, 4], subgroup_strides = [1, 0, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 4], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1, 1], batches_per_subgroup = [1, 4, 4], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 4, 16], elements_per_thread = [1, 4, 1], subgroup_strides = [1, 0, 0], thread_strides = [0, 16, 1]>
+
+// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multi_m
+func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+ return %0 : vector<2x64x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [64, 2, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 4, subgroup_n_count = 1>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 2, 1], batches_per_subgroup = [1, 2, 1], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 16, 4], elements_per_thread = [1, 1, 4], subgroup_strides = [2, 1, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 4], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 2, 1], batches_per_subgroup = [1, 2, 4], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 4, 16], elements_per_thread = [1, 4, 1], subgroup_strides = [2, 1, 0], thread_strides = [0, 16, 1]>
+
+// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multi_split_m
+func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
+ return %0 : vector<2x64x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [128, 2, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 2>, workgroup_size = [128, 2, 1]}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1, 1], batches_per_subgroup = [2, 4, 1]
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1, 1], batches_per_subgroup = [1, 1, 4]
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 2, 1, 1], batches_per_subgroup = [2, 1, 4, 4]
+
+// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multi_m_and_n
+func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes { translation_info = #translation } {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<4x64x16xf16>, vector<2x16x64xf16> into vector<4x2x64x64xf32>
+ return %0 : vector<4x2x64x64xf32>
+}
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [32, 4, 1]
+ subgroup_size = 32,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 4], elements_per_thread = [1, 32], subgroup_strides = [0, 0], thread_strides = [4, 1]>
+
+// CHECK-LABEL: func.func @dequant_anchors_on_quant_only
+func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %scale: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>,
+ %zp: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>)
+ attributes { translation_info = #translation } {
+ %alloc = memref.alloc() : memref<128x128xf16, #gpu.address_space<workgroup>>
+ %cst = arith.constant 0.000000e+00 : f16
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %c0_i4 = arith.constant 0 : i4
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_read %quant[%c0, %c0], %c0_i4 {in_bounds = [true, true]} : memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128x128xi4>
+ // CHECK: %[[READ:.+]] = vector.transfer_read
+ // CHECK: to_layout %[[READ]] to #[[$NESTED]]
+ %1 = vector.transfer_read %scale[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
+ %2 = vector.broadcast %1 : vector<128xf16> to vector<128x128xf16>
+ %3 = vector.transpose %2, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
+ %4 = vector.transfer_read %zp[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
+ %5 = vector.broadcast %4 : vector<128xf16> to vector<128x128xf16>
+ %6 = vector.transpose %5, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
+ %7 = arith.extui %0 : vector<128x128xi4> to vector<128x128xi32>
+ %8 = arith.uitofp %7 : vector<128x128xi32> to vector<128x128xf16>
+ %9 = arith.subf %8, %6 : vector<128x128xf16>
+ %10 = arith.mulf %9, %3 : vector<128x128xf16>
+ vector.transfer_write %10, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<128x128xf16>, memref<128x128xf16, #gpu.address_space<workgroup>>
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 2, 1], batches_per_subgroup = [1, 2, 4], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 16, 4], elements_per_thread = [1, 1, 4], subgroup_strides = [0, 2, 0], thread_strides = [0, 1, 16]>
+// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1, 2], batches_per_subgroup = [1, 4, 4], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 4, 16], elements_per_thread = [1, 4, 1], subgroup_strides = [0, 0, 1], thread_strides = [0, 16, 1]>
+// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 2, 2], batches_per_subgroup = [1, 2, 4], outers_per_batch = [1, 1, 1], threads_per_outer = [1, 4, 16], elements_per_thread = [1, 4, 1], subgroup_strides = [0, 2, 1], thread_strides = [0, 16, 1]>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+ workgroup_size = [128, 2, 1]
+ subgroup_size = 64,
+ {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+// CHECK-LABEL: func.func @batch_matmul_unit_batch
+func.func @batch_matmul_unit_batch(%arg0: vector<1x64x64xf16>, %arg1: vector<1x64x128xf16>, %arg2: vector<1x64x128xf32>) -> vector<1x64x128xf32> attributes {translation_info = #translation} {
+ // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]]
+ // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]]
+ // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
+ // CHECK: vector.contract
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2 : vector<1x64x64xf16>, vector<1x64x128xf16> into vector<1x64x128xf32>
+ return %0 : vector<1x64x128xf32>
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
index 0818851..9c9eab0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-configure-vector-layouts, iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
deleted file mode 100644
index a2519ab..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ /dev/null
@@ -1,470 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute{test-layout}, canonicalize, cse))' %s | FileCheck %s
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
- return %0 : vector<96x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf32>
- return %0 : vector<96x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<64x96xf32>) -> vector<64x96xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, k) -> (n, m)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<64x96xf32>
- return %0 : vector<64x96xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 3], outers_per_batch = [1, 4], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 2, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
-
-func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<192x16xf16>, vector<16x64xf16> into vector<192x64xf32>
- return %0 : vector<192x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1]
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1]
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<subgroups_per_workgroup = [2, 1]
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
- attributes { translation_info = #translation } {
- %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
- %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
- %cst = arith.constant 0.000000e+00 : f16
- %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
- %c32 = arith.constant 32 : index
- %c256 = arith.constant 256 : index
- %c0 = arith.constant 0 : index
- %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) {
- %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
- %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
- vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
- gpu.barrier
- vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
- gpu.barrier
- %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
- %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
- %10 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
- scf.yield %10 : vector<16x16xf32>
- }
- vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
- memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space<workgroup>>
- memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>
- return
-}
-
-// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}<storage_buffer>>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [4, 1]>
-// CHECK: transfer '{{.+}} memref<256x16xf16{{.+}}<storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [2, 1]>
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 16]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-
-func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
- attributes { translation_info = #translation } {
- %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
- %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
- %cst = arith.constant 0.000000e+00 : f16
- %cst_f32 = arith.constant 0.000000e+00 : f32
- %c32 = arith.constant 32 : index
- %c256 = arith.constant 256 : index
- %c0 = arith.constant 0 : index
- %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]}
- : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x16xf32>
- %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) {
- %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
- %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
- vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
- gpu.barrier
- vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
- gpu.barrier
- %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
- %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
- %10 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
- scf.yield %10 : vector<16x16xf32>
- }
- vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
- memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space<workgroup>>
- memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>
- return
-}
-
-// CHECK-NOT: transfer '{{.+}} memref<16x16xf16{{.+}}<storage_buffer>>, vector<16x16xf16>' vector layout
-// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}<storage_buffer>>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [4, 1]>
-// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 4]>
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 16]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>,
- %rhs: memref<32x16xf16>,
- %bias: memref<16x16xf32>,
- %out: memref<16x16xf32>)
- attributes { translation_info = #translation } {
- %cst = arith.constant 0.000000e+00 : f16
- %cst_f32 = arith.constant 0.000000e+00 : f32
- %c32 = arith.constant 32 : index
- %c256 = arith.constant 256 : index
- %c0 = arith.constant 0 : index
- %acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
- %8 = vector.transfer_read %lhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16>, vector<16x32xf16>
- %9 = vector.transfer_read %rhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16>, vector<32x16xf16>
- %10 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %8, %9, %acc : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
- %11 = vector.transfer_read %bias[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32>
- %12 = arith.addf %10, %11 : vector<16x16xf32>
- vector.transfer_write %12, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
- return
-}
-
-// We don't really care what layout we assign here, just that the only anchor
-// we set is on the contraction.
-// CHECK-NOT: transfer {{.*}} vector layout
-// CHECK: contract A vector layout
-// CHECK-NOT: transfer {{.*}} vector layout
-// CHECK: contract B vector layout
-// CHECK-NOT: transfer {{.*}} vector layout
-// CHECK: contract C vector layout
-// CHECK-NOT: transfer {{.*}} vector layout
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [32, 1, 1]
- subgroup_size = 32,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
- return %0 : vector<48x32xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [0, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [32, 1, 1]
- subgroup_size = 32,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
- iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
- return %0 : vector<48x32xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 2, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 1>}>
-
-
-func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
- return %0 : vector<2x64x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 4, 1],
-// CHECK-SAME: outers_per_batch = [1, 1, 1],
-// CHECK-SAME: threads_per_outer = [1, 16, 4],
-// CHECK-SAME: elements_per_thread = [1, 1, 4],
-// CHECK-SAME: subgroup_strides = [1, 0, 0],
-// CHECK-SAME: thread_strides = [0, 1, 16]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 4],
-// CHECK-SAME: outers_per_batch = [1, 1],
-// CHECK-SAME: threads_per_outer = [4, 16],
-// CHECK-SAME: elements_per_thread = [4, 1],
-// CHECK-SAME: subgroup_strides = [0, 0],
-// CHECK-SAME: thread_strides = [16, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 4, 4],
-// CHECK-SAME: outers_per_batch = [1, 1, 1],
-// CHECK-SAME: threads_per_outer = [1, 4, 16],
-// CHECK-SAME: elements_per_thread = [1, 4, 1],
-// CHECK-SAME: subgroup_strides = [1, 0, 0],
-// CHECK-SAME: thread_strides = [0, 16, 1]>
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [64, 2, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 4, subgroup_n_count = 1>}>
-
-func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<2x64x16xf16>, vector<16x64xf16> into vector<2x64x64xf32>
- return %0 : vector<2x64x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 2, 1],
-// CHECK-SAME: subgroup_strides = [2, 1, 0],
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 2, 4],
-// CHECK-SAME: subgroup_strides = [2, 1, 0],
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [128, 2, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 2>, workgroup_size = [128, 2, 1]}>
-
-func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes { translation_info = #translation } {
- %0 = vector.contract {
- indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
- %lhs, %rhs, %init : vector<4x64x16xf16>, vector<2x16x64xf16> into vector<4x2x64x64xf32>
- return %0 : vector<4x2x64x64xf32>
-}
-
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
-// CHECK-SAME: batches_per_subgroup = [2, 4, 1],
-// CHECK-SAME: subgroup_strides = [2, 0, 0],
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 1, 4],
-// CHECK-SAME: subgroup_strides = [1, 0, 0],
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1, 1],
-// CHECK-SAME: batches_per_subgroup = [2, 1, 4, 4],
-// CHECK-SAME: subgroup_strides = [2, 1, 0, 0],
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [32, 4, 1]
- subgroup_size = 32,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>
-
-func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %scale: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>,
- %zp: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>)
- attributes { translation_info = #translation } {
- %alloc = memref.alloc() : memref<128x128xf16, #gpu.address_space<workgroup>>
- %cst = arith.constant 0.000000e+00 : f16
- %cst_0 = arith.constant 0.000000e+00 : f32
- %c32 = arith.constant 32 : index
- %c256 = arith.constant 256 : index
- %c0_i4 = arith.constant 0 : i4
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_read %quant[%c0, %c0], %c0_i4 {in_bounds = [true, true]} : memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128x128xi4>
- %1 = vector.transfer_read %scale[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
- %2 = vector.broadcast %1 : vector<128xf16> to vector<128x128xf16>
- %3 = vector.transpose %2, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
- %4 = vector.transfer_read %zp[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
- %5 = vector.broadcast %4 : vector<128xf16> to vector<128x128xf16>
- %6 = vector.transpose %5, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
- %7 = arith.extui %0 : vector<128x128xi4> to vector<128x128xi32>
- %8 = arith.uitofp %7 : vector<128x128xi32> to vector<128x128xf16>
- %9 = arith.subf %8, %6 : vector<128x128xf16>
- %10 = arith.mulf %9, %3 : vector<128x128xf16>
- vector.transfer_write %10, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<128x128xf16>, memref<128x128xf16, #gpu.address_space<workgroup>>
- return
-}
-// CHECK: transfer '{{.+}} memref<128x128xi4{{.+}}<storage_buffer>>, vector<128x128xi4>' vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 4], elements_per_thread = [1, 32], subgroup_strides = [0, 0], thread_strides = [4, 1]>
-// CHECK-NOT: transfer '{{.+}} memref<128xf16{{.+}}<storage_buffer>>, vector<128xf16>' vector layout
-
-// -----
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
- workgroup_size = [128, 2, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 2>}>
-func.func @batch_matmul_unit_batch(%arg0: vector<1x64x64xf16>, %arg1: vector<1x64x128xf16>, %arg2: vector<1x64x128xf32>) -> vector<1x64x128xf32> attributes {translation_info = #translation} {
- %0 = vector.contract {
- indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "parallel", "reduction"],
- kind = #vector.kind<add>}
- %arg0, %arg1, %arg2 : vector<1x64x64xf16>, vector<1x64x128xf16> into vector<1x64x128xf32>
- return %0 : vector<1x64x128xf32>
-}
-// CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 2, 1],
-// CHECK-SAME: batches_per_subgroup = [1, 2, 4],
-// CHECK-SAME: outers_per_batch = [1, 1, 1]
-// CHECK-SAME: threads_per_outer = [1, 16, 4]
-// CHECK-SAME: elements_per_thread = [1, 1, 4]
-// CHECK-SAME: subgroup_strides = [0, 2, 0],
-// CHECK-SAME: thread_strides = [0, 1, 16]>
-// CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 1, 2]
-// CHECK-SAME: batches_per_subgroup = [1, 4, 4]
-// CHECK-SAME: outers_per_batch = [1, 1, 1]
-// CHECK-SAME: threads_per_outer = [1, 4, 16]
-// CHECK-SAME: elements_per_thread = [1, 4, 1]
-// CHECK-SAME: subgroup_strides = [0, 0, 1],
-// CHECK-SAME: thread_strides = [0, 16, 1]>
-// CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
-// CHECK-SAME: subgroups_per_workgroup = [1, 2, 2]
-// CHECK-SAME: batches_per_subgroup = [1, 2, 4]
-// CHECK-SAME: outers_per_batch = [1, 1, 1]
-// CHECK-SAME: threads_per_outer = [1, 4, 16]
-// CHECK-SAME: elements_per_thread = [1, 4, 1]
-// CHECK-SAME: subgroup_strides = [0, 2, 1],
-// CHECK-SAME: thread_strides = [0, 16, 1]>