blob: 407f014d2f507d69a4055fee4b7ad58279998c2d [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "compiler/IR/Ops.h"
#include "compiler/IR/Sequencer/HLOps.h"
#include "compiler/IR/StructureOps.h"
#include "compiler/IR/Types.h"
#include "compiler/Utils/DispatchUtils.h"
#include "compiler/Utils/MemRefUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Determines the shapes involved with reducing this dimension.
SmallVector<int64_t, 4> calculateResultShape(Value *input,
int windowDimension) {
SmallVector<int64_t, 4> resultShape;
for (auto it :
llvm::enumerate(input->getType().cast<ShapedType>().getShape())) {
if (it.index() != windowDimension) {
resultShape.push_back(it.value());
}
}
return resultShape;
}
// Creates an executable that holds the given elemental reduction region.
// The executable will have an entry point taking the specified reduction values
// and writing the results to output arguments.
std::pair<IREE::MultiArchExecutableOp, FuncOp> createReductionExecutable(
IREE::ReductionRegionOp regionOp, int outlinedRegionOrdinal,
int separatedReductionIndex, int reductionDimension,
SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs) {
Builder builder(regionOp.getContext());
// Build function type matching 1:1 with the region signature.
SmallVector<Type, 8> elementalOperandTypes;
SmallVector<Type, 8> elementalResultTypes;
for (auto *arg : regionOp.getInitialValueOperands()) {
// (in0, in1) -> out0
elementalOperandTypes.push_back(arg->getType());
elementalOperandTypes.push_back(arg->getType());
elementalResultTypes.push_back(arg->getType());
}
auto elementalFunctionType = FunctionType::get(
elementalOperandTypes, elementalResultTypes, regionOp.getContext());
// Create the executable with the region cloned into it.
IREE::MultiArchExecutableOp multiArchExecutable;
FuncOp elementalFunc;
std::tie(multiArchExecutable, elementalFunc) = createRegionExecutable(
regionOp, elementalFunctionType,
"_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" +
std::to_string(separatedReductionIndex));
// Create a new entry point that we can use with the signature for this
// dimension.
SmallVector<Type, 8> allOperandTypes;
auto inputTypes =
llvm::map_range(inputs, [](Value *value) { return value->getType(); });
allOperandTypes.append(inputTypes.begin(), inputTypes.end());
auto initialValueTypes = llvm::map_range(
initialValues, [](Value *value) { return value->getType(); });
allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
auto shapedType = resultType.value().cast<ShapedType>();
allOperandTypes.push_back(MemRefType::get(
calculateResultShape(inputs[resultType.index()], reductionDimension),
shapedType.getElementType()));
}
auto entryFuncType = FunctionType::get(allOperandTypes, ArrayRef<Type>{},
regionOp.getContext());
auto entryFunc =
FuncOp::create(regionOp.getLoc(),
(elementalFunc.getName() + "_entry").str(), entryFuncType);
entryFunc.setAttr("iree.executable.export",
UnitAttr::get(regionOp.getContext()));
elementalFunc.getOperation()->getBlock()->push_back(entryFunc);
entryFunc.getOperation()->moveBefore(elementalFunc);
entryFunc.setAttr("iree.executable.reduction",
UnitAttr::get(regionOp.getContext()));
entryFunc.setAttr("iree.executable.reduction.apply",
builder.getSymbolRefAttr(elementalFunc));
return {multiArchExecutable, entryFunc};
}
// Converts a reduction_region into a dispatch to the outlined region function
// for a single reduction dimension.
// Returns the results of the reduction or empty if the construction fails.
SmallVector<Value *, 4> convertToDispatchOp(
IREE::ReductionRegionOp regionOp, IREE::MultiArchExecutableOp executable,
FuncOp entryFunc, int reductionDimension,
SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs,
OpBuilder &dispatcherBuilder) {
// Allocate output args and replace the return values with those.
SmallVector<Value *, 4> resultValues;
for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
// Allocate output buffer in the dispatcher to pass in to the region.
auto shapedType = resultType.value().cast<ShapedType>();
Value *allocatedValue = allocateDispatchOutputBuffer(
regionOp.getLoc(),
MemRefType::get(calculateResultShape(inputs[resultType.index()],
reductionDimension),
shapedType.getElementType()),
dispatcherBuilder);
if (!allocatedValue) {
regionOp.emitError("unable to allocate result value");
return {};
}
resultValues.push_back(allocatedValue);
}
// Calculate workload from the result shape.
auto *workload =
wrapAsMemRef(calculateWorkload(regionOp, resultValues.front()), regionOp,
dispatcherBuilder);
// Create the reduce op to the executable function.
std::vector<Value *> allOperands;
allOperands.insert(allOperands.end(), inputs.begin(), inputs.end());
allOperands.insert(allOperands.end(), initialValues.begin(),
initialValues.end());
allOperands.insert(allOperands.end(), resultValues.begin(),
resultValues.end());
dispatcherBuilder.create<IREESeq::HL::DispatchOp>(
regionOp.getLoc(), executable.getName(), entryFunc.getName(), workload,
ArrayRef<Type>{}, allOperands);
return resultValues;
}
// Outlines a reduction region into one or more iree.multi_arch_executables.
// This separates the reduction into multiple dispatches, one for each reduction
// dimension (thankfully XLA's operation semantics state this is ok). We then
// special case the first dispatch such that it takes the constant initial
// values so that we don't have to materialize a buffer for them.
LogicalResult outlineReductionRegion(IREE::ReductionRegionOp regionOp,
int outlinedRegionOrdinal) {
// Insert at the same place as the original region.
OpBuilder dispatcherBuilder(regionOp);
// Wrap input operands in memrefs.
SmallVector<Value *, 4> initialValues{llvm::map_range(
regionOp.getInitialValueOperands(), [&](Value *originalArg) {
return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
})};
SmallVector<Value *, 4> temps{
llvm::map_range(regionOp.getReductionOperands(), [&](Value *originalArg) {
return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
})};
// Create one dispatch per dimension being reduced.
// We'll do this by chaining the original input through with the temporary
// reduction results. The results we end up with will be the originally
// requested shape and we can just substitute them.
if (regionOp.isWindowed()) {
auto windowDimensions = regionOp.window_dimensions().getValue();
auto windowStrides = regionOp.window_strides().getValue();
auto baseDilations = regionOp.base_dilations().getValue();
auto windowDilations = regionOp.window_dilations().getValue();
SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>, 4>
sortedWindowAttrs;
for (uint64_t i = 0; i < windowDimensions.getNumElements(); ++i) {
int64_t windowDimension =
windowDimensions.getValue<IntegerAttr>({i}).getInt();
int64_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt();
int64_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt();
int64_t windowDilation =
windowDilations.getValue<IntegerAttr>({i}).getInt();
sortedWindowAttrs.push_back(
{windowDimension, windowStride, baseDilation, windowDilation});
}
llvm::sort(sortedWindowAttrs,
[](std::tuple<int64_t, int64_t, int64_t, int64_t> a,
std::tuple<int64_t, int64_t, int64_t, int64_t> b) {
return std::get<0>(a) - std::get<0>(b);
});
for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) {
int64_t windowDimension = std::get<0>(windowAttrs.value());
int64_t windowStride = std::get<1>(windowAttrs.value());
int64_t baseDilation = std::get<2>(windowAttrs.value());
int64_t windowDilation = std::get<3>(windowAttrs.value());
IREE::MultiArchExecutableOp multiArchExecutable;
FuncOp entryFunc;
std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension,
initialValues, temps);
entryFunc.setAttr("iree.executable.reduction.padding_mode",
dispatcherBuilder.getI32IntegerAttr(
regionOp.padding_mode().getValue()));
entryFunc.setAttr("iree.executable.reduction.window_dimension",
dispatcherBuilder.getI32IntegerAttr(windowDimension));
entryFunc.setAttr("iree.executable.reduction.window_stride",
dispatcherBuilder.getI32IntegerAttr(windowStride));
entryFunc.setAttr("iree.executable.reduction.base_dilation",
dispatcherBuilder.getI32IntegerAttr(baseDilation));
entryFunc.setAttr("iree.executable.reduction.window_dilation",
dispatcherBuilder.getI32IntegerAttr(windowDilation));
temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
windowDimension, initialValues,
std::move(temps), dispatcherBuilder);
if (temps.empty()) {
return regionOp.emitOpError()
<< "Failed to construct reduction for windowed dimension "
<< windowDimension;
}
}
} else {
auto dimensions = regionOp.dimensions().getValue();
SmallVector<int64_t, 4> sortedDimensions;
for (uint64_t i = 0; i < dimensions.getNumElements(); ++i) {
sortedDimensions.push_back(
dimensions.getValue<IntegerAttr>({i}).getInt());
}
llvm::sort(sortedDimensions, [](int64_t a, int64_t b) { return a - b; });
for (auto dimension : llvm::enumerate(sortedDimensions)) {
IREE::MultiArchExecutableOp multiArchExecutable;
FuncOp entryFunc;
std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(),
initialValues, temps);
entryFunc.setAttr("iree.executable.reduction.dimension",
dispatcherBuilder.getI32IntegerAttr(dimension.value()));
temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
dimension.value(), initialValues,
std::move(temps), dispatcherBuilder);
if (temps.empty()) {
return regionOp.emitOpError()
<< "Failed to construct reduction for dimension "
<< dimension.value();
}
}
}
for (auto it : llvm::enumerate(regionOp.getResults())) {
insertDispatcherLoad(regionOp, it.value(), temps[it.index()],
dispatcherBuilder);
}
// Erase original region.
regionOp.erase();
return success();
}
} // namespace
class OutlineReductionRegionsPass
: public ModulePass<OutlineReductionRegionsPass> {
public:
void runOnModule() override {
auto module = getModule();
ModuleManager moduleManager(module);
auto funcs = module.getOps<FuncOp>();
SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end());
for (auto func : funcOps) {
// Outline all of the iree.reduction_region ops in this function.
std::vector<IREE::ReductionRegionOp> reductionRegionOps;
func.walk([&](IREE::ReductionRegionOp op) {
reductionRegionOps.push_back(op);
});
for (int i = 0; i < reductionRegionOps.size(); ++i) {
if (failed(outlineReductionRegion(reductionRegionOps[i], i))) {
return signalPassFailure();
}
}
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass() {
return std::make_unique<OutlineReductionRegionsPass>(); // NOLINT
}
static PassRegistration<OutlineReductionRegionsPass> pass(
"iree-outline-reduction-regions",
"Outlines reduction regions into standalone functions");
} // namespace iree_compiler
} // namespace mlir