blob: c8f80e60dd32c0fab4b56b0bc7a01d0281b12cc8 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/IR/Dialect.h"
#include "compiler/IR/Interpreter/HLDialect.h"
#include "compiler/IR/Interpreter/HLOps.h"
#include "compiler/IR/Ops.h"
#include "compiler/Transforms/ConversionUtils.h"
#include "compiler/Utils/MemRefUtils.h"
#include "compiler/Utils/OpCreationUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
LogicalResult convertReductionOp(FuncOp entryPoint, FuncOp applyFunc,
Operation *elementOp, OpBuilder &builder) {
// Ensure that this op is pass-through and does not interact with any other
// ops within the function.
// TODO(b/139313439): support fused reductions.
for (auto *operand : elementOp->getOperands()) {
if (operand->getDefiningOp() != nullptr) {
return elementOp->emitOpError()
<< "Fused reductions are not supported (operand not sourced from "
"block args)";
}
}
for (auto *result : elementOp->getResults()) {
for (auto *user : result->getUsers()) {
if (!user->isKnownTerminator()) {
return elementOp->emitOpError() << "Fused reductions are not supported "
"(result used by non-terminator)";
}
}
}
// Determine the index of the args we care about. We'll use these to match up
// the operands of the entry point with our application.
// Our arguments are expanded tuples like <lhs0, rhs0>, <lhs1, rhs1>, so this
// index gets the offset * 2.
auto &applyEntryBlock = applyFunc.getBlocks().front();
int setIndex = std::distance(applyEntryBlock.args_begin(),
llvm::find(applyEntryBlock.getArguments(),
elementOp->getOperand(0))) /
2;
// Map to the args from the entry point.
auto &entryPointEntryBlock = entryPoint.getBlocks().front();
Value *srcArg = entryPointEntryBlock.getArgument(setIndex);
Value *initArg = entryPointEntryBlock.getArgument(
applyFunc.getNumArguments() / 2 + setIndex);
Value *dstArg =
entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex);
auto dstType = dstArg->getType().cast<ShapedType>();
Type elementType = dstType.getElementType();
auto loc = elementOp->getLoc();
auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>(
"iree.executable.reduction.dimension");
Operation *expandedOp = nullptr;
if (isa<IREEInterp::HL::AddFOp>(elementOp) ||
isa<IREEInterp::HL::AddIOp>(elementOp)) {
if (elementType.isa<FloatType>()) {
expandedOp = builder.create<IREEInterp::HL::ReduceSumFOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
} else {
expandedOp = builder.create<IREEInterp::HL::ReduceSumIOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
}
} else if (isa<IREEInterp::HL::MinFOp>(elementOp) ||
isa<IREEInterp::HL::MinISOp>(elementOp) ||
isa<IREEInterp::HL::MinIUOp>(elementOp)) {
if (elementType.isa<FloatType>()) {
expandedOp = builder.create<IREEInterp::HL::ReduceMinFOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
} else {
expandedOp = builder.create<IREEInterp::HL::ReduceMinIOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
}
} else if (isa<IREEInterp::HL::MaxFOp>(elementOp) ||
isa<IREEInterp::HL::MaxISOp>(elementOp) ||
isa<IREEInterp::HL::MaxIUOp>(elementOp)) {
if (elementType.isa<FloatType>()) {
expandedOp = builder.create<IREEInterp::HL::ReduceMaxFOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
} else {
expandedOp = builder.create<IREEInterp::HL::ReduceMaxIOp>(
loc, dstType, srcArg, initArg, dimensionAttr);
}
}
if (!expandedOp) {
return elementOp->emitOpError()
<< "No matching expanded reduction op for elemental op";
}
llvm::SmallVector<int64_t, 4> zeroOffset(dstType.getRank(), 0);
auto zeroIndices = createArrayConstant(builder, loc, zeroOffset);
auto lengths = createArrayConstant(builder, loc, dstType.getShape());
builder.create<IREEInterp::HL::CopyOp>(
loc, expandedOp->getResult(0), zeroIndices, dstArg, zeroIndices, lengths);
return success();
}
// Replaces the given elemental |funcOp| with a widened reduction.
LogicalResult expandReductionFunction(FuncOp entryFunc) {
if (!entryFunc.empty()) {
return entryFunc.emitError()
<< "Function has already been expanded or has existing contents";
} else if (!entryFunc.getAttr("iree.executable.reduction.dimension")) {
return entryFunc.emitError() << "Windowed reductions are not yet supported";
}
auto applySym =
entryFunc.getAttrOfType<SymbolRefAttr>("iree.executable.reduction.apply");
if (!applySym) {
return entryFunc.emitError() << "No reduction application function defined";
}
auto applyFunc = entryFunc.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
applySym.getValue());
if (!applyFunc) {
return entryFunc.emitError()
<< "Unable to find apply function " << applySym;
}
auto *entryBlock = entryFunc.addEntryBlock();
OpBuilder builder(entryBlock);
if (applyFunc.getBlocks()
.front()
.walk([&](Operation *op) {
if (!op->isKnownTerminator()) {
if (failed(
convertReductionOp(entryFunc, applyFunc, op, builder))) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
})
.wasInterrupted()) {
return applyFunc.emitError() << "Unable to convert apply func";
}
builder.create<IREE::ReturnOp>(builder.getUnknownLoc());
// Remove the apply function as we have inlined it.
applyFunc.erase();
entryFunc.removeAttr("iree.executable.reduction.apply");
entryFunc.removeAttr("iree.executable.reduction.dimension");
return success();
}
// Limited lowering of reductions to fat reduce_* ops.
//
// The specific subset this supports is:
// * 'min', 'max', and 'add' computations, with function names matching the
// computation
// * one op per reduction (no fusions yet).
// Note: computations and shapes are not validated.
//
// TODO(b/139410773): Implement more generally, supporting custom computations.
class ExpandReductionsToOpsPass : public ModulePass<ExpandReductionsToOpsPass> {
public:
void runOnModule() override {
auto module = getModule();
SmallVector<FuncOp, 4> reductionFuncs;
for (auto funcOp : module.getOps<FuncOp>()) {
if (funcOp.getAttr("iree.executable.reduction.apply")) {
reductionFuncs.push_back(funcOp);
}
}
for (auto funcOp : reductionFuncs) {
if (failed(expandReductionFunction(funcOp))) {
return signalPassFailure();
}
}
}
};
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> createExpandReductionsToOpsPass() {
return std::make_unique<ExpandReductionsToOpsPass>();
}
static PassRegistration<ExpandReductionsToOpsPass> pass(
"iree-expand-reductions-to-ops",
"Expands IREE reduction functions to their interpreter ops");
} // namespace iree_compiler
} // namespace mlir