blob: 332c29a5ea36c254f247912843010b4db04a0cfe [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-codegen-set-num-workgroups"
static const unsigned kNumMaxParallelDims = 3;
namespace mlir {
namespace iree_compiler {
namespace {
/// Sets the hal.interace.workgroup.size operation to the constant value passed
/// in as `workloadPerWorkgroup`. The number of entries in
/// `workloadPerWorkgroup` is at least as much as the dimensionality of the
/// workgroup. It is assumed that the inner-most loop is mapped to the fastest
/// varying dimension in flow.dispatch.workgroup_size.
class SetWorkgroupSizePattern
: public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
public:
SetWorkgroupSizePattern(MLIRContext *context,
ArrayRef<int64_t> workloadPerWorkgroupRef,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit),
workloadPerWorkgroup(llvm::to_vector<4>(
workloadPerWorkgroupRef.size() > kNumMaxParallelDims
? workloadPerWorkgroupRef.take_front(kNumMaxParallelDims)
: workloadPerWorkgroupRef)) {}
LogicalResult matchAndRewrite(
IREE::HAL::InterfaceWorkgroupSizeOp workgroupSizeOp,
PatternRewriter &rewriter) const override {
int64_t dim = workgroupSizeOp.dimension().getSExtValue();
if (dim >= workloadPerWorkgroup.size()) {
return failure();
}
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
workgroupSizeOp, workloadPerWorkgroup[dim]);
return success();
}
private:
SmallVector<int64_t, 4> workloadPerWorkgroup;
};
class SetNumWorkgroupsPass : public SetNumWorkgroupsBase<SetNumWorkgroupsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, IREE::HAL::HALDialect, linalg::LinalgDialect>();
}
SetNumWorkgroupsPass(ArrayRef<int64_t> ws = {})
: workloadPerWorkgroup(ws.begin(), ws.end()) {}
SetNumWorkgroupsPass(const SetNumWorkgroupsPass &pass)
: workloadPerWorkgroup(pass.workloadPerWorkgroup) {}
void runOnOperation() override;
private:
SmallVector<int64_t> workloadPerWorkgroup;
};
} // namespace
void SetNumWorkgroupsPass::runOnOperation() {
MLIRContext *context = &getContext();
IREE::HAL::ExecutableVariantOp variantOp = getOperation();
ModuleOp module = variantOp.getInnerModule();
llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
getAllEntryPoints(module);
for (auto funcOp : module.getOps<FuncOp>()) {
auto entryPointOp = entryPoints.lookup(funcOp.getName());
if (!entryPointOp) continue;
SmallVector<int64_t, 4> currWorkloadPerWorkgroup;
// First check if there is a per-workgroup workload provided.
if (!workloadPerWorkgroup.empty()) {
currWorkloadPerWorkgroup.assign(workloadPerWorkgroup.begin(),
workloadPerWorkgroup.end());
} else if (IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(entryPointOp)) {
currWorkloadPerWorkgroup = translationInfo.getWorkloadPerWorkgroupVals();
}
if (!currWorkloadPerWorkgroup.empty()) {
// Fold hal.workgroup.size ops.
OwningRewritePatternList patterns(funcOp.getContext());
patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
currWorkloadPerWorkgroup);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
// The workgroup count region might already be set by op-specific
// configuration logic. If so, just return to avoid overwriting that.
if (!entryPointOp.workgroup_count_region().empty()) continue;
WorkgroupCountRegionBuilder regionBuilder;
if (currWorkloadPerWorkgroup.empty()) {
// If no workgroup size is specified, leave the workgroup size as is, just
// set the number of workgroups to be 1, 1, 1 to have a single invocation.
regionBuilder = [](OpBuilder &b, Location loc,
std::array<Value, 3> workload) {
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
return std::array<Value, 3>{one, one, one};
};
} else {
assert(currWorkloadPerWorkgroup.size() <= kNumMaxParallelDims &&
"workloadPerWorkgroup size greater than max num parallel dims");
regionBuilder = [&currWorkloadPerWorkgroup](
OpBuilder &b, Location loc,
std::array<Value, 3> workload) {
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
std::array<Value, 3> returnValues = {one, one, one};
for (auto ts : llvm::enumerate(currWorkloadPerWorkgroup)) {
returnValues[ts.index()] = linalg::applyMapToValues(
b, loc,
AffineMap::get(0, 1,
b.getAffineSymbolExpr(0).ceilDiv(ts.value())),
workload[ts.index()])[0];
}
return returnValues;
};
}
OpBuilder builder(context);
if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder)))
return signalPassFailure();
}
// Apply post distribution canonicalization passes.
OwningRewritePatternList canonicalization(context);
AffineMinOp::getCanonicalizationPatterns(canonicalization, context);
populateAffineMinSCFCanonicalizationPattern(canonicalization);
IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
context);
(void)applyPatternsAndFoldGreedily(module, std::move(canonicalization));
}
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createSetNumWorkgroupsPass(ArrayRef<int64_t> workgroupSize) {
return std::make_unique<SetNumWorkgroupsPass>(workgroupSize);
}
} // namespace iree_compiler
} // namespace mlir