blob: f4bd35870fe2276066240cbdeba28a9f44c35948 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- Transforms.cpp - Transformations common to all backends ------------===//
//
// Implements transformations that are common to all backends.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/Transforms/Transforms.h"
#include "iree/compiler/Conversion/Utils/MarkerUtils.h"
#include "iree/compiler/Conversion/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
static constexpr unsigned kMaxNumParallelDims = 3;
namespace mlir {
namespace iree_compiler {
namespace {
static size_t kMaxHALDimensions = 3;
/// Sets the hal.interace.workgroup.size operation to the constant value passed
/// in as `workloadPerWorkgroup`. The number of entries in
/// `workloadPerWorkgroup` is at least as much as the dimensionality of the
/// workgroup. It is assumed that the inner-most loop is mapped to the fastest
/// varying dimension in flow.dispatch.workgroup_size.
class SetWorkgroupSizePattern
: public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
public:
SetWorkgroupSizePattern(MLIRContext *context,
ArrayRef<int64_t> workloadPerWorkgroupRef,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit),
workloadPerWorkgroup(llvm::to_vector<4>(
workloadPerWorkgroupRef.size() > kMaxHALDimensions
? workloadPerWorkgroupRef.take_front(kMaxHALDimensions)
: workloadPerWorkgroupRef)) {}
LogicalResult matchAndRewrite(
IREE::HAL::InterfaceWorkgroupSizeOp workgroupSizeOp,
PatternRewriter &rewriter) const override {
int64_t dim = workgroupSizeOp.dimension().getSExtValue();
if (dim >= workloadPerWorkgroup.size()) {
return failure();
}
rewriter.replaceOpWithNewOp<ConstantIndexOp>(workgroupSizeOp,
workloadPerWorkgroup[dim]);
return success();
}
private:
SmallVector<int64_t, 4> workloadPerWorkgroup;
};
} // namespace
LogicalResult defineWorkgroupCountRegion(
OpBuilder &builder, FuncOp funcOp,
WorkgroupCountRegionBuilder regionBuilder) {
IREE::HAL::ExecutableEntryPointOp entryPointOp = getEntryPoint(funcOp);
if (!entryPointOp) {
return funcOp.emitOpError("unable to find corresponding entry point op");
}
Location loc = entryPointOp.getLoc();
OpBuilder::InsertionGuard guard(builder);
// Create the cloned operation but with a single region.
builder.setInsertionPoint(entryPointOp);
auto clonedOp = builder.create<IREE::HAL::ExecutableEntryPointOp>(
loc, entryPointOp.sym_nameAttr(), entryPointOp.ordinalAttr(),
entryPointOp.interfaceAttr(), entryPointOp.workgroup_sizeAttr(),
entryPointOp.workgroup_local_memoryAttr(), 1);
Region *region = clonedOp.getBody();
Block *entryBlock = builder.createBlock(region);
// Add 3 index arguments for the workload.
auto indexType = builder.getIndexType();
std::array<Value, 3> workload = {entryBlock->addArgument(indexType),
entryBlock->addArgument(indexType),
entryBlock->addArgument(indexType)};
std::array<Value, 3> workgroupCount = regionBuilder(builder, loc, workload);
builder.create<IREE::HAL::ReturnOp>(loc, workgroupCount);
entryPointOp.erase();
return success();
}
LogicalResult materializeStaticLaunchInformation(
FuncOp funcOp, ArrayRef<int64_t> workloadPerWorkgroup) {
OwningRewritePatternList patterns(funcOp.getContext());
patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
workloadPerWorkgroup);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return failure();
}
assert(workloadPerWorkgroup.size() <= kMaxNumParallelDims &&
"workloadPerWorkgroup size greater than max num parallel dims");
WorkgroupCountRegionBuilder regionBuilder =
[&workloadPerWorkgroup](
OpBuilder &b, Location loc,
std::array<Value, 3> workload) -> std::array<Value, 3> {
Value one = b.create<ConstantIndexOp>(loc, 1);
std::array<Value, 3> returnValues = {one, one, one};
for (auto ts : llvm::enumerate(workloadPerWorkgroup)) {
returnValues[ts.index()] = linalg::applyMapToValues(
b, loc,
AffineMap::get(0, 1, b.getAffineSymbolExpr(0).ceilDiv(ts.value())),
workload[ts.index()])[0];
}
return returnValues;
};
OpBuilder builder(funcOp.getContext());
return defineWorkgroupCountRegion(builder, funcOp, regionBuilder);
}
/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
/// %c0 = vector.constant 0: ...
/// %c = vector.contract %a, %b, %c0: ...
/// %e = add %c, %d: ...
/// ```
///
/// by:
///
/// ```mlir
/// %e = vector.contract %a, %b, %d: ...
/// ```
///
/// Return null if the canonicalization does not apply.
// TODO: This should be a folding of Add into Contract in core but while they
// live in different dialects, it is not possible without unnatural
// dependencies.
vector::ContractionOp canonicalizeContractionAdd(Operation *op) {
if (!isa<AddIOp, AddFOp>(op)) return nullptr;
OpBuilder builder(op);
auto canonicalize = [](OpBuilder &b, Value maybeContraction,
Value otherOperand) -> vector::ContractionOp {
vector::ContractionOp contractionOp =
dyn_cast_or_null<vector::ContractionOp>(
maybeContraction.getDefiningOp());
if (!contractionOp) return nullptr;
if (auto maybeZero =
dyn_cast_or_null<ConstantOp>(contractionOp.acc().getDefiningOp())) {
if (maybeZero.value() == b.getZeroAttr(contractionOp.acc().getType())) {
BlockAndValueMapping bvm;
bvm.map(contractionOp.acc(), otherOperand);
return cast<vector::ContractionOp>(b.clone(*contractionOp, bvm));
}
}
return nullptr;
};
Value a = op->getOperand(0), b = op->getOperand(1);
vector::ContractionOp contract = canonicalize(builder, a, b);
contract = contract ? contract : canonicalize(builder, b, a);
return contract;
}
} // namespace iree_compiler
} // namespace mlir