[Codegen][Common] Resolve `scf.forall` that are used for workgroup distribution (#18368)
The `scf.forall` for workgroup distribution need to be resolved after
startegy lowering. Since this needs an update to the entry point region,
the `scf.forall` resolution is now added with
`ReconcileTranslationInfoPass`. If needed this could be moved out of
here.
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index e8e18f5..d7a81fc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -16,7 +16,10 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_RECONCILETRANSLATIONINFOPASS
@@ -32,6 +35,270 @@
};
} // namespace
+//===---------------------------------------------------------------------===//
+// Resolve `scf.forall` operations
+//===---------------------------------------------------------------------===//
+
+/// Verify that the mapping attribute provided is of the right form.
+static FailureOr<SmallVector<IREE::Codegen::WorkgroupMappingAttr>>
+verifyWorkgroupMappingAttrArray(scf::ForallOp forallOp) {
+ std::optional<ArrayAttr> mappingAttr = forallOp.getMapping();
+ if (!mappingAttr) {
+ return forallOp.emitOpError("expected mapped for all op");
+ }
+ if (mappingAttr.value().empty()) {
+ return forallOp.emitOpError("mapping attribute cannot be empty");
+ }
+ if (failed(IREE::Codegen::WorkgroupMappingAttr::verifyAttrList(
+ forallOp.getContext(), forallOp.getLoc(), mappingAttr->getValue()))) {
+ return failure();
+ }
+ SmallVector<IREE::Codegen::WorkgroupMappingAttr> workgroupMappingAttrs =
+ llvm::map_to_vector(mappingAttr.value(), [](Attribute attr) {
+ return cast<IREE::Codegen::WorkgroupMappingAttr>(attr);
+ });
+ return workgroupMappingAttrs;
+}
+
+/// Get the permutation that represents the mapping of loop dimensions to
+/// process dimensions.
+SmallVector<int64_t>
+getMappingPermutation(ArrayRef<IREE::Codegen::WorkgroupMappingAttr> mapping) {
+ return invertPermutationVector(llvm::map_to_vector(mapping, [&](auto a) {
+ return static_cast<int64_t>(mapping.size() - 1) - a.getMappingId();
+ }));
+}
+
+/// Return the procId and nprocs to use for each of the distributed loops,
+/// derived from `hal.interface.workgroup.id/count`s.
+static FailureOr<
+ std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>>
+getProcIdsAndNprocs(
+ scf::ForallOp forallOp, RewriterBase &builder, Location loc,
+ SmallVector<IREE::Codegen::WorkgroupMappingAttr> workgroupMappings,
+ SmallVector<OpFoldResult> lowerBounds,
+ SmallVector<OpFoldResult> upperBounds, SmallVector<OpFoldResult> steps) {
+ if (workgroupMappings.size() != lowerBounds.size()) {
+ return forallOp.emitOpError(
+ "expected as many workgroup mapping attributes as number of loops");
+ }
+
+ auto permutation = getMappingPermutation(workgroupMappings);
+ applyPermutationToVector(workgroupMappings, permutation);
+ applyPermutationToVector(lowerBounds, permutation);
+ applyPermutationToVector(upperBounds, permutation);
+ applyPermutationToVector(steps, permutation);
+
+ SmallVector<OpFoldResult> procId(workgroupMappings.size(),
+ builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> nprocs(workgroupMappings.size(),
+ builder.getIndexAttr(1));
+
+ AffineExpr s0, s1, s2;
+ bindSymbols(builder.getContext(), s0, s1, s2);
+ AffineExpr extentExpr = (s1 - s0).ceilDiv(s2);
+ IREE::Codegen::WorkgroupMappingAttr baseZDim =
+ IREE::Codegen::WorkgroupMappingAttr::get(builder.getContext(),
+ IREE::Codegen::WorkgroupId::IdZ);
+ SmallVector<OpFoldResult> loopExtents;
+ if (workgroupMappings.size() > baseZDim.getMappingId()) {
+ loopExtents.resize(workgroupMappings.size() - baseZDim.getMappingId());
+ }
+ for (int index = workgroupMappings.size() - 1; index >= 0; --index) {
+ auto workgroupMapping = workgroupMappings[index];
+ auto lowerBound = lowerBounds[index];
+ auto upperBound = upperBounds[index];
+ auto step = steps[index];
+ switch (workgroupMapping.getId()) {
+ case IREE::Codegen::WorkgroupId::IdX:
+ procId[index] =
+ builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 0).getResult();
+ nprocs[index] =
+ builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0)
+ .getResult();
+ break;
+ case IREE::Codegen::WorkgroupId::IdY:
+ procId[index] =
+ builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 1).getResult();
+ nprocs[index] =
+ builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1)
+ .getResult();
+ break;
+ case IREE::Codegen::WorkgroupId::IdZ: {
+ OpFoldResult extent = affine::makeComposedFoldedAffineApply(
+ builder, loc, extentExpr, {lowerBound, upperBound, step});
+ loopExtents[index] = extent;
+ break;
+ }
+ }
+ }
+
+ // Delinearize the z-dim based on the loop extents.
+ if (!loopExtents.empty()) {
+ Value zDimId =
+ builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 2).getResult();
+ OpFoldResult zNprocs =
+ builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 2)
+ .getResult();
+
+ if (loopExtents.size() != 1) {
+ auto delinearizeOp = builder.create<affine::AffineDelinearizeIndexOp>(
+ loc, zDimId, loopExtents);
+ SmallVector<OpFoldResult> orderedDelinearizedDimIds =
+ llvm::map_to_vector(delinearizeOp.getResults(),
+ [](Value v) -> OpFoldResult { return v; });
+ SmallVector<OpFoldResult> orderedDelinearizedNprocs;
+ AffineMap minMap = AffineMap::get(0, 2, {s0, s1}, builder.getContext());
+ AffineExpr ceilDivExpr = s0.ceilDiv(s1);
+ for (int index = loopExtents.size() - 1; index >= 0; --index) {
+ auto extent = loopExtents[index];
+ procId[index] = delinearizeOp->getResult(index);
+ OpFoldResult currNprocs = affine::makeComposedFoldedAffineMin(
+ builder, loc, minMap, {extent, zNprocs});
+ nprocs[index] = currNprocs;
+ zNprocs = affine::makeComposedFoldedAffineApply(
+ builder, loc, ceilDivExpr, {zNprocs, currNprocs});
+ }
+ } else {
+ // If there is only one z-dim mapping, just use the ID directly.
+ procId[0] = zDimId;
+ nprocs[0] = zNprocs;
+ }
+ }
+
+ auto inversePermutation = invertPermutationVector(permutation);
+ applyPermutationToVector(procId, inversePermutation);
+ applyPermutationToVector(nprocs, inversePermutation);
+ return std::make_pair(procId, nprocs);
+}
+
+/// Resolve scf.forall operation by using the workgroup ID and counts.
+static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
+ if (forallOp->getNumResults() != 0) {
+ return forallOp.emitOpError(
+ "cannot resolve for all ops with return values");
+ }
+ SmallVector<OpFoldResult> mixedLowerBound = forallOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> mixedUpperBound = forallOp.getMixedUpperBound();
+ SmallVector<OpFoldResult> mixedStep = forallOp.getMixedStep();
+ FailureOr<SmallVector<IREE::Codegen::WorkgroupMappingAttr>> workgroupMapping =
+ verifyWorkgroupMappingAttrArray(forallOp);
+ if (failed(workgroupMapping)) {
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ SmallVector<OpFoldResult> procId;
+
+ {
+ FailureOr<std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>>
+ procInfo =
+ getProcIdsAndNprocs(forallOp, rewriter, forallOp.getLoc(),
+ workgroupMapping.value(), mixedLowerBound,
+ mixedUpperBound, mixedStep);
+ if (failed(procInfo)) {
+ return failure();
+ }
+ std::swap(procId, procInfo->first);
+ }
+
+ /// For now this is assuming that number of workgroups is exactly equal to
+ /// the iterations for each loop dimension. Just inline the forall body into
+ /// the parent.
+ Block *parentBlock = forallOp->getBlock();
+ Block *remainingBlock =
+ rewriter.splitBlock(parentBlock, Block::iterator(forallOp));
+ auto argReplacements =
+ getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId);
+ Block *loopBody = forallOp.getBody();
+ rewriter.eraseOp(loopBody->getTerminator());
+ rewriter.mergeBlocks(loopBody, parentBlock, argReplacements);
+ rewriter.mergeBlocks(remainingBlock, parentBlock, ValueRange{});
+ rewriter.eraseOp(forallOp);
+ return success();
+}
+
+static LogicalResult resolveWorkgroupCount(RewriterBase &rewriter,
+ mlir::FunctionOpInterface funcOp,
+ scf::ForallOp forAllOp) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forAllOp);
+ SmallVector<OpFoldResult> lowerBounds = forAllOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> upperBounds = forAllOp.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = forAllOp.getMixedStep();
+ SmallVector<OpFoldResult> workgroupCount(lowerBounds.size());
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr countExpr = (s1 - s0).ceilDiv(s2);
+ for (auto [index, lb, ub, step] :
+ llvm::enumerate(lowerBounds, upperBounds, steps)) {
+ workgroupCount[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, forAllOp.getLoc(), countExpr, {lb, ub, step});
+ }
+ auto mappingAttr =
+ llvm::map_to_vector(forAllOp.getMapping().value(), [](auto a) {
+ return cast<IREE::Codegen::WorkgroupMappingAttr>(a);
+ });
+ auto permutation = getMappingPermutation(mappingAttr);
+ workgroupCount = applyPermutation(workgroupCount, permutation);
+ return lowerWorkgroupCountFromSliceOp(rewriter, funcOp, workgroupCount);
+}
+
+static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
+ FunctionOpInterface funcOp) {
+ Region &body = funcOp.getFunctionBody();
+
+ if (body.empty()) {
+ return success();
+ }
+
+ if (!llvm::hasSingleElement(body)) {
+ return funcOp.emitOpError("unhandled function with multiple blocks");
+ }
+
+ auto forAllOps = body.getOps<scf::ForallOp>();
+ SmallVector<scf::ForallOp> workgroupForAllOps = llvm::to_vector(
+ llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) {
+ auto mapping = forAllOp.getMapping();
+ if (!mapping) {
+ return false;
+ }
+ if (!llvm::all_of(mapping.value(), [](Attribute attr) {
+ return isa<IREE::Codegen::WorkgroupMappingAttr>(attr);
+ })) {
+ return false;
+ }
+ return true;
+ }));
+
+ if (workgroupForAllOps.empty()) {
+ // If there are no workgroup distribution loops, set the default
+ // number of workgroups to {1, 1, 1}. Note: that this only kicks
+ // in if the export op region has
+ // `flow.dispatch.workgroup_count_from_slice
+ return lowerWorkgroupCountFromSliceOp(rewriter, funcOp,
+ ArrayRef<OpFoldResult>{});
+ }
+ if (!llvm::hasSingleElement(workgroupForAllOps)) {
+ return funcOp.emitOpError("unhandled resolution of zero/multiple "
+ "scf.forall ops withing the function");
+ }
+
+ scf::ForallOp forallOp = *forAllOps.begin();
+ if (failed(resolveWorkgroupCount(rewriter, funcOp, forallOp))) {
+ return failure();
+ }
+
+ return resolveWorkgroupForAll(rewriter, *forAllOps.begin());
+}
+
+//===---------------------------------------------------------------------===//
+// End Resolve `scf.forall` operations
+//===---------------------------------------------------------------------===//
+
// Reconcile workgroup sizes across all translation infos.
static FailureOr<SmallVector<int64_t>> reconcileWorkgroupSize(
ArrayRef<IREE::Codegen::TranslationInfoAttr> translationInfos) {
@@ -89,25 +356,37 @@
return signalPassFailure();
}
auto exportOp = *exportOps.begin();
- Builder builder(&getContext());
+ IRRewriter rewriter(&getContext());
SmallVector<IREE::Codegen::TranslationInfoAttr> translationInfos;
- innerModuleOp->walk([&](FunctionOpInterface funcOp) {
- auto translationInfo = getTranslationInfo(funcOp);
- if (!translationInfo) {
- return;
- }
+ auto walkResult =
+ innerModuleOp->walk([&](FunctionOpInterface funcOp) -> WalkResult {
+ // Resolve workgroup distribution related `scf.forall` ops.
+ if (failed(resolveWorkgroupForAll(rewriter, funcOp))) {
+ return failure();
+ }
- translationInfos.push_back(translationInfo);
- // The following is moving the target-func-attrs specification from
- // translation info into the func-like op. This is not the best
- // place to do this, but the intent is after this pass all the
- // lowering configs and translation infos will be deleted.
- DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
- if (targetFuncAttrs) {
- funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
- }
- });
+ auto translationInfo = getTranslationInfo(funcOp);
+ if (!translationInfo) {
+ return WalkResult::advance();
+ }
+
+ translationInfos.push_back(translationInfo);
+ // The following is moving the target-func-attrs specification from
+ // translation info into the func-like op. This is not the best
+ // place to do this, but the intent is after this pass all the
+ // lowering configs and translation infos will be deleted.
+ DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
+ if (targetFuncAttrs) {
+ funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted()) {
+ variantOp.emitOpError(
+ "failed in iree-codegen-reconcile-translation-info pass");
+ return signalPassFailure();
+ }
// Reconcile workgroup sizes.
FailureOr<SmallVector<int64_t>> reconciledWorkgroupSize =
@@ -125,7 +404,7 @@
for (auto [index, size] : llvm::enumerate(reconciledWorkgroupSize.value())) {
workgroupSize[index] = size;
}
- auto workgroupSizeArrayAttr = builder.getIndexArrayAttr(workgroupSize);
+ auto workgroupSizeArrayAttr = rewriter.getIndexArrayAttr(workgroupSize);
exportOp.setWorkgroupSizeAttr(workgroupSizeArrayAttr);
// Reconcile subgroup sizes.
@@ -137,7 +416,7 @@
}
if (reconciledSubgroupSize.value() != int64_t()) {
exportOp.setSubgroupSizeAttr(
- builder.getIndexAttr(reconciledSubgroupSize.value()));
+ rewriter.getIndexAttr(reconciledSubgroupSize.value()));
}
// Erase all the lowering configs and translation infos.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 66cfa40..3f19b22 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/IR/StorageUniquerSupport.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler {
@@ -279,8 +278,7 @@
SmallVector<Attribute> deviceMappingAttribute =
getMapping(context, tilingInfo->tileSizes);
if (failed(IREE::Codegen::WorkgroupMappingAttr::verifyAttrList(
- context, ::mlir::detail::getDefaultDiagnosticEmitFn(funcOp.getLoc()),
- deviceMappingAttribute))) {
+ context, funcOp.getLoc(), deviceMappingAttribute))) {
return signalPassFailure();
}
tilingOptions.setMapping(deviceMappingAttribute);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
index 18c1677..cc8aa9d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info)))" %s --verify-diagnostics | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info, canonicalize)))" %s --verify-diagnostics --allow-unregistered-dialect | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
@@ -160,3 +160,272 @@
// CHECK-LABEL: hal.executable private @llvm_func_attrs
// CHECK: func.func @fn1() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}
// CHECK: func.func @fn2() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_2D {
+ hal.executable.variant public @scf_forall_2D target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @scf_forall_2D layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @scf_forall_2D() {
+ %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+ %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+ %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+ %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+ %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+ scf.forall (%arg0, %arg1) = (0, 0) to (%0, %1) step(64, 32) {
+ "use"(%arg0, %arg1) : (index, index) -> ()
+ scf.forall.in_parallel {}
+ } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+ return
+ }
+ }
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
+// CHECK: hal.executable.export public @scf_forall_2D layout
+// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-z0-9]+]]: index
+// CHECK-DAG: %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+// CHECK-DAG: %[[WG_X:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
+// CHECK: func @scf_forall_2D()
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-NOT: scf.forall
+// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_2D_dynamic_tile_size {
+ hal.executable.variant public @scf_forall_2D_dynamic_tile_size target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @scf_forall_2D_dynamic_tile_size layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index, %arg4 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @scf_forall_2D_dynamic_tile_size() {
+ %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+ %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+ %cst3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
+ %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+ %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+ %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+ %3 = flow.dispatch.workload.ordinal %cst3, 3 : index
+ scf.forall (%arg0, %arg1) = (0, 0) to (%0, %1) step(%2, %3) {
+ "use"(%arg0, %arg1) : (index, index) -> ()
+ scf.forall.in_parallel {}
+ } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+ return
+ }
+ }
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)
+// CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout
+// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-z0-9]+]]: index
+// CHECK-DAG: %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[WG_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
+// CHECK: func @scf_forall_2D_dynamic_tile_size()
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-NOT: scf.forall
+// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 12, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_4D {
+ hal.executable.variant public @scf_forall_4D target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @scf_forall_4D layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index, %arg4 : index,
+ %arg5 : index, %arg6 : index, %arg7 : index, %arg8 : index,
+ %arg9 : index, %arg10 : index, %arg11 : index, %arg12 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3,
+ %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @scf_forall_4D() {
+ %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+ %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+ %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+ %cst3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
+ %cst4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : index
+ %cst5 = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : index
+ %cst6 = hal.interface.constant.load layout(#pipeline_layout) ordinal(6) : index
+ %cst7 = hal.interface.constant.load layout(#pipeline_layout) ordinal(7) : index
+ %cst8 = hal.interface.constant.load layout(#pipeline_layout) ordinal(8) : index
+ %cst9 = hal.interface.constant.load layout(#pipeline_layout) ordinal(9) : index
+ %cst10 = hal.interface.constant.load layout(#pipeline_layout) ordinal(10) : index
+ %cst11 = hal.interface.constant.load layout(#pipeline_layout) ordinal(11) : index
+ %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+ %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+ %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+ %3 = flow.dispatch.workload.ordinal %cst3, 3 : index
+ %4 = flow.dispatch.workload.ordinal %cst4, 4 : index
+ %5 = flow.dispatch.workload.ordinal %cst5, 5 : index
+ %6 = flow.dispatch.workload.ordinal %cst6, 6 : index
+ %7 = flow.dispatch.workload.ordinal %cst7, 7 : index
+ %8 = flow.dispatch.workload.ordinal %cst8, 8 : index
+ %9 = flow.dispatch.workload.ordinal %cst9, 9 : index
+ %10 = flow.dispatch.workload.ordinal %cst10, 10 : index
+ %11 = flow.dispatch.workload.ordinal %cst11, 11 : index
+ scf.forall (%arg0, %arg1, %arg2, %arg3) = (%0, %1, %2, %3) to (%4, %5, %6, %7) step(%8, %9, %10, %11) {
+ "use"(%arg0, %arg1, %arg2, %arg3) : (index, index, index, index) -> ()
+ scf.forall.in_parallel {}
+ } {mapping = [#iree_codegen.workgroup_mapping<z:1>,
+ #iree_codegen.workgroup_mapping<z:0>,
+ #iree_codegen.workgroup_mapping<y>,
+ #iree_codegen.workgroup_mapping<x>]}
+ return
+ }
+ }
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
+// CHECK: hal.executable.export public @scf_forall_4D layout
+// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG11:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG12:[a-zA-z0-9]+]]: index
+// CHECK-DAG: %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]], %[[ARG11]]]
+// CHECK-DAG: %[[WG_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG4]], %[[ARG8]], %[[ARG12]]]
+// CHECK-DAG: %[[WG_Z:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]], %[[ARG10]], %[[ARG1]], %[[ARG5]], %[[ARG9]]]
+// CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[WG_Z]]
+// CHECK: func @scf_forall_4D()
+// CHECK-DAG: %[[LB0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0)
+// CHECK-DAG: %[[LB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1)
+// CHECK-DAG: %[[UB0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(4)
+// CHECK-DAG: %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5)
+// CHECK-DAG: %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8)
+// CHECK-DAG: %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9)
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
+// CHECK-DAG: %[[NITERS0:.+]] = affine.apply #[[MAP0]]()[%[[LB0]], %[[UB0]], %[[STEP0]]]
+// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
+// CHECK-NOT: scf.forall
+// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]])
+// CHECK: "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_4D_static_interchange {
+ hal.executable.variant public @scf_forall_4D_static_interchange target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @scf_forall_4D_static_interchange layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @scf_forall_4D_static_interchange() {
+ scf.forall (%arg0, %arg1, %arg2, %arg3, %arg4) = (0, 1, 2, 3, 4) to (4, 10, 19, 29, 44) step(1, 2, 3, 4, 5) {
+ "use"(%arg0, %arg1, %arg2, %arg3, %arg4) : (index, index, index, index, index) -> ()
+ scf.forall.in_parallel {}
+ } {mapping = [#iree_codegen.workgroup_mapping<z:0>,
+ #iree_codegen.workgroup_mapping<z:2>,
+ #iree_codegen.workgroup_mapping<x>,
+ #iree_codegen.workgroup_mapping<y>,
+ #iree_codegen.workgroup_mapping<z:1>]}
+ return
+ }
+ }
+ }
+}
+// CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[C160:.+]] = arith.constant 160 : index
+// CHECK: hal.return %[[C6]], %[[C7]], %[[C160]]
+// CHECK: func @scf_forall_4D_static_interchange()
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
+// CHECK-NOT: scf.forall
+// CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
+// CHECK: "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1)
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @no_loop_do_nothing {
+ hal.executable.variant public @no_loop_do_nothing target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @no_loop_do_nothing layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1: index
+ %c2 = arith.constant 2: index
+ hal.return %c1, %c2, %c1 : index, index, index
+ }
+ builtin.module {
+ func.func @no_loop_do_nothing() {
+ return
+ }
+ }
+ }
+}
+// CHECK: hal.executable.export public @no_loop_do_nothing layout
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: hal.return %[[C1]], %[[C2]], %[[C1]]
+// CHECK: func @no_loop_do_nothing()
+// CHECK-NEXT: return
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+ #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @no_loop_default_workgroup_count {
+ hal.executable.variant public @no_loop_default_workgroup_count target(#hal.executable.target<"", "", {}>) {
+ hal.executable.export public @no_loop_default_workgroup_count layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %0:3 = flow.dispatch.workgroup_count_from_slice %arg1, %arg2
+ hal.return %0#1, %0#2, %0#0 : index, index, index
+ }
+ builtin.module {
+ func.func @no_loop_default_workgroup_count() {
+ return
+ }
+ }
+ }
+}
+// CHECK: hal.executable.export public @no_loop_default_workgroup_count layout
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: hal.return %[[C1]], %[[C1]], %[[C1]]
+// CHECK: func @no_loop_default_workgroup_count()
+// CHECK-NEXT: return
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 5f4be34..3ba545e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StorageUniquerSupport.h"
#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp.inc"
@@ -395,14 +396,15 @@
return getMappingId() < rhs.getMappingId();
}
-LogicalResult WorkgroupMappingAttr::verifyAttrList(
- MLIRContext *context, function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Attribute> attrs) {
+LogicalResult WorkgroupMappingAttr::verifyAttrList(MLIRContext *context,
+ Location loc,
+ ArrayRef<Attribute> attrs) {
if (attrs.empty()) {
return success();
}
SmallVector<IREE::Codegen::WorkgroupMappingAttr> mappingAttrs;
llvm::SmallDenseSet<IREE::Codegen::WorkgroupMappingAttr, 4> attrSet;
+ auto emitError = mlir::detail::getDefaultDiagnosticEmitFn(loc);
for (auto attr : attrs) {
auto typedAttr =
::mlir::dyn_cast_or_null<IREE::Codegen::WorkgroupMappingAttr>(attr);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 6190b6c..48c3917 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -188,8 +188,7 @@
let extraClassDeclaration = [{
// Checks that a list of attributes is well-defined.
static LogicalResult verifyAttrList(::mlir::MLIRContext *context,
- ::llvm::function_ref<::mlir::InFlightDiagnostic ()> emitError,
- ArrayRef<Attribute> attrs);
+ Location loc, ArrayRef<Attribute> attrs);
// Convert from mapping ID to attribute.
static ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f1c7ac1..c95c484 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1137,10 +1137,12 @@
void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
bool useROCM) {
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
- FunctionLikeNest(modulePassManager)
- .addPass(createLLVMGPULowerExecutableTargetPass);
+ {
+ OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
+ modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
+ FunctionLikeNest(modulePassManager)
+ .addPass(createLLVMGPULowerExecutableTargetPass);
+ }
variantPassManager.addPass(createReconcileTranslationInfoPass());
//===--------------------------------------------------------------------===//
// Convert Linalg ops to LLVM+NVVM/ROCDL ops.
@@ -1149,7 +1151,7 @@
// - All Linalg/Loops/GPU/Affine/Standard ops are converted away.
// - The module contains the final llvm.module ready to be serialized.
//===--------------------------------------------------------------------===//
- addLowerToLLVMGPUPasses(modulePassManager, useROCM);
+ addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(), useROCM);
LLVM_DEBUG({
llvm::dbgs() << "Using LLVMGPU pass pipeline:\n";
@@ -1181,12 +1183,16 @@
}
void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
- OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
- modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
- FunctionLikeNest(modulePassManager)
- .addPass(createROCDLLowerExecutableTargetPass);
+ {
+ OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
+ modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
+ FunctionLikeNest(modulePassManager)
+ .addPass(createROCDLLowerExecutableTargetPass);
+ }
variantPassManager.addPass(createReconcileTranslationInfoPass());
- addLowerToLLVMGPUPasses(modulePassManager, /*forROCDL=*/true);
+
+ addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(),
+ /*forROCDL=*/true);
LLVM_DEBUG({
llvm::dbgs() << "Using ROCDL pass pipeline:\n";
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 74763d6..89aed57 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -454,12 +454,11 @@
std::optional<IREE::HAL::ExecutableExportOp> exportOp =
getEntryPoint(entryPointFn);
if (!exportOp) {
- return entryPointFn.emitOpError(
- "expected function to be entry point function");
+ return success();
}
Block *body = exportOp->getWorkgroupCountBody();
if (!body) {
- return exportOp->emitOpError("unexpected empty workgroup count region");
+ return success();
}
auto countOps = body->getOps<IREE::Flow::DispatchWorkgroupCountFromSliceOp>();
if (countOps.empty()) {