blob: ddc0b9a52070d192bd0a2c3e6d7389b08eb9d401 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#define DEBUG_TYPE "iree-codegen-utils"
namespace mlir::iree_compiler {
//===----------------------------------------------------------------------===//
// Utility functions to get entry points
//===----------------------------------------------------------------------===//
std::optional<IREE::HAL::ExecutableExportOp>
getEntryPoint(mlir::FunctionOpInterface funcOp) {
auto variantOp = funcOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
if (!variantOp) {
return std::nullopt;
}
for (auto op : variantOp.getExportOps()) {
if (op.getSymName() == funcOp.getName()) {
return op;
}
}
return std::nullopt;
}
bool isEntryPoint(mlir::FunctionOpInterface func) {
return func.isPublic() && getEntryPoint(func);
}
std::optional<StringAttr> getConfigStringAttr(Attribute srcAttr,
StringRef stringAttr) {
if (!srcAttr) {
return std::nullopt;
}
auto targetAttr = dyn_cast<IREE::HAL::ExecutableTargetAttr>(srcAttr);
DictionaryAttr config;
if (targetAttr) {
config = targetAttr.getConfiguration();
} else {
config = dyn_cast<DictionaryAttr>(srcAttr);
}
if (!config) {
return std::nullopt;
}
auto attr = config.getAs<StringAttr>(stringAttr);
if (!attr) {
return std::nullopt;
}
return attr;
}
std::optional<IntegerAttr> getConfigIntegerAttr(Attribute srcAttr,
StringRef integerAttr) {
if (!srcAttr) {
return std::nullopt;
}
auto targetAttr = dyn_cast<IREE::HAL::ExecutableTargetAttr>(srcAttr);
DictionaryAttr config;
if (targetAttr) {
config = targetAttr.getConfiguration();
} else {
config = dyn_cast<DictionaryAttr>(srcAttr);
}
if (!config) {
return std::nullopt;
}
auto attr = config.getAs<IntegerAttr>(integerAttr);
if (!attr) {
return std::nullopt;
}
return attr;
}
std::optional<BoolAttr> getConfigBoolAttr(Attribute srcAttr,
StringRef boolAttr) {
if (!srcAttr) {
return std::nullopt;
}
auto targetAttr = dyn_cast<IREE::HAL::ExecutableTargetAttr>(srcAttr);
DictionaryAttr config;
if (targetAttr) {
config = targetAttr.getConfiguration();
} else {
config = dyn_cast<DictionaryAttr>(srcAttr);
}
if (!config) {
return std::nullopt;
}
auto attr = config.getAs<BoolAttr>(boolAttr);
if (!attr) {
return std::nullopt;
}
return attr;
}
std::optional<llvm::Triple> getTargetTriple(Attribute attr) {
auto triple = getConfigStringAttr(attr, "target_triple");
if (!triple) {
return std::nullopt;
}
return llvm::Triple(triple.value().str());
}
const char *getIreeArchNameForTargetTriple(llvm::Triple triple) {
if (triple.isX86()) {
return triple.isArch64Bit() ? "x86_64" : "x86_32";
}
if (triple.isWasm()) {
return triple.isArch64Bit() ? "wasm_64" : "wasm_32";
}
if (triple.isAArch64()) {
return "arm_64";
}
if (triple.isARM()) {
return "arm_32";
}
if (triple.isRISCV64()) {
return "riscv_64";
}
if (triple.isRISCV32()) {
return "riscv_32";
}
return "unknown";
}
bool isLLVMCPUBackend(IREE::HAL::ExecutableTargetAttr targetAttr) {
return targetAttr && targetAttr.getBackend().getValue() == "llvm-cpu";
}
bool isVMVXBackend(IREE::HAL::ExecutableTargetAttr targetAttr) {
return targetAttr && targetAttr.getBackend().getValue().starts_with("vmvx");
}
bool isROCMBackend(IREE::HAL::ExecutableTargetAttr targetAttr) {
return targetAttr && targetAttr.getBackend().getValue().starts_with("rocm");
}
static const char *getDefaultEnabledUkernels(Attribute attr) {
const char *kNone = "none";
if (!attr) {
return kNone;
}
auto targetAttr = dyn_cast<IREE::HAL::ExecutableTargetAttr>(attr);
if (!targetAttr) {
return kNone;
}
if (isX86_64(targetAttr)) {
return "mmt4d";
}
if (isAArch64(targetAttr)) {
if (hasFeature(targetAttr, "+sve") || hasFeature(targetAttr, "+sve2") ||
hasFeature(targetAttr, "+sme")) {
return kNone;
}
return "mmt4d";
}
return kNone;
}
bool hasUkernel(Attribute attr, StringRef ukernelName) {
auto enabledUkernels = getConfigStringAttr(attr, "ukernels");
StringRef enabledUkernelsStr;
if (enabledUkernels) {
enabledUkernelsStr = enabledUkernels->getValue();
} else {
enabledUkernelsStr = "default";
}
// Resolve `default`.
if (enabledUkernelsStr == "default") {
enabledUkernelsStr = getDefaultEnabledUkernels(attr);
}
// Resolve `none`.
if (enabledUkernelsStr == "none") {
return false;
}
// Resolve `all`.
if (enabledUkernelsStr == "all") {
return true;
}
// If `ukernelName` is empty, the question is "are ukernels enabled at all?"
// At this point, we already know that enabledUkernelsStr != "none".
if (ukernelName.empty()) {
return !enabledUkernelsStr.empty();
}
while (!enabledUkernelsStr.empty()) {
auto split = enabledUkernelsStr.split(',');
if (split.first == ukernelName) {
return true;
}
enabledUkernelsStr = split.second;
}
return false;
}
std::optional<StringRef> getCpuFeatures(Attribute attr) {
auto cpuFeatures = getConfigStringAttr(attr, "cpu_features");
if (!cpuFeatures) {
return std::nullopt;
}
return cpuFeatures->getValue();
}
// TODO(dcaballe): If we have to check for a significantly large number of
// features in the future, we may want to consider a persistent state to carry
// over processed HAL information or keeping the TTI instance alive and query
// subtarget features data structure.
bool hasFeature(Attribute attr, StringRef feature) {
std::optional<StringRef> features = getCpuFeatures(attr);
if (!features) {
return false;
}
// Find feature string in list of features, making sure that we don't match a
// sub-string.
std::stringstream sstream(features->str());
std::string str;
while (std::getline(sstream, str, ',')) {
if (str == feature) {
return true;
}
}
return false;
}
bool isX86(Attribute attr) {
std::optional<llvm::Triple> triple = getTargetTriple(attr);
return triple && triple.value().isX86();
}
bool isX86_64(Attribute attr) {
std::optional<llvm::Triple> triple = getTargetTriple(attr);
return triple && triple.value().getArch() == llvm::Triple::x86_64;
}
bool isAArch64(Attribute attr) {
std::optional<llvm::Triple> triple = getTargetTriple(attr);
return triple && triple.value().isAArch64();
}
bool isRISCV(Attribute attr) {
std::optional<llvm::Triple> triple = getTargetTriple(attr);
return triple && triple.value().isRISCV();
}
bool isRISCV32(Attribute attr) {
std::optional<llvm::Triple> triple = getTargetTriple(attr);
return triple && triple.value().isRISCV32();
}
bool isReadOnly(Value v) {
Operation *definingOp = v.getDefiningOp();
if (!definingOp)
return false;
return TypeSwitch<Operation *, bool>(definingOp)
.Case<arith::ConstantOp>(
[&](arith::ConstantOp constantOp) { return true; })
.Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
[&](auto op) { return isReadOnly(op.getSrc()); })
.Case<tensor::CastOp, tensor::ExtractSliceOp>(
[&](auto op) { return isReadOnly(op.getSource()); })
.Case<IREE::Flow::DispatchTensorLoadOp>(
[&](IREE::Flow::DispatchTensorLoadOp loadOp) {
return llvm::cast<IREE::Flow::DispatchTensorType>(
loadOp.getSource().getType())
.getAccess() == IREE::Flow::TensorAccess::ReadOnly;
})
.Default([&](Operation *op) { return false; });
}
LogicalResult duplicateTensorEmptyOps(OpBuilder &b, tensor::EmptyOp emptyOp) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(emptyOp);
SmallVector<OpOperand *> uses = llvm::map_to_vector(
emptyOp->getUses(), [](OpOperand &use) { return &use; });
for (auto use : llvm::make_range(std::next(uses.begin()), uses.end())) {
auto newOp = cast<tensor::EmptyOp>(b.clone(*emptyOp.getOperation()));
Operation *user = use->getOwner();
user->setOperand(use->getOperandNumber(), newOp);
}
return success();
}
//===----------------------------------------------------------------------===//
// Setting CustomOp Lowering config.
//===----------------------------------------------------------------------===//
static std::tuple<SmallVector<Operation *>, SetVector<Value>>
getNonConstantValuesDefinedFromAbove(Region &region) {
llvm::SetVector<Value> valuesDefinedFromAbove;
mlir::getUsedValuesDefinedAbove(region, valuesDefinedFromAbove);
SmallVector<Operation *> constants;
SetVector<Value> erasedVals;
for (auto value : valuesDefinedFromAbove) {
Attribute constVal;
if (!matchPattern(value, m_Constant(&constVal))) {
continue;
}
if (!isa<IntegerAttr, FloatAttr>(constVal)) {
continue;
}
constants.push_back(value.getDefiningOp());
erasedVals.insert(value);
}
valuesDefinedFromAbove.set_subtract(erasedVals);
return {constants, valuesDefinedFromAbove};
}
/// Listener to track mapping from operations in the body of a cloned custom op
/// back to the original operations in the body of the original custom op.
class CustomOpConfigListener : public RewriterBase::Listener {
public:
CustomOpConfigListener(IREE::LinalgExt::CustomOp origCustomOp,
IREE::LinalgExt::CustomOp clonedCustomOp) {
for (auto [origOp, clonedOp] :
llvm::zip_equal(origCustomOp.getBody()->without_terminator(),
clonedCustomOp.getBody()->without_terminator())) {
clonedOpToOrigOp[&clonedOp] = &origOp;
}
}
void notifyOperationErased(Operation *op) override {
clonedOpToOrigOp.erase(op);
}
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
auto it = clonedOpToOrigOp.find(op);
if (it != clonedOpToOrigOp.end()) {
Operation *origOp = it->second;
clonedOpToOrigOp.erase(it);
clonedOpToOrigOp[replacement] = origOp;
}
}
void notifyOperationReplaced(Operation *op,
ValueRange replacements) override {
Operation *replacementOp = nullptr;
for (auto val : replacements) {
Operation *definingOp = getDefiningOp(val);
if (!definingOp) {
// One of the replacements is definitely not from an op. Bail
// immediately.
return;
}
if (replacementOp) {
if (definingOp != replacementOp) {
// No consistent replacementOp. Bail.
return;
}
} else {
replacementOp = definingOp;
}
}
if (replacementOp && replacementOp->getName() == op->getName()) {
notifyOperationReplaced(op, replacementOp);
}
}
// Helper methods to get back the orig op for the cloned op.
std::optional<Operation *> getOrigOp(Operation *clonedOp) {
auto it = clonedOpToOrigOp.find(clonedOp);
if (it == clonedOpToOrigOp.end()) {
return std::nullopt;
}
return it->second;
}
private:
llvm::MapVector<Operation *, Operation *> clonedOpToOrigOp;
/// On cast propagation, the replacement value used is not the
/// actual op that is used for replacement. Walk back the replacement
/// value use-def chain to get to the real replacement. This is a
/// bit of a hack, but the lowering config propagation is really
/// best effort, so not incorrect.
Operation *getDefiningOp(Value v) {
Operation *definingOp = v.getDefiningOp();
while (definingOp) {
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
definingOp = castOp.getSource().getDefiningOp();
continue;
}
// Default is to break out of the loop.
break;
}
return definingOp;
}
};
LogicalResult setDefaultCustomOpLoweringConfig(
FunctionOpInterface funcOp, IREE::LinalgExt::CustomOp customOp,
std::function<LogicalResult(FunctionOpInterface)> configFn) {
MLIRContext *context = funcOp.getContext();
IRRewriter rewriter(context);
rewriter.setInsertionPoint(funcOp);
// 1. Get values captured from above in the custom op region.
llvm::SetVector<Value> valuesDefinedAbove;
SmallVector<Operation *> constantOps;
std::tie(constantOps, valuesDefinedAbove) =
getNonConstantValuesDefinedFromAbove(customOp.getRegion());
// 2. Create an empty function with arguments being the operands of the custom
// op and values captured from above in the custom op.
auto operandTypes = llvm::to_vector(customOp->getOperandTypes());
auto valuesDefinedAboveTypes =
llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getType(); });
operandTypes.append(valuesDefinedAboveTypes.begin(),
valuesDefinedAboveTypes.end());
auto dummyFuncType =
FunctionType::get(context, operandTypes, customOp->getResultTypes());
std::string dummyFuncName =
std::string("__") + funcOp.getName().str() + "_config_setting__";
auto dummyFuncOp = rewriter.create<func::FuncOp>(
customOp.getLoc(), dummyFuncName, dummyFuncType);
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
if (targetAttr) {
dummyFuncOp->setAttr(IREE::HAL::ExecutableTargetAttr::name, targetAttr);
}
// 3. Clone the custom op into the function
SmallVector<Location> locs = llvm::map_to_vector(
customOp->getOperands(), [](Value v) { return v.getLoc(); });
auto valuesDefinedAboveLocs =
llvm::map_range(valuesDefinedAbove, [](Value v) { return v.getLoc(); });
locs.append(valuesDefinedAboveLocs.begin(), valuesDefinedAboveLocs.end());
Block *body =
rewriter.createBlock(&dummyFuncOp.getRegion(),
dummyFuncOp.getRegion().begin(), operandTypes, locs);
rewriter.setInsertionPointToStart(body);
IRMapping map;
map.map(customOp.getOperands(),
body->getArguments().take_front(customOp.getNumOperands()));
map.map(valuesDefinedAbove.getArrayRef(),
body->getArguments().take_back(valuesDefinedAbove.size()));
for (auto op : constantOps) {
rewriter.clone(*op, map);
}
auto clonedCustomOp = cast<IREE::LinalgExt::CustomOp>(
rewriter.clone(*customOp.getOperation(), map));
rewriter.create<func::ReturnOp>(customOp.getLoc(),
clonedCustomOp->getResults());
CustomOpConfigListener customOpConfigListener(customOp, clonedCustomOp);
// 4. Inline the cloned custom op.
rewriter.setInsertionPoint(clonedCustomOp);
FailureOr<SmallVector<Value>> replacements =
clonedCustomOp.decomposeOperation(rewriter);
if (failed(replacements)) {
return customOp.emitOpError(
"failed to decompose op during custom op configuration setting");
}
rewriter.replaceOp(clonedCustomOp, replacements.value());
// 5. Run canonicalizations on the created function to constant propagate the
// shape.
RewritePatternSet patterns(context);
auto addCanonicalizationPatterns = [&context,
&patterns](StringRef dialectName) {
context->getLoadedDialect(dialectName)
->getCanonicalizationPatterns(patterns);
};
addCanonicalizationPatterns(linalg::LinalgDialect::getDialectNamespace());
addCanonicalizationPatterns(
IREE::LinalgExt::IREELinalgExtDialect::getDialectNamespace());
tensor::CastOp::getCanonicalizationPatterns(patterns, context);
addCanonicalizationPatterns(tensor::TensorDialect::getDialectNamespace());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
GreedyRewriteConfig config;
config.listener = &customOpConfigListener;
if (failed(applyPatternsGreedily(dummyFuncOp, std::move(patterns), config))) {
return customOp.emitOpError(
"failed to canonicalize during custom op configuration setting");
}
// 6. Run set configuration on the new dummy function.
if (failed(configFn(dummyFuncOp))) {
return customOp.emitOpError("failed to set configuration for custom op");
}
// 7. Set translation info and lowering config for the custom op.
IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(dummyFuncOp);
// Move lowering config from ops in the cloned function to the ops
// within the body of the custom op.
// TODO: This logic needs to be made more robust (by account for indexing maps
// specified for operands on the custom op and the indexing maps of the
// operations within the region of the custom op). For now, just use the first
// operation with lowering config.
std::optional<SmallVector<int64_t>> workgroupTileSizes;
std::optional<SmallVector<int64_t>> workgroupInterchange;
for (Operation &op : dummyFuncOp.getBody().front()) {
auto currLoweringConfig =
getLoweringConfig<IREE::Codegen::LoweringConfigAttrInterface>(&op);
if (!currLoweringConfig)
continue;
// Translate the lowering config to the original operation.
if (std::optional<Operation *> originalOperation =
customOpConfigListener.getOrigOp(&op)) {
setLoweringConfig(originalOperation.value(), currLoweringConfig);
}
auto currWorkgroupTileSizes = currLoweringConfig.getWorkgroupTileSizes();
if (currWorkgroupTileSizes.empty())
continue;
workgroupTileSizes = currWorkgroupTileSizes;
workgroupInterchange = currLoweringConfig.getWorkgroupInterchange();
}
IREE::Codegen::LoweringConfigAttr loweringConfig;
if (workgroupTileSizes) {
loweringConfig = IREE::Codegen::LoweringConfigAttr::get(
context, workgroupTileSizes.value_or(SmallVector<int64_t>{}),
workgroupInterchange.value_or(SmallVector<int64_t>{}));
}
if (failed(setOpConfigAndEntryPointFnTranslation(
funcOp, customOp, loweringConfig, translationInfo))) {
return funcOp.emitOpError("failed to set custom op configuration");
}
rewriter.eraseOp(dummyFuncOp);
return success();
}
//===----------------------------------------------------------------------===//
// Utility functions to set configurations
//===----------------------------------------------------------------------===//
/// Returns the first of `exprs` which is of the type `T`.
template <typename T>
static AffineExpr getAffineExprOfType(ArrayRef<AffineExpr> exprs) {
if (auto it = llvm::find_if(exprs, llvm::IsaPred<T>); it != exprs.end())
return *it;
return nullptr;
}
/// Returns a Value that represents the value for symbol or dim expr for the map
/// in the `applyOp`.
static Value getValueForDimOrSymbol(affine::AffineApplyOp applyOp,
AffineExpr expr) {
unsigned numDims = applyOp.getAffineMap().getNumDims();
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
return applyOp.getOperand(dimExpr.getPosition());
}
if (auto symbolExpr = dyn_cast<AffineSymbolExpr>(expr)) {
return applyOp.getOperand(numDims + symbolExpr.getPosition());
}
return nullptr;
}
static SmallVector<Value>
getValuesForDimsOrSymbols(affine::AffineApplyOp applyOp,
ArrayRef<AffineExpr> exprs) {
SmallVector<Value> vals;
for (auto expr : exprs) {
vals.push_back(getValueForDimOrSymbol(applyOp, expr));
}
return vals;
}
/// Returns the dimension for any operation that implements processor op
/// interfaces.
template <typename T>
static std::optional<unsigned> getDimension(Operation *op) {
if (auto tOp = dyn_cast<T>(op)) {
return tOp.getDimIndex();
}
return std::nullopt;
}
template <typename T1, typename T2, typename... T3>
static std::optional<unsigned> getDimension(Operation *op) {
if (!op)
return std::nullopt;
if (auto dimension = getDimension<T1>(op)) {
return dimension;
}
return getDimension<T2, T3...>(op);
}
/// Checks that all `vals` are defined by some processor id/count/size ops using
/// the same `dimension`. If any element of `vals` is not defined by one of
/// these ops, or the dimensions dont match, returns std::nullopt; oterhwise,
/// returns the dimension. If `refDimension` is passed checks if the dimension
/// matches the given value.
template <typename... T>
static std::optional<unsigned>
checkDimensions(ArrayRef<Value> vals,
std::optional<unsigned> refDimension = std::nullopt) {
for (auto v : vals) {
auto currDimension = getDimension<T...>(v.getDefiningOp());
if (!currDimension)
return std::nullopt;
if (refDimension) {
if (refDimension.value() != currDimension.value()) {
return std::nullopt;
}
} else {
refDimension = currDimension.value();
}
}
return refDimension;
}
namespace {
/// Visitor to walk `lb` of a distributed loop. Expected the expression to be of
/// the form `a + b * c`, where `a` is the original `lb` and `b`, `c` are either
/// hal.interface.workgroup.id or hal.interface.workgroup.size.
class LowerBoundExprVisitor
: public AffineExprVisitor<LowerBoundExprVisitor, LogicalResult> {
public:
LowerBoundExprVisitor(affine::AffineApplyOp applyOp,
LoopTilingAndDistributionInfo &loopInfo)
: applyOp(applyOp), loopInfo(loopInfo) {}
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
AffineExpr offsetExpr =
getAffineExprOfType<AffineBinaryOpExpr>({expr.getLHS(), expr.getRHS()});
if (!offsetExpr) {
// One of the expressions has to be a binary op expr.
return failure();
}
// The other expression must be the undistributed `lb`.
AffineExpr lbExpr =
(offsetExpr == expr.getLHS() ? expr.getRHS() : expr.getLHS());
if (isa<AffineDimExpr, AffineSymbolExpr>(lbExpr)) {
Value v = getValueForDimOrSymbol(applyOp, lbExpr);
if (!v) {
return failure();
}
loopInfo.untiledLowerBound = getAsOpFoldResult(v);
} else if (auto constExpr = dyn_cast<AffineConstantExpr>(lbExpr)) {
loopInfo.untiledLowerBound = IntegerAttr::get(
IndexType::get(applyOp.getContext()), constExpr.getValue());
} else {
return failure();
}
return visit(offsetExpr);
}
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
SmallVector<Value> vals;
std::optional<unsigned> dimension;
// workgroupSizeOp may have been folded into a constant expression.
if (auto wgSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
vals = getValuesForDimsOrSymbols(applyOp, {expr.getLHS()});
if (vals.size() != 1 || !vals[0]) {
return failure();
}
loopInfo.tileSize = wgSize.getValue();
dimension = checkDimensions<ProcessorIDInterface>(vals);
} else {
vals = getValuesForDimsOrSymbols(applyOp, {expr.getLHS(), expr.getRHS()});
if (vals.size() != 2 || !vals[0] || !vals[1]) {
return failure();
}
IntegerAttr tileSizeAttr;
if (matchPattern(vals[1], m_Constant(&tileSizeAttr))) {
loopInfo.tileSize = tileSizeAttr.getInt();
dimension = checkDimensions<ProcessorIDInterface>(vals[0]);
} else {
dimension =
checkDimensions<ProcessorIDInterface, ProcessorTileSizeInterface>(
vals);
}
}
if (!dimension) {
return failure();
}
loopInfo.processorDistributionDim = dimension.value();
if (!loopInfo.untiledLowerBound) {
loopInfo.untiledLowerBound =
IntegerAttr::get(IndexType::get(applyOp.getContext()), 0);
}
return success();
}
private:
affine::AffineApplyOp applyOp;
LoopTilingAndDistributionInfo &loopInfo;
};
/// Visitor to walk the `step` of a distributed loop. Expected the expression to
/// be of the form `a * b * c`, where they could be the dynamic `step` or
/// defined by `hal.interface.workgroup.size`/`hal.interface.workgroup.count`
/// operation.
class StepExprVisitor
: public AffineExprVisitor<StepExprVisitor, LogicalResult> {
public:
StepExprVisitor(affine::AffineApplyOp applyOp,
LoopTilingAndDistributionInfo &loopInfo)
: applyOp(applyOp), loopInfo(loopInfo) {}
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
// Check if one of the operands is a binary op expr.
SmallVector<AffineExpr> sentinels;
if (auto e = getAffineExprOfType<AffineBinaryOpExpr>(
{expr.getLHS(), expr.getRHS()})) {
AffineExpr otherExpr =
(e == expr.getLHS() ? expr.getRHS() : expr.getLHS());
if (failed(processSentinel(otherExpr, sentinels))) {
return failure();
}
expr = cast<AffineBinaryOpExpr>(e);
} else {
// Check if the workgroup tile size is folded into the affine map itself.
if (loopInfo.tileSize) {
if (auto stepCst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
loopInfo.untiledStep =
IntegerAttr::get(IndexType::get(applyOp.getContext()),
stepCst.getValue() / *loopInfo.tileSize);
} else {
auto stepValue = getValueForDimOrSymbol(applyOp, expr.getRHS());
IntegerAttr tileSizeAttr;
if (stepValue && matchPattern(stepValue, m_Constant(&tileSizeAttr))) {
loopInfo.untiledStep =
IntegerAttr::get(IndexType::get(applyOp.getContext()),
tileSizeAttr.getInt() / *loopInfo.tileSize);
}
}
} else {
loopInfo.untiledStep =
IntegerAttr::get(IndexType::get(applyOp.getContext()), 1);
}
}
if (failed(processSentinel(expr.getLHS(), sentinels)) ||
(!loopInfo.tileSize &&
failed(processSentinel(expr.getRHS(), sentinels)))) {
return failure();
}
// Either there are 3 sentinels and step isnt set, or there are two
// sentinels and the step is set.
if (sentinels.size() == 3) {
if (loopInfo.untiledStep) {
return failure();
}
auto it = sentinels.begin();
for (auto ie = sentinels.end(); it != ie; ++it) {
Value v = getValueForDimOrSymbol(applyOp, *it);
if (!v.getDefiningOp<IREE::HAL::InterfaceWorkgroupSizeOp>() &&
!v.getDefiningOp<IREE::HAL::InterfaceWorkgroupCountOp>()) {
loopInfo.untiledStep = getAsOpFoldResult(v);
break;
}
}
if (it != sentinels.end()) {
sentinels.erase(it);
}
}
if ((sentinels.size() != 2 || !loopInfo.untiledStep) &&
(sentinels.size() != 1 || !loopInfo.tileSize)) {
return failure();
}
SmallVector<Value> vals = getValuesForDimsOrSymbols(applyOp, sentinels);
if ((loopInfo.tileSize && !checkDimensions<ProcessorCountInterface>(
vals, loopInfo.processorDistributionDim)) ||
(!loopInfo.tileSize &&
!checkDimensions<ProcessorCountInterface, ProcessorTileSizeInterface>(
vals, loopInfo.processorDistributionDim))) {
return failure();
}
return success();
}
private:
LogicalResult processSentinel(AffineExpr e,
SmallVectorImpl<AffineExpr> &sentinels) {
if (isa<AffineDimExpr, AffineSymbolExpr>(e)) {
sentinels.push_back(e);
return success();
} else if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
if (loopInfo.untiledStep) {
return failure();
}
loopInfo.untiledStep = IntegerAttr::get(
IndexType::get(applyOp.getContext()), constExpr.getValue());
return success();
}
return failure();
}
affine::AffineApplyOp applyOp;
LoopTilingAndDistributionInfo &loopInfo;
};
} // namespace
template <typename OpTy>
static std::optional<unsigned> getInterfaceWorkgroupOpDim(Value value) {
if (auto op = value.getDefiningOp<OpTy>()) {
return op.getDimension().getZExtValue();
}
return std::nullopt;
}
/// Checks if the `forOp` is a tiled + distributed op. Looks for the op of this
/// form
/// ```
/// %dim = arith.constant ... : index
/// %id = stream.dispatch.workgroup.id[%dim]
/// %count = stream.dispatch.workgroup.count[%dim]
/// %size = stream.dispatch.workgroup.size[%dim]
/// %offset = affine.apply
/// affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>(%lb)[%id, %size]
/// %new_step = affine.apply
/// affine_map<(d0)[s0, s1] -> (d0 * s0 * s1)>(%step)[%id, %size]
/// scf.for %iv = %offset to %ub step %new_step { ... }
/// ```
std::optional<LoopTilingAndDistributionInfo>
isTiledAndDistributedLoop(scf::ForOp forOp) {
LoopTilingAndDistributionInfo loopInfo;
loopInfo.loop = forOp;
loopInfo.untiledUpperBound = getAsOpFoldResult(forOp.getUpperBound());
auto lbApplyOp = forOp.getLowerBound().getDefiningOp<affine::AffineApplyOp>();
auto stepApplyOp = forOp.getStep().getDefiningOp<affine::AffineApplyOp>();
if (!lbApplyOp || !stepApplyOp) {
// Try to see if this is a specical case where we have:
// scf.for %iv = %id to %ub step %count
std::optional<unsigned> idDim;
if (auto ifx = dyn_cast_or_null<ProcessorIDInterface>(
forOp.getLowerBound().getDefiningOp())) {
idDim = ifx.getDimIndex();
}
std::optional<unsigned> countDim;
if (auto ifx = dyn_cast_or_null<ProcessorCountInterface>(
forOp.getStep().getDefiningOp())) {
countDim = ifx.getDimIndex();
}
if (!idDim || !countDim)
return std::nullopt;
Builder b(forOp.getContext());
loopInfo.untiledLowerBound = b.getIndexAttr(0);
loopInfo.untiledStep = b.getIndexAttr(1);
loopInfo.processorDistributionDim = idDim.value();
// For such special case, the tile size is 1.
loopInfo.tileSize = 1;
return loopInfo;
}
LowerBoundExprVisitor lbVisitor(lbApplyOp, loopInfo);
StepExprVisitor stepVisitor(stepApplyOp, loopInfo);
if (failed(lbVisitor.visit(lbApplyOp.getAffineMap().getResults()[0]))) {
return std::nullopt;
}
if (failed(stepVisitor.visit(stepApplyOp.getAffineMap().getResults()[0]))) {
return std::nullopt;
}
if (!loopInfo.untiledLowerBound || !loopInfo.untiledStep) {
return std::nullopt;
}
return loopInfo;
}
SmallVector<Operation *> getComputeOps(Operation *containingOp) {
if (containingOp->getNumRegions() == 0) {
return {};
}
assert(containingOp->getNumRegions() == 1 &&
"expected op with a single region");
SmallVector<Operation *> computeOps;
containingOp->getRegion(0).walk([&](Operation *op) {
if (isa<TilingInterface, IREE::Codegen::UKernelOpInterface>(op)) {
computeOps.push_back(op);
}
});
return computeOps;
}
SmallVector<LoopTilingAndDistributionInfo>
getTiledAndDistributedLoopInfo(mlir::FunctionOpInterface funcOp) {
SmallVector<LoopTilingAndDistributionInfo> info;
funcOp.walk([&](scf::ForOp forOp) {
if (auto tiledLoopInfo = isTiledAndDistributedLoop(forOp)) {
info.emplace_back(std::move(tiledLoopInfo.value()));
}
});
return info;
}
void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface op,
ArrayRef<int64_t> tileSizes,
ArrayRef<bool> tileScalableFlags) {
// scf::tileUsingSCFForOp expects the num of tile sizes = num of loops.
int numLoops = op.getLoopIteratorTypes().size();
SmallVector<int64_t> fixedTileSizes(tileSizes);
fixedTileSizes.resize(numLoops, /*default=*/0);
SmallVector<bool> fixedTileScalableFlags(tileScalableFlags);
fixedTileScalableFlags.resize(numLoops, /*default=*/false);
if (!llvm::is_contained(fixedTileScalableFlags, true)) {
// Non-scalable case: All constant tile sizes.
options.setTileSizes(
getAsIndexOpFoldResult(op.getContext(), fixedTileSizes));
} else {
// Scalable case: Multiply scalable tile sizes by a vector.vscale op.
options.setTileSizeComputationFunction(
[=](OpBuilder &b, Operation *op) -> SmallVector<OpFoldResult> {
auto loc = op->getLoc();
return llvm::map_to_vector(
llvm::zip(fixedTileSizes, fixedTileScalableFlags),
[&](auto pair) -> OpFoldResult {
auto [t, isScalable] = pair;
Value size = b.create<arith::ConstantIndexOp>(loc, t);
if (isScalable) {
Value vscale = b.create<vector::VectorScaleOp>(loc);
size = b.create<arith::MulIOp>(loc, size, vscale);
}
return size;
});
});
}
}
/// Create a linalg::GenericOp version of an n-D copy that can further tile,
/// lower to loops or vectorize, unlike the current implementation of
/// memref::CopyOp.
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, Value to,
ArrayRef<NamedAttribute> attributes) {
auto memrefTypeFrom = llvm::dyn_cast<MemRefType>(from.getType());
auto memrefTypeTo = llvm::dyn_cast<MemRefType>(to.getType());
if (!memrefTypeFrom || !memrefTypeTo ||
memrefTypeFrom.getRank() != memrefTypeTo.getRank()) {
mlir::emitError(
loc, "unable to generate copy op within bufferization from type ")
<< memrefTypeFrom << " to " << memrefTypeTo;
return nullptr;
}
AffineMap id =
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
utils::IteratorType::parallel);
return b.create<linalg::GenericOp>(
loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front());
},
attributes);
}
template <typename OpTy>
static Value buildHALWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
}
linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
linalg::DistributionMethod distributionMethod,
int32_t maxWorkgroupParallelDims) {
return {[distributionMethod,
maxWorkgroupParallelDims](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
auto numParallelDims = parallelLoopRanges.size();
SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
std::optional<OpFoldResult> splitDim;
for (size_t dim = 0; dim < numParallelDims; ++dim) {
if (numParallelDims > maxWorkgroupParallelDims &&
dim >= maxWorkgroupParallelDims - 1) {
if (!splitDim) {
splitDim = buildHALWorkgroupInfoOp<IREE::HAL::InterfaceWorkgroupIDOp>(
builder, maxWorkgroupParallelDims - 1);
}
OpFoldResult size = parallelLoopRanges[numParallelDims - dim - 1].size;
OpFoldResult offset =
parallelLoopRanges[numParallelDims - dim - 1].offset;
OpFoldResult step =
parallelLoopRanges[numParallelDims - dim - 1].stride;
AffineExpr d0, d1, d2;
bindSymbols(builder.getContext(), d0, d1, d2);
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
builder, loc, (d1 - d0).ceilDiv(d2), {offset, size, step});
OpFoldResult dimValue;
if (dim == numParallelDims - 1)
dimValue = splitDim.value();
else {
dimValue = affine::makeComposedFoldedAffineApply(
builder, loc, (d0 % d1), {splitDim.value(), numTiles});
splitDim = affine::makeComposedFoldedAffineApply(
builder, loc, (d0).floorDiv(d1), {splitDim.value(), numTiles});
}
procInfo[numParallelDims - dim - 1] = {
getValueOrCreateConstantIndexOp(builder, loc, dimValue),
getValueOrCreateConstantIndexOp(builder, loc, numTiles),
distributionMethod};
continue;
}
procInfo[numParallelDims - dim - 1] = {
buildHALWorkgroupInfoOp<IREE::HAL::InterfaceWorkgroupIDOp>(builder,
dim),
buildHALWorkgroupInfoOp<IREE::HAL::InterfaceWorkgroupCountOp>(builder,
dim),
distributionMethod};
}
return procInfo;
}};
}
static constexpr char pipeliningDepthName[] = "pipeline_depth";
static constexpr char pipeliningStageName[] = "store_stage";
DictionaryAttr
getSoftwarePipeliningAttrDict(MLIRContext *context,
unsigned softwarePipelineDepth,
unsigned softwarePipelineStoreStage) {
SmallVector<NamedAttribute> attrs;
attrs.push_back(
{StringAttr::get(context, pipeliningDepthName),
IntegerAttr::get(IntegerType::get(context, 64), softwarePipelineDepth)});
attrs.push_back({StringAttr::get(context, pipeliningStageName),
IntegerAttr::get(IntegerType::get(context, 64),
softwarePipelineStoreStage)});
return DictionaryAttr::get(context, attrs);
}
FailureOr<int64_t> getSoftwarePipelineDepth(DictionaryAttr config) {
if (!config) {
return failure();
}
Attribute depth = config.get(pipeliningDepthName);
if (!depth) {
return failure();
}
return llvm::cast<IntegerAttr>(depth).getInt();
}
FailureOr<int64_t> getSoftwarePipelineStoreStage(DictionaryAttr config) {
if (!config) {
return failure();
}
Attribute stage = config.get(pipeliningStageName);
if (!stage) {
return failure();
}
return llvm::cast<IntegerAttr>(stage).getInt();
}
/// Returns a small tiling factor for the given reduction `dimSize`.
/// Returns 0 to avoid tiling.
int getReductionTilingFactor(int64_t dimSize) {
if (dimSize % 4 == 0)
return 4;
// Try to find the smallest prime factor as the tiling factor. As a trade off
// between generated code size and compilation time, only look at prime
// numbers less than 50 right now.
static constexpr std::array<int, 15> primeNumbers = {
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47};
for (int n : primeNumbers) {
if (dimSize % n == 0)
return n;
}
return 1; // Otherwise just tile with size 1.
}
int64_t getMinElementBitwidth(linalg::LinalgOp linalgOp) {
unsigned bitwidth = std::numeric_limits<unsigned>::max();
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
unsigned b =
IREE::Util::getTypeBitWidth(getElementTypeOrSelf(operand->get()));
bitwidth = std::min(bitwidth, b);
}
for (Value result : linalgOp.getDpsInits()) {
unsigned b = IREE::Util::getTypeBitWidth(getElementTypeOrSelf(result));
bitwidth = std::min(bitwidth, b);
}
return bitwidth;
};
//===---------------------------------------------------------------------===//
// Misc. utility functions
//===---------------------------------------------------------------------===//
OpFoldResult convertByteOffsetToElementOffset(RewriterBase &rewriter,
Location loc,
OpFoldResult byteOffset,
Type elementType) {
if (isa<ComplexType, FloatType, IntegerType, VectorType>(elementType)) {
unsigned typeBitWidth = IREE::Util::getTypeBitWidth(elementType);
assert(llvm::isPowerOf2_32(typeBitWidth) &&
"unhandled non powers of 2 bit width while converting byte offset "
"to element offset");
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(
rewriter, loc, (s0 * 8).floorDiv(typeBitWidth),
{byteOffset, rewriter.getIndexAttr(typeBitWidth)});
} else {
OpFoldResult elementByteSize =
rewriter.create<IREE::Util::SizeOfOp>(loc, elementType).getResult();
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1),
{byteOffset, elementByteSize});
}
}
LogicalResult isArgmaxOp(linalg::GenericOp genericOp) {
// Check for 2 results(value, index), and 1 input
if (genericOp.getNumDpsInits() != 2) {
return failure();
}
if (genericOp.getNumDpsInputs() != 1) {
return failure();
}
// If max value is being used, it is not a pure argmax.
if (!genericOp.getResults()[0].use_empty()) {
return failure();
}
// Check that the rank is at least 3 and all loops are parallel
unsigned numLoops = genericOp.getNumLoops();
unsigned numParallelLoops = genericOp.getNumParallelLoops();
// Argmax will require 1D reduction.
if (numParallelLoops != (numLoops - 1)) {
return failure();
}
// TODO: Add better affine map checks.
auto indexing_maps = genericOp.getIndexingMapsArray();
if (!indexing_maps[0].isIdentity())
return failure();
// Check that initial value is negative Infinite.
// TODO: Move this check to ukernel once we implement
// variant to handle non neg-Inf initial value.
Value initVal = genericOp.getDpsInitOperand(0)->get();
auto fillOp = initVal.getDefiningOp<linalg::FillOp>();
if (!fillOp)
return failure();
Value fillVal = fillOp.getDpsInputOperand(0)->get();
if (!matchPattern(fillVal, m_NegInfFloat()))
return failure();
// Work back from linalg.yield and check body of genericOp.
// The genericOp should yield the result of an arith.select,
// preceded by an arith.cmpf, arith.maximumf, and arith.extui
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
Value producerOutput;
Operation *producer;
// Producer of linalg.yield 1st arg is arith.maximumf
{
producerOutput = yieldOp->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::MaximumFOp>())) {
return failure();
}
}
// Producer of linalg.yield op 2nd arg is arith.select
// TODO: Add check that select is selecting between linalg.index and index of
// current max.
{
producerOutput = yieldOp->getOperand(1);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::SelectOp>())) {
return failure();
}
}
// Producer of arith.select op is arith.cmpf
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
auto producerCmpFOp = dyn_cast<arith::CmpFOp>(producer);
if (!producerCmpFOp) {
return failure();
}
if (producerCmpFOp.getPredicate() != arith::CmpFPredicate::OGT) {
return failure();
}
// Check that in and out of cmpf are loop variables.
// Currently first operand is disabled because it may be mixed type
// which would lead it to be extf(%arg0).
// TODO: Add better mixed type support check.
if (producer->getOperand(1) != genericOp.getBody()->getArgument(1)) {
return failure();
}
}
return success();
}
//===---------------------------------------------------------------------===//
// Replace Memref users (transitively)
//===---------------------------------------------------------------------===//
/// Replaces a `use` with the `replacement` for cases where a simple
/// substition might lead to verification errors.
static std::optional<SmallVector<Value>>
replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
Value replacement) {
Operation *user = use.getOwner();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(user);
LLVM_DEBUG({
llvm::dbgs() << "\tReplacing in user by creating new user : ";
user->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
auto replacementType = llvm::cast<MemRefType>(replacement.getType());
auto currentResultType =
llvm::cast<MemRefType>(castOp.getResult().getType());
if (replacementType == currentResultType) {
// Cast is a no op, just return the replacement.
return SmallVector<Value>{replacement};
}
auto newResultType = MemRefType::get(
currentResultType.getShape(), currentResultType.getElementType(),
replacementType.getLayout(), replacementType.getMemorySpace());
auto newCastOp =
rewriter.create<memref::CastOp>(loc, newResultType, replacement);
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newCastOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return SmallVector<Value>(newCastOp->result_begin(),
newCastOp->result_end());
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
auto currResultType =
llvm::cast<MemRefType>(subviewOp.getResult().getType());
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
SmallVector<OpFoldResult> offsets = subviewOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = subviewOp.getMixedSizes();
SmallVector<OpFoldResult> strides = subviewOp.getMixedStrides();
MemRefType newResultType =
(currResultType.getRank() != newSourceType.getRank()
? llvm::cast<MemRefType>(
memref::SubViewOp::inferRankReducedResultType(
currResultType.getShape(), newSourceType, offsets, sizes,
strides))
: nullptr);
auto newSubviewOp = rewriter.create<memref::SubViewOp>(
loc, newResultType, replacement, offsets, sizes, strides);
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newSubviewOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newSubviewOp->getResults());
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(user)) {
auto currResultType =
llvm::cast<MemRefType>(expandOp.getResult().getType());
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
FailureOr<MemRefType> newResultType =
memref::ExpandShapeOp::computeExpandedType(
newSourceType, currResultType.getShape(),
expandOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}
auto newExpandOp = rewriter.create<memref::ExpandShapeOp>(
loc, *newResultType, replacement, expandOp.getReassociation(),
expandOp.getOutputShape(), expandOp.getStaticOutputShape());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newExpandOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
FailureOr<MemRefType> newResultType =
memref::CollapseShapeOp::computeCollapsedType(
newSourceType, collapseOp.getReassociationIndices());
if (failed(newResultType)) {
return std::nullopt;
}
auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
loc, *newResultType, replacement, collapseOp.getReassociation());
LLVM_DEBUG({
llvm::dbgs() << "\t\tNew user : ";
newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return llvm::to_vector_of<Value>(newCollapseOp->getResults());
}
return std::nullopt;
}
void replaceMemrefUsesAndPropagateType(RewriterBase &rewriter, Location loc,
Value origValue,
Value replacementValue) {
SmallVector<std::pair<Value, Value>> worklist;
SmallVector<Operation *> toDeleteUsers;
worklist.push_back({origValue, replacementValue});
while (!worklist.empty()) {
auto [original, replacement] = worklist.pop_back_val();
LLVM_DEBUG({
llvm::dbgs() << "//===------------------------------------------===//\n";
llvm::dbgs() << "Replacing : ";
original.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
llvm::SmallDenseSet<OpOperand *> preservedUses;
if (original.getType() != replacement.getType()) {
for (OpOperand &use : original.getUses()) {
Operation *user = use.getOwner();
// Some uses cannot be replaced.
if (user->hasTrait<OpTrait::ReturnLike>()) {
LLVM_DEBUG({
llvm::dbgs() << "\tUnhandled user : ";
user->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
preservedUses.insert(&use);
continue;
}
// Some uses might be replace-able but require creating new versions
// of the users to pass verification.
std::optional<SmallVector<Value>> nonTrivialUse =
replaceNonTrivialUse(rewriter, loc, use, replacement);
if (nonTrivialUse) {
// Add the results of the new users created as replacements
// for the old users. Push this back on the to worklist.
preservedUses.insert(&use);
for (auto [v1, v2] :
llvm::zip_equal(user->getResults(), nonTrivialUse.value())) {
worklist.push_back({v1, v2});
}
toDeleteUsers.push_back(user);
continue;
}
}
}
// Replace all non-preserved uses.
rewriter.replaceUsesWithIf(original, replacement, [&](OpOperand &use) {
if (!preservedUses.count(&use)) {
LLVM_DEBUG({
llvm::dbgs() << "\t\tReplacing use in :";
use.getOwner()->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified());
llvm::dbgs() << "\n";
});
return true;
}
return false;
});
}
// Iterate over delete-able operations in reverse and delete if
// there are no users.
for (auto deleteOp : llvm::reverse(toDeleteUsers)) {
if (deleteOp->use_empty()) {
rewriter.eraseOp(deleteOp);
}
}
}
void sinkOpsInCFG(const SmallVector<Operation *> &allocs,
DominanceInfo &dominators) {
for (Operation *sinkOp : allocs) {
Block *dom = nullptr;
for (Operation *user : sinkOp->getUsers()) {
if (!dom) {
dom = user->getBlock();
// Find the block in the same region.
while (dom->getParent() != sinkOp->getParentRegion()) {
dom = dom->getParentOp()->getBlock();
}
continue;
}
dom = dominators.findNearestCommonDominator(dom, user->getBlock());
}
llvm::SmallDenseSet<Operation *> users;
for (Operation *user : sinkOp->getUsers()) {
while (user->getParentRegion() != sinkOp->getParentRegion()) {
user = user->getParentOp();
}
users.insert(user);
}
Operation *firstUse = dom->getTerminator();
for (Operation &op : dom->getOperations()) {
if (users.count(&op)) {
firstUse = &op;
break;
}
}
sinkOp->moveBefore(firstUse);
}
}
/// Infer the number of workgroups from exportOp.
SmallVector<int64_t> getStaticNumWorkgroups(mlir::FunctionOpInterface funcOp) {
SmallVector<int64_t> result;
std::optional<IREE::HAL::ExecutableExportOp> exportOp = getEntryPoint(funcOp);
if (!exportOp)
return result;
Block *body = exportOp->getWorkgroupCountBody();
if (!body)
return result;
auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
assert(returnOp.getNumOperands() == 3);
for (unsigned i = 0; i < 3; ++i) {
Operation *defOp = returnOp.getOperand(i).getDefiningOp();
if (auto indexOp = dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
result.push_back(indexOp.value());
} else {
result.push_back(ShapedType::kDynamic);
}
}
return result;
}
bool hasFusedLeadingOp(linalg::LinalgOp rootOp) {
assert(rootOp.getNumDpsInputs() == 2 && "rootOp expected to have two inputs");
BackwardSliceOptions options;
options.inclusive = true;
// Get the backward slice of each input operand and take the union.
SetVector<Operation *> backwardSlice;
for (OpOperand *operand : rootOp.getDpsInputOperands()) {
SetVector<Operation *> tmpBackwardSlice;
getBackwardSlice(operand->get(), &tmpBackwardSlice, options);
backwardSlice.set_union(tmpBackwardSlice);
}
return llvm::any_of(backwardSlice, llvm::IsaPred<linalg::LinalgOp>);
}
std::optional<vector::VscaleRange>
getDefaultVscaleRange(IREE::HAL::ExecutableTargetAttr targetAttr) {
if (isAArch64(targetAttr)) {
// On AArch64 the scalable vector length will always be between 128-bit and
// 2048-bit. This works out as a vscale range of 1 to 16. See:
// https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions
return vector::VscaleRange{1, 16};
}
// TODO: Implement for other architectures.
return std::nullopt;
}
FailureOr<DimBoundSize>
computeDimUpperBound(Value shapedValue, unsigned dimNum,
std::optional<vector::VscaleRange> vscaleRange,
RoundUpVscaleMultiple roundUp) {
if (!vscaleRange.has_value()) {
FailureOr<int64_t> maybeDimBoundSize =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, {shapedValue, dimNum},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (succeeded(maybeDimBoundSize))
return DimBoundSize{/*baseSize=*/*maybeDimBoundSize,
/*scalable=*/false};
return failure();
}
FailureOr<DimBound> maybeDimBound =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
shapedValue, dimNum,
/*vscaleMin=*/vscaleRange->vscaleMin,
/*vscaleMax=*/vscaleRange->vscaleMax, presburger::BoundType::UB);
if (failed(maybeDimBound))
return failure();
auto boundSize = maybeDimBound->getSize();
if (succeeded(boundSize))
return boundSize;
if (roundUp == RoundUpVscaleMultiple::No)
return failure();
// If the upper bound map is of the form `add(subExpr, cst)` (cst <= 0),
// round it up to `subExpr` (and try matching the bound again).
auto binOp = dyn_cast<AffineBinaryOpExpr>(maybeDimBound->map.getResult(0));
if (!binOp || binOp.getKind() != AffineExprKind::Add)
return failure();
auto cst = dyn_cast<AffineConstantExpr>(binOp.getRHS());
if (!cst || cst.getValue() > 0)
return failure();
DimBound roundedDimBound{AffineMap::get(maybeDimBound->map.getNumDims(),
maybeDimBound->map.getNumSymbols(),
binOp.getLHS())};
return roundedDimBound.getSize();
}
static bool isFullSlice(ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides,
IREE::Flow::DispatchTensorType tensorType,
ValueRange dynamicDims) {
OpBuilder builder(tensorType.getContext());
SmallVector<int64_t> tensorShape = llvm::to_vector(tensorType.getShape());
SmallVector<OpFoldResult> mixedTensorShape =
mlir::getMixedValues(tensorShape, dynamicDims, builder);
return areAllConstantIntValue(mixedOffsets, 0) &&
areAllConstantIntValue(mixedStrides, 1) &&
mixedTensorShape == mixedSizes;
}
bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp,
IREE::Flow::DispatchTensorType tensorType,
ValueRange dynamicDims) {
return isFullSlice(
sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(),
sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims);
}
} // namespace mlir::iree_compiler