blob: 9acd9657279def2676996cf9b45bedf8e5e38915 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
// Given a #hal.match.* expression tree returns a boolean value indicating
// whether the expression evaluates to true.
static Value buildConditionExpression(Location loc, Value device,
Attribute conditionAttr,
OpBuilder funcBuilder) {
if (auto matchAttr = conditionAttr.dyn_cast<IREE::HAL::MatchAlwaysAttr>()) {
// #hal.match.always -> true
return funcBuilder.createOrFold<ConstantIntOp>(loc, 1, 1);
} else if (auto matchAttr =
conditionAttr.dyn_cast<IREE::HAL::MatchAnyAttr>()) {
// #hal.match.any<[a, b, c]> -> or(or(a, b), c)
auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>();
auto conditionValues =
llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) {
return buildConditionExpression(loc, device, attr, funcBuilder);
}));
Value resultValue = conditionValues[0];
for (int i = 1; i < conditionValues.size(); ++i) {
resultValue =
funcBuilder.createOrFold<OrOp>(loc, resultValue, conditionValues[i]);
}
return resultValue;
} else if (auto matchAttr =
conditionAttr.dyn_cast<IREE::HAL::MatchAllAttr>()) {
// #hal.match.all<[a, b, c]> -> and(and(a, b), c)
auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>();
auto conditionValues =
llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) {
return buildConditionExpression(loc, device, attr, funcBuilder);
}));
Value resultValue = conditionValues[0];
for (int i = 1; i < conditionValues.size(); ++i) {
resultValue =
funcBuilder.createOrFold<AndOp>(loc, resultValue, conditionValues[i]);
}
return resultValue;
} else if (auto matchAttr =
conditionAttr.dyn_cast<IREE::HAL::DeviceMatchIDAttr>()) {
// #hal.device.match.id<"pattern"> -> hal.device.match.id
return funcBuilder.createOrFold<IREE::HAL::DeviceMatchIDOp>(
loc, funcBuilder.getI1Type(), device, matchAttr.patternAttr());
}
llvm_unreachable("unhandled condition expression attribute");
return {};
}
// Outlines a condition region from a switch op into a standalone function.
static FuncOp outlineConditionRegion(StringRef funcName,
Region &conditionRegion,
ArrayRef<Type> resultTypes,
OpBuilder moduleBuilder) {
auto &entryBlock = conditionRegion.front();
auto funcType = moduleBuilder.getFunctionType(
llvm::to_vector<4>(
llvm::map_range(entryBlock.getArguments(),
[](BlockArgument arg) { return arg.getType(); })),
resultTypes);
auto funcOp = moduleBuilder.create<FuncOp>(
conditionRegion.getLoc(), funcName, funcType, ArrayRef<NamedAttribute>{});
funcOp.getBody().takeBody(conditionRegion);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
// Replace hal.return statements with normal std.return. This ensures that
// normal matchers/inlining/etc works as we continue transformation.
for (auto &block : funcOp.getBlocks()) {
if (auto returnOp = dyn_cast<IREE::HAL::ReturnOp>(block.back())) {
OpBuilder builder(returnOp);
builder.create<mlir::ReturnOp>(
returnOp.getLoc(), llvm::to_vector<4>(returnOp.getOperands()));
returnOp.erase();
}
}
return funcOp;
}
// Outlines each switch condition region into its own function and replaces the
// switch op with conditioned calls to those functions.
//
// Since switch conditions are evaluated in the order they are defined we can
// trivially turn the switch into a chain of if-else blocks.
// if condition_0_match:
// call outlined_condition_0
// else
// if condition_1_match:
// call outlined_condition_1
// else ...
static void buildConditionDispatchTable(IREE::HAL::DeviceSwitchOp switchOp,
StringRef baseFuncName,
OpBuilder moduleBuilder,
OpBuilder funcBuilder) {
// Split the block containing the switch op such that all ops before the
// switch are before and the switch and the following ops are after.
// We'll have all of our outlined regions bounce over to the afterBlock with
// the results of the call and use that to replace the switch op.
auto *beforeBlock = funcBuilder.getBlock();
auto *afterBlock = beforeBlock->splitBlock(switchOp);
auto finalValues =
llvm::to_vector<4>(afterBlock->addArguments(switchOp.getResultTypes()));
// Create the blocks we'll use for all our conditions so that we can reference
// them when inserting the branch ops.
SmallVector<Block *, 4> conditionMatchBlocks(
switchOp.condition_regions().size());
SmallVector<Block *, 4> conditionFallthroughBlocks(
switchOp.condition_regions().size());
for (int i = 0; i < conditionMatchBlocks.size(); ++i) {
conditionMatchBlocks[i] = funcBuilder.createBlock(afterBlock);
conditionFallthroughBlocks[i] = funcBuilder.createBlock(afterBlock);
}
funcBuilder.setInsertionPoint(beforeBlock, beforeBlock->end());
int argOffset = 0;
for (auto condition : llvm::enumerate(llvm::zip(
switchOp.conditions().getValue(), switchOp.condition_regions()))) {
auto conditionAttr = std::get<0>(condition.value());
auto &conditionRegion = std::get<1>(condition.value());
// Get the arguments from the switch that we want to carry along in the
// block arguments.
auto regionOperands = conditionRegion.front().getArguments();
auto regionArgs = switchOp.args().slice(argOffset, regionOperands.size());
argOffset += regionOperands.size();
// Outline the region into a function.
std::string regionFuncName =
(baseFuncName + "_").str() + std::to_string(condition.index());
auto regionFuncOp = outlineConditionRegion(
regionFuncName, conditionRegion,
switchOp.getOperation()->getResultTypes(), moduleBuilder);
// Insert the branch based on the match. We either match and jump to a block
// that will call the function or don't match and need to fall through.
auto isMatch = buildConditionExpression(
switchOp.getLoc(), switchOp.device(), conditionAttr, funcBuilder);
auto *matchBlock = conditionMatchBlocks[condition.index()];
auto *fallthroughBlock = conditionFallthroughBlocks[condition.index()];
funcBuilder.create<CondBranchOp>(switchOp.getLoc(), isMatch, matchBlock,
fallthroughBlock);
// Block that calls the outlined function and then jumps out of the chain.
funcBuilder.setInsertionPointToStart(matchBlock);
auto matchResults =
funcBuilder.create<CallOp>(switchOp.getLoc(), regionFuncOp, regionArgs);
funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock,
matchResults.getResults());
// Block that we enter to check the next condition.
funcBuilder.setInsertionPointToStart(fallthroughBlock);
if (condition.index() + 1 < conditionFallthroughBlocks.size()) {
// Just continue on - the next loop iteration for the following condition
// will add its IR to the block.
} else {
// Fallthrough of all expressions; die if we expected return values.
if (switchOp.getNumResults() > 0) {
funcBuilder.create<IREE::UnreachableOp>(switchOp.getLoc());
} else {
funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock);
}
}
}
// Remove the switch op and replace its results with the final joined results.
switchOp.replaceAllUsesWith(finalValues);
switchOp.erase();
}
class OutlineDeviceSwitchesPass
: public PassWrapper<OutlineDeviceSwitchesPass, OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
auto moduleOp = getOperation();
auto funcOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>());
for (auto funcOp : llvm::enumerate(funcOps)) {
OpBuilder moduleBuilder(funcOp.value());
moduleBuilder.setInsertionPointAfter(funcOp.value());
for (auto block : llvm::enumerate(funcOp.value())) {
auto switchOps = llvm::to_vector<4>(
block.value().getOps<IREE::HAL::DeviceSwitchOp>());
for (auto switchOp : llvm::enumerate(switchOps)) {
std::string baseFuncName =
(funcOp.value().getName() + "_switch_").str() +
std::to_string(funcOp.index()) + "_" +
std::to_string(block.index()) + "_" +
std::to_string(switchOp.index());
OpBuilder funcBuilder(switchOp.value());
buildConditionDispatchTable(switchOp.value(), baseFuncName,
moduleBuilder, funcBuilder);
}
}
}
}
};
std::unique_ptr<OperationPass<ModuleOp>> createOutlineDeviceSwitchesPass() {
return std::make_unique<OutlineDeviceSwitchesPass>();
}
static PassRegistration<OutlineDeviceSwitchesPass> pass(
"iree-hal-outline-device-switches",
"Outlines hal.device.switch condition regions");
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir