More cleanup in `IREEGPUAttrs`. (#19161)
* Remove nearly 1,000 lines of code from `IREEGPUAttrs.cpp`, plus some
methods out of `IREEGPUAttrs.td`. The bulk of that is code that was only
used in a single place and didn't need to be in this central location,
gets moved to its single user. A few helper functions were completely
dead code, and a few additional simplifications.
* Write up-to-date comments for core data structures, particularly the
all-important `SingleSubgroupLayout`.
* Within `IREEGPUAttrs.cpp`, group at the top of the file the code that
encodes the raw information about MMA intrinsics, i.e. the code that
needs to be updated when a new intrinsic is supported.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 0a27437..19a7b3a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -12,7 +12,6 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
-#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
@@ -45,169 +44,34 @@
#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc"
-using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension;
-using LayoutDimensionAttr =
- mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr;
-using VectorLayoutInterface =
- mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface;
-using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr;
-using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr;
-using NestedLayoutAttr = mlir::iree_compiler::IREE::VectorExt::NestedLayoutAttr;
-
namespace mlir::iree_compiler::IREE::GPU {
-namespace {
-// Struct containing abstract MMA shape and type information.
-struct OpaqueMmaLayout {
- int64_t mSize = 0;
- int64_t nSize = 0;
- int64_t kSize = 0;
- Type aType;
- Type bType;
- Type cType;
-};
-
-// Struct containing concrete MMA shape, type, and layout information.
-struct ConcreteMmaLayout {
- OpaqueMmaLayout base;
- PerDimLayoutAttr aMLayout;
- PerDimLayoutAttr aKLayout;
- PerDimLayoutAttr bKLayout;
- PerDimLayoutAttr bNLayout;
- PerDimLayoutAttr cMLayout;
- PerDimLayoutAttr cNLayout;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
-// #iree_gpu.mma_vector_layout
+// MMA intrinsics semantics: shapes, layouts, operand element types.
//===----------------------------------------------------------------------===//
-static PerDimLayoutAttr getBatchedPerDimLayoutAttr(LayoutDimensionAttr batchDim,
- PerDimLayoutAttr baseLayout,
- int64_t problemSize,
- int64_t fragmentDimSize) {
- assert(problemSize % fragmentDimSize == 0 &&
- "invalid layout fragment for problem size");
-
- SmallVector<LayoutDimensionAttr, 3> dimAttrs(baseLayout.getLabels());
- dimAttrs.insert(dimAttrs.begin(), batchDim);
-
- SmallVector<int64_t, 3> shapes(baseLayout.getShapes());
- shapes.insert(shapes.begin(), problemSize / fragmentDimSize);
- auto layout =
- PerDimLayoutAttr::get(baseLayout.getContext(), dimAttrs, shapes);
- return layout;
+static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
+ // Not supporting any block size other than 1 at the moment.
+ return 1;
}
-// Get the batched layout attributes for the given fragment layouts, indexing
-// map, and problem shape. The canonical fragment map is used to compare against
-// the problem map |indexingMap|. For example, for mma fragment B (RHS):
-//
-// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B
-// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1)
-// problemShape = [32, 64]
-// fragmentSize = [16, 8]
-// fragmentLayouts = [kLayout, nLayout]
-//
-// Gives batched layout
-//
-// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape]
-// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape]
-static LayoutAttr
-getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap,
- ArrayRef<int64_t> problemShape,
- ArrayRef<int64_t> fragmentSize,
- ArrayRef<PerDimLayoutAttr> fragmentLayouts) {
- // Current distribution to MFMA operations does not support batched
- // contractions so that is reflected here.
- assert(indexingMap.getNumResults() == 2 &&
- "invalid indexing map to non-batched simple contraction");
-
- LayoutDimensionAttr batchX = LayoutDimensionAttr::get(
- indexingMap.getContext(), LayoutDimension::BATCHX);
- LayoutDimensionAttr batchY = LayoutDimensionAttr::get(
- indexingMap.getContext(), LayoutDimension::BATCHY);
-
- SmallVector<PerDimLayoutAttr, 2> perDimAttrs;
- for (auto [expr, batchType] :
- llvm::zip_equal(indexingMap.getResults(),
- SmallVector<LayoutDimensionAttr, 2>{batchX, batchY})) {
- auto maybeResultPosition = fragmentMap.getResultPosition(expr);
- assert(maybeResultPosition && "fragment map and problem map mismatch");
- int64_t idx = *maybeResultPosition;
- perDimAttrs.push_back(getBatchedPerDimLayoutAttr(
- batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx]));
- }
-
- return LayoutAttr::get(indexingMap.getContext(), perDimAttrs);
+static uint32_t getArchID(MMAIntrinsic intrinsic) {
+ return static_cast<int>(intrinsic) & 0xFF00;
}
-static FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
- VectorLayoutInterface>>
-getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
- MLIRContext *context = contract.getContext();
- FailureOr<linalg::ContractionDimensions> maybeContractionDims =
- linalg::inferContractionDims(contract.getIndexingMapsArray());
- if (failed(maybeContractionDims)) {
- return failure();
- }
- auto contractionDims = *maybeContractionDims;
- // TODO: Relax this condition to strictly alignment requirements.
- if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 ||
- contractionDims.n.size() != 1) {
- return failure();
- }
- // TODO: Support batched contractions.
- if (contractionDims.batch.size() > 0) {
- return failure();
- }
- unsigned mDim = contractionDims.m[0];
- unsigned nDim = contractionDims.n[0];
- unsigned kDim = contractionDims.k[0];
-
- SmallVector<int64_t> iterationBounds;
- contract.getIterationBounds(iterationBounds);
-
- int64_t problemMSize = iterationBounds[mDim];
- int64_t problemNSize = iterationBounds[nDim];
- int64_t problemKSize = iterationBounds[kDim];
-
- int64_t mSize = layout.base.mSize;
- int64_t nSize = layout.base.nSize;
- int64_t kSize = layout.base.kSize;
-
- // The problem size currently must be strictly aligned to the size of the mma.
- // This is expected to succeed assuming the correct [masked] vector size was
- // set at strategy configuration time (for this mma).
- if (problemMSize % mSize != 0 || problemNSize % nSize ||
- problemKSize % kSize) {
- return failure();
- }
-
- LayoutAttr aLayout = getBatchedLayoutAttr(
- contract.getIndexingMapsArray()[0],
- AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context),
- {problemMSize, problemKSize}, {mSize, kSize},
- {layout.aMLayout, layout.aKLayout});
- LayoutAttr bLayout = getBatchedLayoutAttr(
- contract.getIndexingMapsArray()[1],
- AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context),
- {problemKSize, problemNSize}, {kSize, nSize},
- {layout.bKLayout, layout.bNLayout});
- LayoutAttr cLayout = getBatchedLayoutAttr(
- contract.getIndexingMapsArray()[2],
- AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context),
- {problemMSize, problemNSize}, {mSize, nSize},
- {layout.cMLayout, layout.cNLayout});
-
- return std::make_tuple<VectorLayoutInterface, VectorLayoutInterface,
- VectorLayoutInterface>(aLayout, bLayout, cLayout);
+static bool is_AMD_MFMA(MMAIntrinsic intrinsic) {
+ return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF;
}
-//===----------------------------------------------------------------------===//
-// Layout Attribute Building Helpers
-//===----------------------------------------------------------------------===//
+static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
+ return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
+}
+
+static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
+ // Not using Wave64 at all at the moment, so the only place where the
+ // subgroup size is CDNA* architectures.
+ return is_AMD_MFMA(intrinsic) ? 64 : 32;
+}
static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
MMAIntrinsic intrinsic) {
@@ -263,233 +127,6 @@
return {};
}
-template <typename MMAIntrinsicType>
-static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
- MMAIntrinsicType intrinsic) {
- OpaqueMmaLayout o;
- std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
- auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
- auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
- o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
- o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
- o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
- return o;
-}
-
-static std::tuple<PerDimLayoutAttr, PerDimLayoutAttr>
-getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
- // Step 1: obtain the swizzled tile shape, but keeping track of the source
- // dimension indices.
- struct SrcIndexAndSwizzleDim {
- size_t srcIndex;
- TileSwizzle::Dim dim;
- };
- SmallVector<SrcIndexAndSwizzleDim> swizzledShape;
- for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) {
- for (TileSwizzle::Dim d : e) {
- swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d});
- }
- }
- applyPermutationToVector(swizzledShape, swizzle.permutation);
-
- // Step 2: collect the appropriate labels to use for the swizzled dims.
- LayoutDimension internalLabels[] = {LayoutDimension::VECTORZ,
- LayoutDimension::VECTORY,
- LayoutDimension::VECTORX};
- LayoutDimension crossThreadLabels[] = {
- LayoutDimension::LANEZ, LayoutDimension::LANEY, LayoutDimension::LANEX};
- auto internalLabelIter = std::end(internalLabels);
- auto crossThreadLabelIter = std::end(crossThreadLabels);
- for (SrcIndexAndSwizzleDim d : swizzledShape) {
- if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) {
- assert(internalLabelIter != std::begin(internalLabels));
- --internalLabelIter;
- } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) {
- assert(crossThreadLabelIter != std::begin(crossThreadLabels));
- --crossThreadLabelIter;
- } else {
- assert(false && "unexpected dimension kind in intrinsic swizzle");
- }
- }
-
- // Step 3: put together the result PerDimLayoutAttr'd for the two source dims.
- SmallVector<LayoutDimensionAttr> labels[2];
- SmallVector<int64_t> shape[2];
- for (SrcIndexAndSwizzleDim d : swizzledShape) {
- shape[d.srcIndex].push_back(d.dim.size);
- auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal)
- ? internalLabelIter
- : crossThreadLabelIter;
- labels[d.srcIndex].push_back(LayoutDimensionAttr::get(
- context, static_cast<LayoutDimension>(*labelIterRef++)));
- }
- return {PerDimLayoutAttr::get(context, labels[0], shape[0]),
- PerDimLayoutAttr::get(context, labels[1], shape[1])};
-};
-
-static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
- MMAIntrinsic intrinsic) {
- auto opaque = getOpaqueMMALayout(context, intrinsic);
- ConcreteMmaLayout concreteLayout;
- concreteLayout.base = opaque;
- auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs);
- auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Rhs);
- auto accSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Acc);
- std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) =
- getPerDimLayoutAttrs(context, lhsSwizzle);
- std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) =
- getPerDimLayoutAttrs(context, rhsSwizzle);
- std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) =
- getPerDimLayoutAttrs(context, accSwizzle);
- return concreteLayout;
-}
-
-//===----------------------------------------------------------------------===//
-// MmaInterface Attribute Helper Functions
-//===----------------------------------------------------------------------===//
-
-MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
- if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
- return mmaAttr.getASingleSubgroupLayout();
- } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
- return vmmaAttr.getASingleSubgroupLayout();
- } else {
- assert(false && "unhandled MMA Interface type.");
- return {};
- }
-}
-
-MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
- if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
- return mmaAttr.getBSingleSubgroupLayout();
- } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
- return vmmaAttr.getBSingleSubgroupLayout();
- } else {
- assert(false && "unhandled MMA Interface type.");
- return {};
- }
-}
-
-MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
- if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
- return mmaAttr.getCSingleSubgroupLayout();
- } else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
- return vmmaAttr.getCSingleSubgroupLayout();
- } else {
- assert(false && "unhandled MMA Interface type.");
- return {};
- }
-}
-
-//===----------------------------------------------------------------------===//
-// MFMA Attributes
-//===----------------------------------------------------------------------===//
-
-Attribute MMAAttr::parse(AsmParser &p, Type type) {
- if (failed(p.parseLess()))
- return {};
-
- FailureOr<MMAIntrinsicAttr> mmaIntrinsic =
- FieldParser<MMAIntrinsicAttr>::parse(p);
- if (failed(mmaIntrinsic)) {
- p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier");
- return {};
- }
-
- if (failed(p.parseGreater()))
- return {};
-
- return get(p.getContext(), mmaIntrinsic->getValue());
-}
-
-void MMAAttr::print(AsmPrinter &p) const {
- auto &os = p.getStream();
- os << "<";
- os << stringifyMMAIntrinsic(getIntrinsic().getValue());
- os << ">";
-}
-
-MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
- auto layout = getOpaqueMMALayout(context, type);
- return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
- layout.nSize, layout.kSize, layout.aType, layout.bType,
- layout.cType);
-}
-
-std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
- return {getAType(), getBType(), getCType()};
-}
-
-std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
- return {getMSize(), getNSize(), getKSize()};
-}
-
-template <typename MMAIntrinsicType>
-static VectorType getVectorType(MLIRContext *context,
- MMAIntrinsicType intrinsic,
- MMAFragment fragment) {
- auto o = getOpaqueMMALayout(context, intrinsic);
- auto s = getSingleSubgroupLayout(intrinsic, fragment);
- Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
- : (fragment == MMAFragment::Rhs) ? o.bType
- : o.cType;
- return VectorType::get(
- {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType);
-}
-
-std::tuple<VectorType, VectorType, VectorType>
-MMAAttr::getABCVectorTypes() const {
- MLIRContext *context = getContext();
- MMAIntrinsic intrinsic = getIntrinsic().getValue();
- VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
- VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
- VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
- return {aVecType, bVecType, cVecType};
-}
-
-FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
- VectorLayoutInterface>>
-MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
- ConcreteMmaLayout layout =
- getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue());
- return IREE::GPU::getContractionLayout(contract, layout);
-}
-
-static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
- // Not supporting any block size other than 1 at the moment.
- return 1;
-}
-
-int64_t MMAAttr::getBlockSize() const {
- return IREE::GPU::getBlockSize(getIntrinsic().getValue());
-}
-
-static uint32_t getArchID(MMAIntrinsic intrinsic) {
- return static_cast<int>(intrinsic) & 0xFF00;
-}
-
-static bool is_AMD_MFMA(MMAIntrinsic intrinsic) {
- return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF;
-}
-
-static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
- return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
-}
-
-static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
- // Not using Wave64 at all at the moment, so the only place where the
- // subgroup size is CDNA* architectures.
- return is_AMD_MFMA(intrinsic) ? 64 : 32;
-}
-
-int64_t MMAAttr::getSubgroupSize() const {
- return getIntrinsicSubgroupSize(getIntrinsic().getValue());
-}
-
-FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
- return IREE::GPU::MMAScope::Subgroup;
-}
-
MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
MMAFragment fragment) {
switch (intrinsic) {
@@ -637,6 +274,139 @@
return {};
}
+template <typename MMAIntrinsicType>
+static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+ MMAIntrinsicType intrinsic) {
+ OpaqueMmaLayout o;
+ std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
+ auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
+ auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
+ o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
+ o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
+ o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
+ return o;
+}
+
+OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+ IREE::GPU::MMAIntrinsic intrinsic) {
+ return getOpaqueMMALayout<IREE::GPU::MMAIntrinsic>(context, intrinsic);
+}
+
+//===----------------------------------------------------------------------===//
+// MmaInterface Attribute Helper Functions
+//===----------------------------------------------------------------------===//
+
+MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+ if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+ return mmaAttr.getASingleSubgroupLayout();
+ }
+ if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+ return vmmaAttr.getASingleSubgroupLayout();
+ }
+ assert(false && "unhandled MMA Interface type.");
+ return {};
+}
+
+MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+ if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+ return mmaAttr.getBSingleSubgroupLayout();
+ }
+ if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+ return vmmaAttr.getBSingleSubgroupLayout();
+ }
+ assert(false && "unhandled MMA Interface type.");
+ return {};
+}
+
+MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
+ if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
+ return mmaAttr.getCSingleSubgroupLayout();
+ }
+ if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
+ return vmmaAttr.getCSingleSubgroupLayout();
+ }
+ assert(false && "unhandled MMA Interface type.");
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// MMA Attributes
+//===----------------------------------------------------------------------===//
+
+Attribute MMAAttr::parse(AsmParser &p, Type type) {
+ if (failed(p.parseLess()))
+ return {};
+
+ FailureOr<MMAIntrinsicAttr> mmaIntrinsic =
+ FieldParser<MMAIntrinsicAttr>::parse(p);
+ if (failed(mmaIntrinsic)) {
+ p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier");
+ return {};
+ }
+
+ if (failed(p.parseGreater()))
+ return {};
+
+ return get(p.getContext(), mmaIntrinsic->getValue());
+}
+
+void MMAAttr::print(AsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << "<";
+ os << stringifyMMAIntrinsic(getIntrinsic().getValue());
+ os << ">";
+}
+
+MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
+ auto layout = getOpaqueMMALayout(context, type);
+ return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
+ layout.nSize, layout.kSize, layout.aType, layout.bType,
+ layout.cType);
+}
+
+std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
+ return {getAType(), getBType(), getCType()};
+}
+
+std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
+ return {getMSize(), getNSize(), getKSize()};
+}
+
+template <typename MMAIntrinsicType>
+static VectorType getVectorType(MLIRContext *context,
+ MMAIntrinsicType intrinsic,
+ MMAFragment fragment) {
+ auto o = getOpaqueMMALayout(context, intrinsic);
+ auto s = getSingleSubgroupLayout(intrinsic, fragment);
+ Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
+ : (fragment == MMAFragment::Rhs) ? o.bType
+ : o.cType;
+ return VectorType::get(
+ {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType);
+}
+
+std::tuple<VectorType, VectorType, VectorType>
+MMAAttr::getABCVectorTypes() const {
+ MLIRContext *context = getContext();
+ MMAIntrinsic intrinsic = getIntrinsic().getValue();
+ VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
+ VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
+ VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
+ return {aVecType, bVecType, cVecType};
+}
+
+int64_t MMAAttr::getBlockSize() const {
+ return IREE::GPU::getBlockSize(getIntrinsic().getValue());
+}
+
+int64_t MMAAttr::getSubgroupSize() const {
+ return getIntrinsicSubgroupSize(getIntrinsic().getValue());
+}
+
+FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
+ return IREE::GPU::MMAScope::Subgroup;
+}
+
MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
}
@@ -781,22 +551,8 @@
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
- MMASingleSubgroupLayout subgroupLayout;
- switch (fragment) {
- case IREE::GPU::MMAFragment::Lhs: {
- subgroupLayout = getASingleSubgroupLayout();
- break;
- }
- case IREE::GPU::MMAFragment::Rhs: {
- subgroupLayout = getBSingleSubgroupLayout();
- break;
- }
- case IREE::GPU::MMAFragment::Acc: {
- subgroupLayout = getCSingleSubgroupLayout();
- break;
- }
- }
-
+ MMASingleSubgroupLayout subgroupLayout =
+ getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
SmallVector<OpFoldResult> canonicalOffsets;
SmallVector<OpFoldResult> canonicalSizes;
if (failed(populateCanonicalOffsetsSizesAndStrides(
@@ -810,82 +566,6 @@
return success();
}
-LogicalResult MMAAttr::materializeOperandConcreteShape(
- OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand,
- std::optional<ArrayRef<int64_t>> permutation,
- SmallVector<ReassociationIndices> &reassociations,
- RankedTensorType &resultType) const {
-
- SmallVector<int64_t, 2> outerSizes;
- SmallVector<int64_t, 2> opaqueSizes;
- auto [m, n, k] = getMNKShape();
- switch (fragment) {
- case IREE::GPU::MMAFragment::Lhs: {
- outerSizes = getASingleSubgroupLayout().outer;
- opaqueSizes.append({m, k});
- break;
- }
- case IREE::GPU::MMAFragment::Rhs: {
- outerSizes = getBSingleSubgroupLayout().outer;
- opaqueSizes.append({k, n});
- break;
- }
- case IREE::GPU::MMAFragment::Acc: {
- outerSizes = getCSingleSubgroupLayout().outer;
- opaqueSizes.append({m, n});
- break;
- }
- }
- if (permutation.has_value()) {
- if (permutation.value().size() != outerSizes.size()) {
- return failure();
- }
- applyPermutationToVector(opaqueSizes, permutation.value());
- applyPermutationToVector(outerSizes, permutation.value());
- }
-
- // Inner tile must have sizes matching the opaque layout.
- auto operandType = llvm::cast<RankedTensorType>(operand.getType());
- ArrayRef<int64_t> operandShape = operandType.getShape();
- SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
- operandShape.end());
- if (!llvm::equal(opaqueSizes, innerShape)) {
- return failure();
- }
-
- // Expand the shape of the inner tile to reflect the MMA thread layout.
- SmallVector<int64_t, 4> resultShape(operandShape.begin(),
- operandShape.end() - 2);
- SmallVector<ReassociationIndices> reInds =
- llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
- [](int64_t idx) -> ReassociationIndices {
- return ReassociationIndices({idx});
- });
- int idx = reInds.size();
- for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) {
- // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives
- // a guarantee that the |element| counts are contiguous within the layout,
- // and a unit outer implies a single offset and size for that dimension.
- if (outer == 1) {
- resultShape.push_back(native);
- reInds.push_back(ReassociationIndices({idx++}));
- continue;
- }
-
- // Reshape to [outer, native / outer] == [outer, thread * element]. This
- // corresponds to |outer| repetitions of the thread/element sublayout.
- resultShape.push_back(outer);
- assert(native % outer == 0 && "invalid mma layout");
- resultShape.push_back(native / outer);
- reInds.push_back(ReassociationIndices{idx, idx + 1});
- idx += 2;
- }
-
- reassociations = reInds;
- resultType = operandType.clone(resultShape);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// DataTiledMMA Attributes
//===----------------------------------------------------------------------===//
@@ -1265,22 +945,8 @@
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
- MMASingleSubgroupLayout subgroupLayout;
- switch (fragment) {
- case IREE::GPU::MMAFragment::Lhs: {
- subgroupLayout = getASingleSubgroupLayout();
- break;
- }
- case IREE::GPU::MMAFragment::Rhs: {
- subgroupLayout = getBSingleSubgroupLayout();
- break;
- }
- case IREE::GPU::MMAFragment::Acc: {
- subgroupLayout = getCSingleSubgroupLayout();
- break;
- }
- }
-
+ MMASingleSubgroupLayout subgroupLayout =
+ getSingleSubgroupLayout(getIntrinsic().getValue(), fragment);
SmallVector<OpFoldResult> canonicalOffsets;
SmallVector<OpFoldResult> canonicalSizes;
if (failed(populateCanonicalOffsetsSizesAndStrides(
@@ -1445,348 +1111,6 @@
}
//===----------------------------------------------------------------------===//
-// MMA Schedule Attributes
-//===----------------------------------------------------------------------===//
-
-/// Gets a unit vector of the given rank, but fills in the given dimensions
-/// from the 2 element array |counts|. |dim0| is the position in the returned
-/// vector to put the first element of |counts|, and |dim1| is the position to
-/// put the second element. For example,
-///
-/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
-/// returns [1, 5, 7]
-SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
- ArrayRef<int64_t> counts,
- int64_t dim0, int64_t dim1) {
- assert(counts.size() == 2 &&
- "Unexpected non-rank 2 single subgroup dimension counts");
- SmallVector<int64_t> res(rank, 1);
- res[dim0] = counts[0];
- res[dim1] = counts[1];
- return res;
-}
-
-SmallVector<int64_t> getIdentityPerm(int64_t rank) {
- return llvm::to_vector(llvm::seq(static_cast<int64_t>(0), rank));
-}
-
-/// Constructs an identity permutation with the given rank, except it applies
-/// the given rank-2 |perm| to the two dimensions |dim0| and |dim1|, and then
-/// swaps the positions of dim0 and dim1 in the final permutation. For example,
-///
-/// rank = 3, perm = [1, 0], dim0 = 1, dim1 = 2
-/// returns [0, 1, 2]
-///
-/// This is essentially just applying two rank-2 permutations to two particular
-/// dimensions. First it applies |perm|, which corresponds to a permutation
-/// needed by the underlying intrinsic, then it does another permutation based
-/// on the order of actual dimensions for the MMA fragment. For example, for the
-/// B matrix, dim0 = K and dim1 = N, so for the element order of an MFMA
-/// 16x16x16, perm would be `[1, 0]`, however if the actual contraction is a
-/// matmul_transpose_b, then the element order needs to be [0, 1].
-SmallVector<int64_t> getIdentityPermWithSwap(int64_t rank,
- ArrayRef<int64_t> perm,
- int64_t dim0, int64_t dim1) {
- assert(perm.size() == 2 &&
- "Unexpected non-rank 2 single subgroup dimension order");
- SmallVector<int64_t> res = getIdentityPerm(rank);
- if (perm[0] > perm[1]) {
- std::swap(dim0, dim1);
- }
- if (dim0 > dim1) {
- res[dim0] = dim1;
- res[dim1] = dim0;
- }
- return res;
-}
-
-/// Constructs the nested layout given the layout for a single subgroup and the
-/// subgroup/batch counts and orders, as well as the dimensions along which to
-/// distribute the intrinsic's layout.
-///
-/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
-/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
-/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
-/// we would have:
-/// outerDim = 1 for the K dim
-/// innerDim = 0 for the N dim
-///
-/// For something like MK_NKN_MN with multiple N dims, it would typically be:
-/// outerDim = 1 for K
-/// innerDim = 2 for the second N dim
-///
-/// Importantly these two dimensions always refer to the actual dimension
-/// positions in the undistributed vector. For each fragment, this means:
-/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
-/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
-/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
-///
-/// And here inner most is referential to the iteration order, not the order
-/// they appear per fragment (because there is no relationship between the
-/// dimension order of M in A and in C, for example).
-NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
- int64_t outerDim, int64_t innerDim,
- SmallVector<int64_t> subgroupSizes,
- SmallVector<int64_t> subgroupStrides,
- SmallVector<int64_t> batchCount,
- MMASingleSubgroupLayout counts) {
-
- LLVM_DEBUG({
- llvm::errs() << "Creating Nested Layout for::";
- llvm::errs() << "\n outerDim = " << outerDim;
- llvm::errs() << "\n innerDim = " << innerDim;
- llvm::errs() << "\n subgroupSizes: ";
- llvm::interleaveComma(subgroupSizes, llvm::errs());
- llvm::errs() << "\n subgroupStrides: ";
- llvm::interleaveComma(subgroupStrides, llvm::errs());
- llvm::errs() << "\n batchCount: ";
- llvm::interleaveComma(batchCount, llvm::errs());
- llvm::errs() << "\n counts.outer: ";
- llvm::interleaveComma(counts.outer, llvm::errs());
- llvm::errs() << "\n counts.thread: ";
- llvm::interleaveComma(counts.thread, llvm::errs());
- llvm::errs() << "\n counts.element: ";
- llvm::interleaveComma(counts.element, llvm::errs());
- llvm::errs() << "\n counts.tstrides: ";
- llvm::interleaveComma(counts.tstrides, llvm::errs());
- llvm::errs() << "\n";
- });
-
- SmallVector<int64_t> outerCount =
- getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
- SmallVector<int64_t> threadCount =
- getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
- SmallVector<int64_t> threadStrides =
- getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
- SmallVector<int64_t> elementCount =
- getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
-
- auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
- outerCount, threadCount, elementCount,
- subgroupStrides, threadStrides);
- return layoutAttr;
-}
-
-FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
- VectorExt::VectorLayoutInterface,
- VectorExt::VectorLayoutInterface>>
-MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo,
- linalg::LinalgOp contractOp) const {
- LLVM_DEBUG({
- llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
- llvm::errs() << "For schedule: " << *this << "\n";
- });
-
- int64_t rank = contractOp.getIteratorTypesArray().size();
- auto mmaAttr = llvm::cast<MmaInterfaceAttr>(getIntrinsic());
- MLIRContext *context = getContext();
-
- SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
- if (llvm::any_of(bounds,
- [](int64_t x) { return x == ShapedType::kDynamic; })) {
- return failure();
- }
-
- if (!llvm::all_of(opInfo.getBatchDims(),
- [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
- LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
- return failure();
- }
-
- // Get the concrete nested layout for each matrix. Note that the struct
- // MMASingleSubgroupLayout contains the partial layout for the
- // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
- // contract op we are looking at right now may not be exactly in that form.
- // So here we need to permute/transpose the canonical layout to match with
- // the concrete contract op.
-
- // Note that no matter how we permute/transpose the input contraction
- // problem, the way we view the hardware warps remain the same--that is,
- // from the hardware's perspective, a single warp has the same warp ID no
- // matter what part of the contraction it works on. Similarly here, we are
- // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
- // logical warp+thread using the subgroup/thread basis, so the subgroup
- // basis should remain the same for all A/B/C matrix.
-
- auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
-
- SmallVector<int64_t, 2> subgroupMBasis;
- SmallVector<int64_t, 2> batchMSizes;
- int64_t currMCount = getSubgroupMCount();
-
- auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
- int64_t minDimSize) -> std::pair<int64_t, int64_t> {
- int64_t dividableDim = dimSize / minDimSize;
- int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
- dividableDim /= subgroupsUsed;
- int64_t batchesUsed = dividableDim;
- return {subgroupsUsed, batchesUsed};
- };
-
- // Greedily break up the M subgroup and batch counts along the "M" iteration
- // bounds. We distribute as many residual subgroups as possible per M dim,
- // and then divide the remaining along batch dims. The inner most M dim is
- // always the one used for the intrinsic, meaning for a valid schedule, the
- // computed batch counts and subgroup basis will satisfy totalMSize /
- // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
- for (auto dim : opInfo.getMDims()) {
- // Get the number of subgroups and batches used for this dimension based
- // on the intrinsic size and the bound size.
- int64_t subgroupsUsed, batchesUsed;
- if (dim == opInfo.getMDims().back()) {
- std::tie(subgroupsUsed, batchesUsed) =
- divideGreedily(currMCount, bounds[dim], intrinsicM);
- } else {
- std::tie(subgroupsUsed, batchesUsed) =
- divideGreedily(currMCount, bounds[dim], 1);
- }
- subgroupMBasis.push_back(subgroupsUsed);
- batchMSizes.push_back(batchesUsed);
- // Update available subgroup count.
- currMCount /= subgroupsUsed;
- }
-
- SmallVector<int64_t, 2> subgroupNBasis;
- SmallVector<int64_t, 2> batchNSizes;
- int64_t currNCount = getSubgroupNCount();
-
- // Do the same for N dims.
- for (auto dim : opInfo.getNDims()) {
- // Get the number of subgroups and batches used for this dimension based
- // on the intrinsic size and the bound size.
- int64_t subgroupsUsed, batchesUsed;
- if (dim == opInfo.getNDims().back()) {
- std::tie(subgroupsUsed, batchesUsed) =
- divideGreedily(currNCount, bounds[dim], intrinsicN);
- } else {
- std::tie(subgroupsUsed, batchesUsed) =
- divideGreedily(currNCount, bounds[dim], 1);
- }
- subgroupNBasis.push_back(subgroupsUsed);
- batchNSizes.push_back(batchesUsed);
- // Update available subgroup count.
- currNCount /= subgroupsUsed;
- }
-
- SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
- SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
-
- auto mDimVec = opInfo.getMDims();
- llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
- auto nDimVec = opInfo.getNDims();
- llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
- // Because we currently require all batch dimensions to be unit, the
- // subgroup basis can be constructed from the M and N bases. To keep things
- // simple, the current heuristic is to distribute the loop dimensions from
- // outer to inner.
- int64_t currStride = 1;
- int64_t currM = subgroupMStrides.size() - 1;
- int64_t currN = subgroupNStrides.size() - 1;
- for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
- if (mDims.contains(dim)) {
- subgroupMStrides[currM] = currStride;
- currStride *= subgroupMBasis[currM];
- currM--;
- continue;
- }
-
- if (nDims.contains(dim)) {
- subgroupNStrides[currN] = currStride;
- currStride *= subgroupNBasis[currN];
- currN--;
- continue;
- }
- }
-
- // C matrix layout
- auto [m, n] = opInfo.getResultMNIndex();
- int64_t cRank = opInfo.getCRank();
-
- // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
- // cNDims are the M and N dimensions of the C matrix in the order they are
- // iterated over in the contraction.
- SmallVector<int64_t> cMDims = opInfo.outMDims;
- SmallVector<int64_t> cNDims = opInfo.outNDims;
- SmallVector<int64_t> cBatchSizes(cRank, 1);
- SmallVector<int64_t> cSubgroupSizes(cRank, 1);
- SmallVector<int64_t> cSubgroupStrides(cRank, 0);
- for (auto [i, dim] : llvm::enumerate(cMDims)) {
- cBatchSizes[dim] = batchMSizes[i];
- cSubgroupSizes[dim] = subgroupMBasis[i];
- cSubgroupStrides[dim] = subgroupMStrides[i];
- }
- for (auto [i, dim] : llvm::enumerate(cNDims)) {
- cBatchSizes[dim] = batchNSizes[i];
- cSubgroupSizes[dim] = subgroupNBasis[i];
- cSubgroupStrides[dim] = subgroupNStrides[i];
- }
-
- auto cLayout = createNestedLayout(context, cRank, m, n,
- /*subgroupCount=*/cSubgroupSizes,
- /*subgroupStrides=*/cSubgroupStrides,
- /*batchCount=*/cBatchSizes,
- getCSingleSubgroupLayout(mmaAttr));
- LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
-
- // A matrix layout
- auto [afm, bfn] = opInfo.getOperandMNIndex();
- auto [afk, bfk] = opInfo.getOperandKIndex();
-
- int64_t aRank = opInfo.getARank();
-
- SmallVector<int64_t> aMDims = opInfo.lhsMDims;
- SmallVector<int64_t> aBatchSizes(aRank, 1);
- SmallVector<int64_t> aSubgroupSizes(aRank, 1);
- SmallVector<int64_t> aSubgroupStrides(aRank, 0);
- for (auto [i, dim] : llvm::enumerate(aMDims)) {
- aBatchSizes[dim] = batchMSizes[i];
- aSubgroupSizes[dim] = subgroupMBasis[i];
- aSubgroupStrides[dim] = subgroupMStrides[i];
- }
- for (auto [kDim, lhsKDim] :
- llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
- aBatchSizes[lhsKDim] = bounds[kDim];
- }
- aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
- auto aLayout = createNestedLayout(context, aRank, afm, afk,
- /*subgroupCount=*/aSubgroupSizes,
- /*subgroupStrides=*/aSubgroupStrides,
- /*batchCount=*/aBatchSizes,
- getASingleSubgroupLayout(mmaAttr));
- LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
-
- int64_t bRank = opInfo.getBRank();
-
- SmallVector<int64_t> bNDims = opInfo.rhsNDims;
- SmallVector<int64_t> bBatchSizes(bRank, 1);
- SmallVector<int64_t> bSubgroupSizes(bRank, 1);
- SmallVector<int64_t> bSubgroupStrides(bRank, 0);
- for (auto [i, dim] : llvm::enumerate(bNDims)) {
- bBatchSizes[dim] = batchNSizes[i];
- bSubgroupSizes[dim] = subgroupNBasis[i];
- bSubgroupStrides[dim] = subgroupNStrides[i];
- }
- for (auto [kDim, rhsKDim] :
- llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
- bBatchSizes[rhsKDim] = bounds[kDim];
- }
- bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
-
- auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
- /*subgroupCount=*/bSubgroupSizes,
- /*subgroupStrides=*/bSubgroupStrides,
- /*batchCount=*/bBatchSizes,
- getBSingleSubgroupLayout(mmaAttr));
- LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
-
- std::tuple<VectorLayoutInterface, VectorLayoutInterface,
- VectorLayoutInterface>
- result = {aLayout, bLayout, cLayout};
- return result;
-}
-
-//===----------------------------------------------------------------------===//
// Target Attributes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
index 7942066..92ce4f4 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h
@@ -22,12 +22,40 @@
namespace mlir::iree_compiler::IREE::GPU {
-// Partial nested layout for an MMA intrinsic's matrix input/output inside
-// a single subgroup.
+// Struct describing the detailed subgroup-level layout of a MMA intrinsic.
+// Together with element type information and subgroup size, it completes the
+// full description of the semantics of a MMA intrinsic.
+//
+// Note: It is not possible to infer subgroup size from the information in this
+// struct. The product of the `thread` sizes here is often, but not always equal
+// to subgroup size. When the product of the `thread` sizes (call that product
+// `P`) is smaller than subgroup size, it must be a divisor of it, and the
+// semantics in that case are that threads within the subgroup whose thread-ids
+// differ by a multiple of `P`, are accessing the same elements.
+//
+// Example observed in RDNA3 WMMA Wave64 intrinsics:
+// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
+// means that each element is being accessed by 2 threads (2 = 64/32), and the
+// threads accessing the same element are those whose tids are exactly 32 apart.
struct MMASingleSubgroupLayout {
+ // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
+ // outer-most in the layout. This happens when a MMA op, seen on a single
+ // thread, has an operand that consists of multiple elements, and these elems
+ // are NOT contiguous.
+ // This is not used by every MMA op; ops which don't use that simply have 1's.
SmallVector<int64_t, 2> outer;
+ // Cross-thread dimensions (as in TileSwizzle::Dim::Kind::CrossThread).
+ // This is the kind of dimension that is present in all GPU MMA ops, by
+ // definition of "SIMT". It is still possible for one of the `thread` dims to
+ // be 1, but not both.
SmallVector<int64_t, 2> thread;
+ // Strides corresponding to the cross-thread dimensions.
SmallVector<int64_t, 2> tstrides;
+ // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
+ // inner-most in the layout. This happens when a MMA op, seen on a single
+ // thread, has an operand that consists of multiple elements, and these elems
+ // are NOT contiguous.
+ // This is not used by every MMA op; ops which don't use that simply have 1's.
SmallVector<int64_t, 2> element;
};
@@ -43,6 +71,22 @@
MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
+// Struct describing the shape of a MMA operation, but not the detailed layout.
+// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is
+// LLVMGPU/TransformExtensions, so maybe make that internal again if/when that
+// goes away.
+struct OpaqueMmaLayout {
+ int64_t mSize = 0;
+ int64_t nSize = 0;
+ int64_t kSize = 0;
+ Type aType;
+ Type bType;
+ Type cType;
+};
+
+OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
+ IREE::GPU::MMAIntrinsic intrinsic);
+
} // namespace mlir::iree_compiler::IREE::GPU
// clang-format off
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index bc6f2d3..51674db 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -148,7 +148,6 @@
DeclareAttrInterfaceMethods<IREEGPU_MmaInterfaceAttr, [
"getABCElementTypes",
"getABCVectorTypes",
- "getContractionLayout",
"getMNKShape",
"getSubgroupSize",
"getMmaScope",
@@ -157,7 +156,6 @@
"getCSingleSubgroupLayout",
"buildMmaOperation",
"populateOperandOffsetsSizesStrides",
- "materializeOperandConcreteShape",
]>
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
@@ -259,19 +257,11 @@
|intrinsic| field specifies which particular MMA intrinsic is targeted by
the data-tiling.
- The tile swizzling already happens, so the attribute does not need to
- implement materializeOperandConcreteShape interface method. E.g., if the
- target intrinsic is MFMA_F32_16x16x4_F32:
- - The inner tile shape of LHS is 4x16.
- - The inner tile shape of RHS is 4x16.
- - The inner tile shape of ACC is 4x16x4.
-
- Furthermore, the unrolling and interleaving can be represented with the
- attribute. In the concept of data-tiling, we always unroll the parallel
- dimensions (i.e., M, N dimensions) to be outermost, and interleave the
- unrolled K dimension. I.e., the unrolled K dimension becomes the innermost
- dimension. The constraint can be relaxed based on data-tiling needs. The
- additional information can be added to `parameters`.
+ The other fields default to one, and that default results in a single
+ intrinsic equivalent to MMAAttr, while values greater than one result in
+ wider "kernels" consisting of multiple intrinsics, with the data layout
+ already swizzled into a tile layout that allows each intrinsic to access
+ data at an offset that's as simple as possible a mapping from the thread ID.
}];
let assemblyFormat = "`<` struct(params) `>`";
@@ -369,15 +359,6 @@
);
let assemblyFormat = "`<` struct(params) `>`";
-
- let extraClassDeclaration = [{
- // Returns the A/B/C matrix concrete layout targeting |contractOp|.
- ::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
- VectorExt::VectorLayoutInterface,
- VectorExt::VectorLayoutInterface>>
- getContractionLayout(::mlir::iree_compiler::VectorContractOpInfo &opInfo,
- ::mlir::linalg::LinalgOp contractOp) const;
- }];
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
index 94bcf3d..6b2e834 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
@@ -23,7 +23,80 @@
using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase;
void runOnOperation() override;
};
-} // namespace
+
+LogicalResult materializeOperandConcreteShape(
+ OpBuilder &builder, MMAAttr mma, IREE::GPU::MMAFragment fragment,
+ Value operand, std::optional<ArrayRef<int64_t>> permutation,
+ SmallVector<ReassociationIndices> &reassociations,
+ RankedTensorType &resultType) {
+
+ SmallVector<int64_t, 2> outerSizes;
+ SmallVector<int64_t, 2> opaqueSizes;
+ auto [m, n, k] = mma.getMNKShape();
+ switch (fragment) {
+ case IREE::GPU::MMAFragment::Lhs: {
+ outerSizes = mma.getASingleSubgroupLayout().outer;
+ opaqueSizes.append({m, k});
+ break;
+ }
+ case IREE::GPU::MMAFragment::Rhs: {
+ outerSizes = mma.getBSingleSubgroupLayout().outer;
+ opaqueSizes.append({k, n});
+ break;
+ }
+ case IREE::GPU::MMAFragment::Acc: {
+ outerSizes = mma.getCSingleSubgroupLayout().outer;
+ opaqueSizes.append({m, n});
+ break;
+ }
+ }
+ if (permutation.has_value()) {
+ if (permutation.value().size() != outerSizes.size()) {
+ return failure();
+ }
+ applyPermutationToVector(opaqueSizes, permutation.value());
+ applyPermutationToVector(outerSizes, permutation.value());
+ }
+
+ // Inner tile must have sizes matching the opaque layout.
+ auto operandType = llvm::cast<RankedTensorType>(operand.getType());
+ ArrayRef<int64_t> operandShape = operandType.getShape();
+ if (opaqueSizes != operandShape.take_back(opaqueSizes.size())) {
+ return failure();
+ }
+
+ // Expand the shape of the inner tile to reflect the MMA thread layout.
+ SmallVector<int64_t, 4> resultShape(operandShape.begin(),
+ operandShape.end() - 2);
+ SmallVector<ReassociationIndices> reInds =
+ llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
+ [](int64_t idx) -> ReassociationIndices {
+ return ReassociationIndices({idx});
+ });
+ int idx = reInds.size();
+ for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) {
+ // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives
+ // a guarantee that the |element| counts are contiguous within the layout,
+ // and a unit outer implies a single offset and size for that dimension.
+ if (outer == 1) {
+ resultShape.push_back(native);
+ reInds.push_back(ReassociationIndices({idx++}));
+ continue;
+ }
+
+ // Reshape to [outer, native / outer] == [outer, thread * element]. This
+ // corresponds to |outer| repetitions of the thread/element sublayout.
+ resultShape.push_back(outer);
+ assert(native % outer == 0 && "invalid mma layout");
+ resultShape.push_back(native / outer);
+ reInds.push_back(ReassociationIndices{idx, idx + 1});
+ idx += 2;
+ }
+
+ reassociations = reInds;
+ resultType = operandType.clone(resultShape);
+ return success();
+}
struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
using OpRewritePattern::OpRewritePattern;
@@ -56,12 +129,15 @@
}
// Get the reassociation indices and result type of the expand_shape op.
- MmaInterfaceAttr kind = mmaOp.getKind();
+ MMAAttr kind = dyn_cast<MMAAttr>(mmaOp.getKind());
+ if (!kind) {
+ return failure();
+ }
SmallVector<ReassociationIndices> reassociations;
RankedTensorType concreteType;
- if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand,
- permutation, reassociations,
- concreteType))) {
+ if (failed(materializeOperandConcreteShape(rewriter, kind, fragment,
+ operand, permutation,
+ reassociations, concreteType))) {
return failure();
}
@@ -140,6 +216,8 @@
MMAFragment fragment;
};
+} // namespace
+
void ConcretizeMmaShapesPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 69333a1..9df22c7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -71,6 +72,310 @@
return *config.getSubgroupNCount();
}
+/// Gets a unit vector of the given rank, but fills in the given dimensions
+/// from the 2 element array |counts|. |dim0| is the position in the returned
+/// vector to put the first element of |counts|, and |dim1| is the position to
+/// put the second element. For example,
+///
+/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1
+/// returns [1, 5, 7]
+static SmallVector<int64_t> getUnitOfRankWithDims(int64_t rank,
+ ArrayRef<int64_t> counts,
+ int64_t dim0, int64_t dim1) {
+ assert(counts.size() == 2 &&
+ "Unexpected non-rank 2 single subgroup dimension counts");
+ SmallVector<int64_t> res(rank, 1);
+ res[dim0] = counts[0];
+ res[dim1] = counts[1];
+ return res;
+}
+
+/// Constructs the nested layout given the layout for a single subgroup and the
+/// subgroup/batch counts and orders, as well as the dimensions along which to
+/// distribute the intrinsic's layout.
+///
+/// |outerDim| and |innerDim| refer to which dimensions are the outermost and
+/// innermost for a canonical MK_KN_MN matrix multiply, for a particular
+/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply,
+/// we would have:
+/// outerDim = 1 for the K dim
+/// innerDim = 0 for the N dim
+///
+/// For something like MK_NKN_MN with multiple N dims, it would typically be:
+/// outerDim = 1 for K
+/// innerDim = 2 for the second N dim
+///
+/// Importantly these two dimensions always refer to the actual dimension
+/// positions in the undistributed vector. For each fragment, this means:
+/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim]
+/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim]
+/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim]
+///
+/// And here inner most is referential to the iteration order, not the order
+/// they appear per fragment (because there is no relationship between the
+/// dimension order of M in A and in C, for example).
+static NestedLayoutAttr createNestedLayout(
+ MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
+ ArrayRef<int64_t> subgroupSizes, ArrayRef<int64_t> subgroupStrides,
+ ArrayRef<int64_t> batchCount, IREE::GPU::MMASingleSubgroupLayout counts) {
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Creating Nested Layout for::";
+ llvm::dbgs() << "\n outerDim = " << outerDim;
+ llvm::dbgs() << "\n innerDim = " << innerDim;
+ llvm::dbgs() << "\n subgroupSizes: ";
+ llvm::interleaveComma(subgroupSizes, llvm::dbgs());
+ llvm::dbgs() << "\n subgroupStrides: ";
+ llvm::interleaveComma(subgroupStrides, llvm::dbgs());
+ llvm::dbgs() << "\n batchCount: ";
+ llvm::interleaveComma(batchCount, llvm::dbgs());
+ llvm::dbgs() << "\n counts.outer: ";
+ llvm::interleaveComma(counts.outer, llvm::dbgs());
+ llvm::dbgs() << "\n counts.thread: ";
+ llvm::interleaveComma(counts.thread, llvm::dbgs());
+ llvm::dbgs() << "\n counts.element: ";
+ llvm::interleaveComma(counts.element, llvm::dbgs());
+ llvm::dbgs() << "\n counts.tstrides: ";
+ llvm::interleaveComma(counts.tstrides, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ SmallVector<int64_t> outerCount =
+ getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim);
+ SmallVector<int64_t> threadCount =
+ getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
+ SmallVector<int64_t> threadStrides =
+ getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim);
+ SmallVector<int64_t> elementCount =
+ getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim);
+
+ auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount,
+ outerCount, threadCount, elementCount,
+ subgroupStrides, threadStrides);
+ return layoutAttr;
+}
+
+static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface,
+ IREE::VectorExt::VectorLayoutInterface>>
+getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
+ VectorContractOpInfo &opInfo,
+ linalg::LinalgOp contractOp) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
+ llvm::dbgs() << "For schedule: " << schedule << "\n";
+ });
+
+ int64_t rank = contractOp.getIteratorTypesArray().size();
+ auto mmaAttr =
+ llvm::cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
+ MLIRContext *context = schedule.getContext();
+
+ SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
+ if (llvm::any_of(bounds,
+ [](int64_t x) { return x == ShapedType::kDynamic; })) {
+ return failure();
+ }
+
+ if (!llvm::all_of(opInfo.getBatchDims(),
+ [&bounds](int64_t dim) { return bounds[dim] == 1; })) {
+ LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
+ return failure();
+ }
+
+ // Get the concrete nested layout for each matrix. Note that the struct
+ // MMASingleSubgroupLayout contains the partial layout for the
+ // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
+ // contract op we are looking at right now may not be exactly in that form.
+ // So here we need to permute/transpose the canonical layout to match with
+ // the concrete contract op.
+
+ // Note that no matter how we permute/transpose the input contraction
+ // problem, the way we view the hardware warps remain the same--that is,
+ // from the hardware's perspective, a single warp has the same warp ID no
+ // matter what part of the contraction it works on. Similarly here, we are
+ // delinearizing the linearized GPU hardware lane ID into a n-D concatenated
+ // logical warp+thread using the subgroup/thread basis, so the subgroup
+ // basis should remain the same for all A/B/C matrix.
+
+ auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
+
+ SmallVector<int64_t, 2> subgroupMBasis;
+ SmallVector<int64_t, 2> batchMSizes;
+ int64_t currMCount = schedule.getSubgroupMCount();
+
+ auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
+ int64_t minDimSize) -> std::pair<int64_t, int64_t> {
+ int64_t dividableDim = dimSize / minDimSize;
+ int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
+ dividableDim /= subgroupsUsed;
+ int64_t batchesUsed = dividableDim;
+ return {subgroupsUsed, batchesUsed};
+ };
+
+ // Greedily break up the M subgroup and batch counts along the "M" iteration
+ // bounds. We distribute as many residual subgroups as possible per M dim,
+ // and then divide the remaining along batch dims. The inner most M dim is
+ // always the one used for the intrinsic, meaning for a valid schedule, the
+ // computed batch counts and subgroup basis will satisfy totalMSize /
+ // intrinsicM = product(batchMSizes) * product(subgroupMBasis)
+ for (auto dim : opInfo.getMDims()) {
+ // Get the number of subgroups and batches used for this dimension based
+ // on the intrinsic size and the bound size.
+ int64_t subgroupsUsed, batchesUsed;
+ if (dim == opInfo.getMDims().back()) {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currMCount, bounds[dim], intrinsicM);
+ } else {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currMCount, bounds[dim], 1);
+ }
+ subgroupMBasis.push_back(subgroupsUsed);
+ batchMSizes.push_back(batchesUsed);
+ // Update available subgroup count.
+ currMCount /= subgroupsUsed;
+ }
+
+ SmallVector<int64_t, 2> subgroupNBasis;
+ SmallVector<int64_t, 2> batchNSizes;
+ int64_t currNCount = schedule.getSubgroupNCount();
+
+ // Do the same for N dims.
+ for (auto dim : opInfo.getNDims()) {
+ // Get the number of subgroups and batches used for this dimension based
+ // on the intrinsic size and the bound size.
+ int64_t subgroupsUsed, batchesUsed;
+ if (dim == opInfo.getNDims().back()) {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currNCount, bounds[dim], intrinsicN);
+ } else {
+ std::tie(subgroupsUsed, batchesUsed) =
+ divideGreedily(currNCount, bounds[dim], 1);
+ }
+ subgroupNBasis.push_back(subgroupsUsed);
+ batchNSizes.push_back(batchesUsed);
+ // Update available subgroup count.
+ currNCount /= subgroupsUsed;
+ }
+
+ SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
+ SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
+
+ auto mDimVec = opInfo.getMDims();
+ llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
+ auto nDimVec = opInfo.getNDims();
+ llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
+ // Because we currently require all batch dimensions to be unit, the
+ // subgroup basis can be constructed from the M and N bases. To keep things
+ // simple, the current heuristic is to distribute the loop dimensions from
+ // outer to inner.
+ int64_t currStride = 1;
+ int64_t currM = subgroupMStrides.size() - 1;
+ int64_t currN = subgroupNStrides.size() - 1;
+ for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
+ if (mDims.contains(dim)) {
+ subgroupMStrides[currM] = currStride;
+ currStride *= subgroupMBasis[currM];
+ currM--;
+ continue;
+ }
+
+ if (nDims.contains(dim)) {
+ subgroupNStrides[currN] = currStride;
+ currStride *= subgroupNBasis[currN];
+ currN--;
+ continue;
+ }
+ }
+
+ // C matrix layout
+ auto [m, n] = opInfo.getResultMNIndex();
+ int64_t cRank = opInfo.getCRank();
+
+ // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
+ // cNDims are the M and N dimensions of the C matrix in the order they are
+ // iterated over in the contraction.
+ SmallVector<int64_t> cMDims = opInfo.outMDims;
+ SmallVector<int64_t> cNDims = opInfo.outNDims;
+ SmallVector<int64_t> cBatchSizes(cRank, 1);
+ SmallVector<int64_t> cSubgroupSizes(cRank, 1);
+ SmallVector<int64_t> cSubgroupStrides(cRank, 0);
+ for (auto [i, dim] : llvm::enumerate(cMDims)) {
+ cBatchSizes[dim] = batchMSizes[i];
+ cSubgroupSizes[dim] = subgroupMBasis[i];
+ cSubgroupStrides[dim] = subgroupMStrides[i];
+ }
+ for (auto [i, dim] : llvm::enumerate(cNDims)) {
+ cBatchSizes[dim] = batchNSizes[i];
+ cSubgroupSizes[dim] = subgroupNBasis[i];
+ cSubgroupStrides[dim] = subgroupNStrides[i];
+ }
+
+ auto cLayout = createNestedLayout(context, cRank, m, n,
+ /*subgroupCount=*/cSubgroupSizes,
+ /*subgroupStrides=*/cSubgroupStrides,
+ /*batchCount=*/cBatchSizes,
+ getCSingleSubgroupLayout(mmaAttr));
+ LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
+
+ // A matrix layout
+ auto [afm, bfn] = opInfo.getOperandMNIndex();
+ auto [afk, bfk] = opInfo.getOperandKIndex();
+
+ int64_t aRank = opInfo.getARank();
+
+ SmallVector<int64_t> aMDims = opInfo.lhsMDims;
+ SmallVector<int64_t> aBatchSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupSizes(aRank, 1);
+ SmallVector<int64_t> aSubgroupStrides(aRank, 0);
+ for (auto [i, dim] : llvm::enumerate(aMDims)) {
+ aBatchSizes[dim] = batchMSizes[i];
+ aSubgroupSizes[dim] = subgroupMBasis[i];
+ aSubgroupStrides[dim] = subgroupMStrides[i];
+ }
+ for (auto [kDim, lhsKDim] :
+ llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
+ aBatchSizes[lhsKDim] = bounds[kDim];
+ }
+ aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+ auto aLayout = createNestedLayout(context, aRank, afm, afk,
+ /*subgroupCount=*/aSubgroupSizes,
+ /*subgroupStrides=*/aSubgroupStrides,
+ /*batchCount=*/aBatchSizes,
+ getASingleSubgroupLayout(mmaAttr));
+ LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
+
+ int64_t bRank = opInfo.getBRank();
+
+ SmallVector<int64_t> bNDims = opInfo.rhsNDims;
+ SmallVector<int64_t> bBatchSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupSizes(bRank, 1);
+ SmallVector<int64_t> bSubgroupStrides(bRank, 0);
+ for (auto [i, dim] : llvm::enumerate(bNDims)) {
+ bBatchSizes[dim] = batchNSizes[i];
+ bSubgroupSizes[dim] = subgroupNBasis[i];
+ bSubgroupStrides[dim] = subgroupNStrides[i];
+ }
+ for (auto [kDim, rhsKDim] :
+ llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
+ bBatchSizes[rhsKDim] = bounds[kDim];
+ }
+ bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
+
+ auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
+ /*subgroupCount=*/bSubgroupSizes,
+ /*subgroupStrides=*/bSubgroupStrides,
+ /*batchCount=*/bBatchSizes,
+ getBSingleSubgroupLayout(mmaAttr));
+ LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
+
+ std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+ VectorLayoutInterface>
+ result = {aLayout, bLayout, cLayout};
+ return result;
+}
+
static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
SmallVector<bool> promotedOperands,
RewriterBase &rewriter,
@@ -89,7 +394,7 @@
contract.getIndexingMapsArray());
assert(succeeded(opInfo) && "contraction should have been inferred");
- auto layouts = schedule.getContractionLayout(opInfo.value(), contract);
+ auto layouts = getContractionLayout(schedule, opInfo.value(), contract);
if (failed(layouts)) {
return contract->emitError("cannot get concrete layout for contraction");
}
@@ -176,7 +481,7 @@
assert(succeeded(opInfo) &&
"unit filter dim convolution should have been infered");
- auto layouts = schedule.getContractionLayout(opInfo.value(), conv);
+ auto layouts = getContractionLayout(schedule, opInfo.value(), conv);
if (failed(layouts)) {
return conv->emitError("cannot get concrete layout for convolution");
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 3dd0c12..a2c2187 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -1599,6 +1600,219 @@
return VectorExt::LayoutAttr::get(ctx, perDimLayouts);
}
+// Struct containing concrete MMA shape, type, and layout information.
+struct ConcreteMmaLayout {
+ GPU::OpaqueMmaLayout base;
+ VectorExt::PerDimLayoutAttr aMLayout;
+ VectorExt::PerDimLayoutAttr aKLayout;
+ VectorExt::PerDimLayoutAttr bKLayout;
+ VectorExt::PerDimLayoutAttr bNLayout;
+ VectorExt::PerDimLayoutAttr cMLayout;
+ VectorExt::PerDimLayoutAttr cNLayout;
+};
+
+static std::tuple<VectorExt::PerDimLayoutAttr, VectorExt::PerDimLayoutAttr>
+getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
+ // Step 1: obtain the swizzled tile shape, but keeping track of the source
+ // dimension indices.
+ struct SrcIndexAndSwizzleDim {
+ size_t srcIndex;
+ TileSwizzle::Dim dim;
+ };
+ SmallVector<SrcIndexAndSwizzleDim> swizzledShape;
+ for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) {
+ for (TileSwizzle::Dim d : e) {
+ swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d});
+ }
+ }
+ applyPermutationToVector(swizzledShape, swizzle.permutation);
+
+ // Step 2: collect the appropriate labels to use for the swizzled dims.
+ VectorExt::LayoutDimension internalLabels[] = {
+ VectorExt::LayoutDimension::VECTORZ, VectorExt::LayoutDimension::VECTORY,
+ VectorExt::LayoutDimension::VECTORX};
+ VectorExt::LayoutDimension crossThreadLabels[] = {
+ VectorExt::LayoutDimension::LANEZ, VectorExt::LayoutDimension::LANEY,
+ VectorExt::LayoutDimension::LANEX};
+ auto internalLabelIter = std::end(internalLabels);
+ auto crossThreadLabelIter = std::end(crossThreadLabels);
+ for (SrcIndexAndSwizzleDim d : swizzledShape) {
+ if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) {
+ assert(internalLabelIter != std::begin(internalLabels));
+ --internalLabelIter;
+ } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) {
+ assert(crossThreadLabelIter != std::begin(crossThreadLabels));
+ --crossThreadLabelIter;
+ } else {
+ assert(false && "unexpected dimension kind in intrinsic swizzle");
+ }
+ }
+
+ // Step 3: put together the result PerDimLayoutAttr'd for the two source dims.
+ SmallVector<VectorExt::LayoutDimensionAttr> labels[2];
+ SmallVector<int64_t> shape[2];
+ for (SrcIndexAndSwizzleDim d : swizzledShape) {
+ shape[d.srcIndex].push_back(d.dim.size);
+ auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal)
+ ? internalLabelIter
+ : crossThreadLabelIter;
+ labels[d.srcIndex].push_back(VectorExt::LayoutDimensionAttr::get(
+ context, static_cast<VectorExt::LayoutDimension>(*labelIterRef++)));
+ }
+ return {VectorExt::PerDimLayoutAttr::get(context, labels[0], shape[0]),
+ VectorExt::PerDimLayoutAttr::get(context, labels[1], shape[1])};
+};
+
+static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
+ GPU::MMAIntrinsic intrinsic) {
+ auto opaque = GPU::getOpaqueMMALayout(context, intrinsic);
+ ConcreteMmaLayout concreteLayout;
+ concreteLayout.base = opaque;
+ auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Lhs);
+ auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Rhs);
+ auto accSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Acc);
+ std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) =
+ getPerDimLayoutAttrs(context, lhsSwizzle);
+ std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) =
+ getPerDimLayoutAttrs(context, rhsSwizzle);
+ std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) =
+ getPerDimLayoutAttrs(context, accSwizzle);
+ return concreteLayout;
+}
+
+static VectorExt::PerDimLayoutAttr
+getBatchedPerDimLayoutAttr(VectorExt::LayoutDimensionAttr batchDim,
+ VectorExt::PerDimLayoutAttr baseLayout,
+ int64_t problemSize, int64_t fragmentDimSize) {
+ assert(problemSize % fragmentDimSize == 0 &&
+ "invalid layout fragment for problem size");
+
+ SmallVector<VectorExt::LayoutDimensionAttr, 3> dimAttrs(
+ baseLayout.getLabels());
+ dimAttrs.insert(dimAttrs.begin(), batchDim);
+
+ SmallVector<int64_t, 3> shapes(baseLayout.getShapes());
+ shapes.insert(shapes.begin(), problemSize / fragmentDimSize);
+ auto layout = VectorExt::PerDimLayoutAttr::get(baseLayout.getContext(),
+ dimAttrs, shapes);
+ return layout;
+}
+
+// Get the batched layout attributes for the given fragment layouts, indexing
+// map, and problem shape. The canonical fragment map is used to compare against
+// the problem map |indexingMap|. For example, for mma fragment B (RHS):
+//
+// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B
+// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1)
+// problemShape = [32, 64]
+// fragmentSize = [16, 8]
+// fragmentLayouts = [kLayout, nLayout]
+//
+// Gives batched layout
+//
+// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape]
+// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape]
+static VectorExt::LayoutAttr
+getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap,
+ ArrayRef<int64_t> problemShape,
+ ArrayRef<int64_t> fragmentSize,
+ ArrayRef<VectorExt::PerDimLayoutAttr> fragmentLayouts) {
+ // Current distribution to MFMA operations does not support batched
+ // contractions so that is reflected here.
+ assert(indexingMap.getNumResults() == 2 &&
+ "invalid indexing map to non-batched simple contraction");
+
+ VectorExt::LayoutDimensionAttr batchX = VectorExt::LayoutDimensionAttr::get(
+ indexingMap.getContext(), VectorExt::LayoutDimension::BATCHX);
+ VectorExt::LayoutDimensionAttr batchY = VectorExt::LayoutDimensionAttr::get(
+ indexingMap.getContext(), VectorExt::LayoutDimension::BATCHY);
+
+ SmallVector<VectorExt::PerDimLayoutAttr, 2> perDimAttrs;
+ for (auto [expr, batchType] : llvm::zip_equal(
+ indexingMap.getResults(),
+ SmallVector<VectorExt::LayoutDimensionAttr, 2>{batchX, batchY})) {
+ auto maybeResultPosition = fragmentMap.getResultPosition(expr);
+ assert(maybeResultPosition && "fragment map and problem map mismatch");
+ int64_t idx = *maybeResultPosition;
+ perDimAttrs.push_back(getBatchedPerDimLayoutAttr(
+ batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx]));
+ }
+
+ return VectorExt::LayoutAttr::get(indexingMap.getContext(), perDimAttrs);
+}
+
+static FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
+ VectorLayoutInterface>>
+getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) {
+ MLIRContext *context = contract.getContext();
+ FailureOr<linalg::ContractionDimensions> maybeContractionDims =
+ linalg::inferContractionDims(contract.getIndexingMapsArray());
+ if (failed(maybeContractionDims)) {
+ return failure();
+ }
+ auto contractionDims = *maybeContractionDims;
+ // TODO: Relax this condition to strictly alignment requirements.
+ if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 ||
+ contractionDims.n.size() != 1) {
+ return failure();
+ }
+ // TODO: Support batched contractions.
+ if (contractionDims.batch.size() > 0) {
+ return failure();
+ }
+ unsigned mDim = contractionDims.m[0];
+ unsigned nDim = contractionDims.n[0];
+ unsigned kDim = contractionDims.k[0];
+
+ SmallVector<int64_t> iterationBounds;
+ contract.getIterationBounds(iterationBounds);
+
+ int64_t problemMSize = iterationBounds[mDim];
+ int64_t problemNSize = iterationBounds[nDim];
+ int64_t problemKSize = iterationBounds[kDim];
+
+ int64_t mSize = layout.base.mSize;
+ int64_t nSize = layout.base.nSize;
+ int64_t kSize = layout.base.kSize;
+
+ // The problem size currently must be strictly aligned to the size of the mma.
+ // This is expected to succeed assuming the correct [masked] vector size was
+ // set at strategy configuration time (for this mma).
+ if (problemMSize % mSize != 0 || problemNSize % nSize ||
+ problemKSize % kSize) {
+ return failure();
+ }
+
+ VectorExt::LayoutAttr aLayout = getBatchedLayoutAttr(
+ contract.getIndexingMapsArray()[0],
+ AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context),
+ {problemMSize, problemKSize}, {mSize, kSize},
+ {layout.aMLayout, layout.aKLayout});
+ VectorExt::LayoutAttr bLayout = getBatchedLayoutAttr(
+ contract.getIndexingMapsArray()[1],
+ AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context),
+ {problemKSize, problemNSize}, {kSize, nSize},
+ {layout.bKLayout, layout.bNLayout});
+ VectorExt::LayoutAttr cLayout = getBatchedLayoutAttr(
+ contract.getIndexingMapsArray()[2],
+ AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context),
+ {problemMSize, problemNSize}, {mSize, nSize},
+ {layout.cMLayout, layout.cNLayout});
+
+ return std::make_tuple<VectorLayoutInterface, VectorLayoutInterface,
+ VectorLayoutInterface>(aLayout, bLayout, cLayout);
+}
+
+FailureOr<std::tuple<
+ VectorLayoutInterface, VectorLayoutInterface,
+ VectorLayoutInterface>> static getContractionLayout(GPU::MMAAttr mma,
+ vector::ContractionOp
+ contract) {
+ ConcreteMmaLayout layout = getConcreteMMALayout(
+ contract->getContext(), mma.getIntrinsic().getValue());
+ return getContractionLayout(contract, layout);
+}
+
DiagnosedSilenceableFailure
transform_dialect::SetContractionLayoutAttributes::apply(
transform::TransformRewriter &rewriter,
@@ -1609,7 +1823,7 @@
return emitDefiniteFailure()
<< "invalid more than one attribute for contraction annotation";
}
- auto mmaType = llvm::dyn_cast<IREE::GPU::MmaInterfaceAttr>(typeList.front());
+ auto mmaType = llvm::dyn_cast<GPU::MMAAttr>(typeList.front());
if (!mmaType) {
return emitDefiniteFailure()
<< "invalid non-mma attribute for contraction annotation "
@@ -1623,7 +1837,7 @@
<< "invalid non-contraction annotation " << payload;
}
- auto maybeLayouts = mmaType.getContractionLayout(contract);
+ auto maybeLayouts = getContractionLayout(mmaType, contract);
if (failed(maybeLayouts)) {
return emitDefiniteFailure()
<< "invalid opaque mma layout for annotation " << mmaType;