blob: a858bc2e5cf385ec79e230f9b89015a80f61f9d2 [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#define DEBUG_TYPE \
"iree-dispatch-creation-materialize-default-workgroup-count-region"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_MATERIALIZEDEFAULTWORKGROUPCOUNTREGIONPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
/// Creates the workgroup count region where the materialized computation
/// is derived as a program slice of the body of the dispatch. This method
/// - Computes the `workload` to use for the `workgroupsOp`, which are
/// derived from the values captured by the `workgroupsOp`.
/// - Populates the workgroup count region for this with the placeholder
/// op `flow.dispatch.workgroups_count_from_body_slice`. This op is
/// resolved in the backends into the actual workgroup count computation.
/// - To correlate back to the captured workload,
/// `flow.dispatch.workload.ordinal`
/// to map the captured operand to the position in the workload list.
static void createDefaultWorkgroupCountRegion(
RewriterBase &rewriter, IREE::Flow::DispatchWorkgroupsOp workgroupsOp) {
Region &workgroupCountBody = workgroupsOp.getWorkgroupCount();
if (!workgroupCountBody.empty()) {
// Preserve pre-existing workgroup count region.
return;
}
// Compute the `workload`. For now all `IndexType` are treated as workload.
SmallVector<Value> workload;
SmallVector<Type> workloadTypes;
SmallVector<Location> workloadLocs;
for (auto argument : workgroupsOp.getArguments()) {
Type argumentType = argument.getType();
if (!llvm::isa<IndexType>(argumentType))
continue;
workload.push_back(argument);
workloadTypes.push_back(argumentType);
workloadLocs.push_back(argument.getLoc());
}
// Populate the count region.
Block *block =
rewriter.createBlock(&workgroupCountBody, workgroupCountBody.end(),
workloadTypes, workloadLocs);
Location loc = workgroupsOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(block);
auto defaultCountOp =
rewriter.create<IREE::Flow::DispatchWorkgroupCountFromSliceOp>(
loc, block->getArguments());
rewriter.create<IREE::Flow::ReturnOp>(loc, defaultCountOp.getResults());
// Update the `workgroupsOp` region.
rewriter.modifyOpInPlace(workgroupsOp, [&]() {
// Update the workload of the op.
workgroupsOp.getWorkloadMutable().assign(workload);
// Annotate the values captures as workload with their position in the
// workload list.
Region &body = workgroupsOp.getWorkgroupBody();
if (body.empty()) {
return;
}
rewriter.setInsertionPointToStart(&body.front());
int ordinalNumber = 0;
for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) {
if (!llvm::isa<IndexType>(operand.getType()))
continue;
BlockArgument arg = workgroupsOp.getInputBlockArgument(index);
auto ordinalOp = rewriter.create<IREE::Flow::DispatchWorkloadOrdinalOp>(
loc, arg, rewriter.getIndexAttr(ordinalNumber++));
rewriter.replaceAllUsesExcept(arg, ordinalOp, ordinalOp);
}
});
}
namespace {
struct MaterializeDefaultWorkgroupCountRegionPass
: public impl::MaterializeDefaultWorkgroupCountRegionPassBase<
MaterializeDefaultWorkgroupCountRegionPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
// populates the workgroup count region.
void MaterializeDefaultWorkgroupCountRegionPass::runOnOperation() {
FunctionOpInterface funcOp = getOperation();
TensorDimTrackingRewriter rewriter(funcOp);
// Populate the workgroup_count region of flow.dispatch.workgroups operation
// that dont already have a region
funcOp.walk([&](IREE::Flow::DispatchWorkgroupsOp workgroupsOp) {
createDefaultWorkgroupCountRegion(rewriter, workgroupsOp);
});
}
} // namespace mlir::iree_compiler::DispatchCreation