blob: 23575136343557a45cfec9cbaaa35c7d93b980f2 [file]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
// Determines the shapes involved with reducing this dimension.
SmallVector<int64_t, 4> calculateResultShape(Value *input,
int windowDimension) {
SmallVector<int64_t, 4> resultShape;
for (auto it :
llvm::enumerate(input->getType().cast<ShapedType>().getShape())) {
if (it.index() != windowDimension) {
resultShape.push_back(it.value());
}
}
return resultShape;
}
// Converts a reduction_region into a dispatch to the outlined region function
// for a single reduction dimension.
// Returns the results of the reduction or empty if the construction fails.
SmallVector<Value *, 4> convertToDispatchOp(
Operation *regionOp, ExecutableOp executableOp, StringRef entryPointName,
int reductionDimension, SmallVector<Value *, 4> initialValues,
SmallVector<Value *, 4> inputs, OpBuilder &dispatcherBuilder) {
SmallVector<Type, 4> resultTypes;
for (auto resultType : llvm::enumerate(regionOp->getResultTypes())) {
// Allocate output buffer in the dispatcher to pass in to the region.
auto shapedType = resultType.value().cast<ShapedType>();
auto reducedType = RankedTensorType::get(
calculateResultShape(inputs[resultType.index()], reductionDimension),
shapedType.getElementType());
resultTypes.push_back(reducedType);
}
// Calculate workload from the result shape.
auto *workload =
calculateWorkload(regionOp, resultTypes.front().cast<ShapedType>());
// Create the reduce op to the executable function.
std::vector<Value *> allOperands;
allOperands.insert(allOperands.end(), inputs.begin(), inputs.end());
allOperands.insert(allOperands.end(), initialValues.begin(),
initialValues.end());
auto dispatchOp = dispatcherBuilder.create<DispatchOp>(
regionOp->getLoc(), executableOp.getName(), entryPointName, workload,
resultTypes, allOperands);
return llvm::to_vector<4>(dispatchOp.getResults());
}
// Creates an executable that holds the given elemental reduction region.
// The executable will have an entry point taking the specified reduction values
// and writing the results to output arguments.
std::pair<ExecutableOp, ReductionEntryOp> createReductionExecutable(
ReductionRegionOp regionOp, int outlinedRegionOrdinal,
int separatedReductionIndex, int reductionDimension,
SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs,
llvm::StringMap<FuncOp> &dispatchableFuncOps) {
// Build function type matching 1:1 with the region signature.
SmallVector<Type, 8> elementalOperandTypes;
SmallVector<Type, 8> elementalResultTypes;
for (auto *arg : regionOp.initial_values()) {
// (in0, in1) -> out0
elementalOperandTypes.push_back(arg->getType());
elementalOperandTypes.push_back(arg->getType());
elementalResultTypes.push_back(arg->getType());
}
auto elementalFunctionType = FunctionType::get(
elementalOperandTypes, elementalResultTypes, regionOp.getContext());
// Create the executable with the region cloned into it.
ExecutableOp executableOp;
FuncOp elementalFuncOp;
std::tie(executableOp, elementalFuncOp) = createRegionExecutable(
regionOp, elementalFunctionType,
"_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" +
std::to_string(separatedReductionIndex),
dispatchableFuncOps);
// Create a new entry point that we can use with the signature for this
// dimension.
SmallVector<Type, 8> allOperandTypes;
auto inputTypes =
llvm::map_range(inputs, [](Value *value) { return value->getType(); });
allOperandTypes.append(inputTypes.begin(), inputTypes.end());
auto initialValueTypes = llvm::map_range(
initialValues, [](Value *value) { return value->getType(); });
allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
SmallVector<Type, 4> resultTypes;
for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
auto shapedType = resultType.value().cast<ShapedType>();
auto reducedType = RankedTensorType::get(
calculateResultShape(inputs[resultType.index()], reductionDimension),
shapedType.getElementType());
resultTypes.push_back(reducedType);
}
auto entryFuncType =
FunctionType::get(allOperandTypes, resultTypes, regionOp.getContext());
auto entryFuncOp = FuncOp::create(
regionOp.getLoc(), (elementalFuncOp.getName() + "_entry").str(),
entryFuncType);
elementalFuncOp.getOperation()->getBlock()->push_back(entryFuncOp);
entryFuncOp.getOperation()->moveBefore(elementalFuncOp);
// Add dispatch export pointing at the function.
OpBuilder builder(executableOp.body());
auto entryPointOp = builder.create<ReductionEntryOp>(
regionOp.getLoc(), builder.getStringAttr(entryFuncOp.getName()),
builder.getSymbolRefAttr(entryFuncOp),
builder.getSymbolRefAttr(elementalFuncOp),
builder.getI32IntegerAttr(reductionDimension));
return {executableOp, entryPointOp};
}
// Outlines a reduction region into one or more executables.
// This separates the reduction into multiple dispatches, one for each reduction
// dimension (thankfully XLA's operation semantics state this is ok). We then
// special case the first dispatch such that it takes the constant initial
// values so that we don't have to materialize a buffer for them.
LogicalResult outlineReductionRegion(
ReductionRegionOp regionOp, int outlinedRegionOrdinal,
llvm::StringMap<FuncOp> &dispatchableFuncOps) {
// Insert at the same place as the original region.
OpBuilder dispatcherBuilder(regionOp);
SmallVector<Value *, 4> initialValues{regionOp.initial_values()};
SmallVector<Value *, 4> temps{regionOp.operands()};
// Create one dispatch per dimension being reduced.
// We'll do this by chaining the original input through with the temporary
// reduction results. The results we end up with will be the originally
// requested shape and we can just substitute them.
auto dimensions = regionOp.dimensions().getValue();
SmallVector<int32_t, 4> sortedDimensions;
for (uint32_t i = 0; i < dimensions.getNumElements(); ++i) {
sortedDimensions.push_back(dimensions.getValue<IntegerAttr>({i}).getInt());
}
llvm::sort(sortedDimensions, [](int32_t a, int32_t b) { return a - b; });
for (auto dimension : llvm::enumerate(sortedDimensions)) {
// Create the executable with the region cloned into it.
ExecutableOp executableOp;
ReductionEntryOp entryPointOp;
std::tie(executableOp, entryPointOp) = createReductionExecutable(
regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(),
initialValues, temps, dispatchableFuncOps);
// Finally convert the dispatch region into a dispatch to the outlined func.
temps = convertToDispatchOp(regionOp, executableOp, entryPointOp.getName(),
dimension.value(), initialValues,
std::move(temps), dispatcherBuilder);
if (temps.empty()) {
return regionOp.emitOpError()
<< "failed to construct reduction for dimension "
<< dimension.value();
}
}
// Replace uses of the existing results with the new results.
for (int i = 0; i < regionOp.getNumResults(); ++i) {
regionOp.getResult(i)->replaceAllUsesWith(temps[i]);
}
// Erase original region.
regionOp.erase();
return success();
}
// Creates an executable that holds the given elemental reduction region.
// The executable will have an entry point taking the specified reduction values
// and writing the results to output arguments.
std::pair<ExecutableOp, WindowedReductionEntryOp>
createWindowedReductionExecutable(
WindowedReductionRegionOp regionOp, int outlinedRegionOrdinal,
int separatedReductionIndex, int32_t windowDimension, int32_t windowStride,
int32_t baseDilation, int32_t windowDilation,
SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs,
llvm::StringMap<FuncOp> &dispatchableFuncOps) {
// Build function type matching 1:1 with the region signature.
SmallVector<Type, 8> elementalOperandTypes;
SmallVector<Type, 8> elementalResultTypes;
for (auto *arg : regionOp.initial_values()) {
// (in0, in1) -> out0
elementalOperandTypes.push_back(arg->getType());
elementalOperandTypes.push_back(arg->getType());
elementalResultTypes.push_back(arg->getType());
}
auto elementalFunctionType = FunctionType::get(
elementalOperandTypes, elementalResultTypes, regionOp.getContext());
// Create the executable with the region cloned into it.
ExecutableOp executableOp;
FuncOp elementalFuncOp;
std::tie(executableOp, elementalFuncOp) = createRegionExecutable(
regionOp, elementalFunctionType,
"_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" +
std::to_string(separatedReductionIndex),
dispatchableFuncOps);
// Create a new entry point that we can use with the signature for this
// dimension.
SmallVector<Type, 8> allOperandTypes;
auto inputTypes =
llvm::map_range(inputs, [](Value *value) { return value->getType(); });
allOperandTypes.append(inputTypes.begin(), inputTypes.end());
auto initialValueTypes = llvm::map_range(
initialValues, [](Value *value) { return value->getType(); });
allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
SmallVector<Type, 4> resultTypes;
for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
auto shapedType = resultType.value().cast<ShapedType>();
auto reducedType = RankedTensorType::get(
calculateResultShape(inputs[resultType.index()], windowDimension),
shapedType.getElementType());
resultTypes.push_back(reducedType);
}
auto entryFuncType =
FunctionType::get(allOperandTypes, resultTypes, regionOp.getContext());
auto entryFuncOp = FuncOp::create(
regionOp.getLoc(), (elementalFuncOp.getName() + "_entry").str(),
entryFuncType);
elementalFuncOp.getOperation()->getBlock()->push_back(entryFuncOp);
entryFuncOp.getOperation()->moveBefore(elementalFuncOp);
// Add dispatch export pointing at the function.
OpBuilder builder(executableOp.body());
auto entryPointOp = builder.create<WindowedReductionEntryOp>(
regionOp.getLoc(), builder.getStringAttr(entryFuncOp.getName()),
builder.getSymbolRefAttr(entryFuncOp),
builder.getSymbolRefAttr(elementalFuncOp),
builder.getI32IntegerAttr(windowDimension),
builder.getI32IntegerAttr(windowStride),
builder.getI32IntegerAttr(baseDilation),
builder.getI32IntegerAttr(windowDilation),
builder.getI32IntegerAttr(
static_cast<uint32_t>(regionOp.padding_mode())));
return {executableOp, entryPointOp};
}
// Outlines a windowed reduction region into one or more executables.
// This separates the reduction into multiple dispatches, one for each reduction
// dimension (thankfully XLA's operation semantics state this is ok). We then
// special case the first dispatch such that it takes the constant initial
// values so that we don't have to materialize a buffer for them.
LogicalResult outlineWindowedReductionRegion(
WindowedReductionRegionOp regionOp, int outlinedRegionOrdinal,
llvm::StringMap<FuncOp> &dispatchableFuncOps) {
// Insert at the same place as the original region.
OpBuilder dispatcherBuilder(regionOp);
SmallVector<Value *, 4> initialValues{regionOp.initial_values()};
SmallVector<Value *, 4> temps{regionOp.operands()};
// Create one dispatch per dimension being reduced.
// We'll do this by chaining the original input through with the temporary
// reduction results. The results we end up with will be the originally
// requested shape and we can just substitute them.
using WindowTuple = std::tuple<int32_t, int32_t, int32_t, int32_t>;
auto windowDimensions = regionOp.window_dimensions();
auto windowStrides = regionOp.window_strides();
auto baseDilations = regionOp.base_dilations();
auto windowDilations = regionOp.window_dilations();
SmallVector<WindowTuple, 4> sortedWindowAttrs;
for (uint32_t i = 0; i < windowDimensions.getNumElements(); ++i) {
int32_t windowDimension =
windowDimensions.getValue<IntegerAttr>({i}).getInt();
int32_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt();
int32_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt();
int32_t windowDilation =
windowDilations.getValue<IntegerAttr>({i}).getInt();
sortedWindowAttrs.push_back(WindowTuple(windowDimension, windowStride,
baseDilation, windowDilation));
}
llvm::sort(sortedWindowAttrs, [](WindowTuple a, WindowTuple b) {
return std::get<0>(a) - std::get<0>(b);
});
for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) {
int32_t windowDimension = std::get<0>(windowAttrs.value());
int32_t windowStride = std::get<1>(windowAttrs.value());
int32_t baseDilation = std::get<2>(windowAttrs.value());
int32_t windowDilation = std::get<3>(windowAttrs.value());
ExecutableOp executableOp;
WindowedReductionEntryOp entryPointOp;
std::tie(executableOp, entryPointOp) = createWindowedReductionExecutable(
regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension,
windowStride, baseDilation, windowDilation, initialValues, temps,
dispatchableFuncOps);
temps = convertToDispatchOp(regionOp, executableOp, entryPointOp.getName(),
windowDimension, initialValues,
std::move(temps), dispatcherBuilder);
if (temps.empty()) {
return regionOp.emitOpError()
<< "failed to construct reduction for windowed dimension "
<< windowDimension;
}
}
// Replace uses of the existing results with the new results.
for (int i = 0; i < regionOp.getNumResults(); ++i) {
regionOp.getResult(i)->replaceAllUsesWith(temps[i]);
}
// Erase original region.
regionOp.erase();
return success();
}
} // namespace
class OutlineReductionRegionsPass
: public ModulePass<OutlineReductionRegionsPass> {
public:
OutlineReductionRegionsPass() = default;
explicit OutlineReductionRegionsPass(
std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps)
: dispatchableFuncOps_(std::move(dispatchableFuncOps)) {}
void runOnModule() override {
// TODO(benvanik): replace with a pattern rewriter?
auto funcOps = llvm::to_vector<32>(getModule().getOps<FuncOp>());
for (auto funcOp : funcOps) {
SmallVector<ReductionRegionOp, 4> reductionRegionOps;
funcOp.walk(
[&](ReductionRegionOp op) { reductionRegionOps.push_back(op); });
for (int i = 0; i < reductionRegionOps.size(); ++i) {
if (failed(outlineReductionRegion(reductionRegionOps[i], i,
*dispatchableFuncOps_))) {
return signalPassFailure();
}
}
SmallVector<WindowedReductionRegionOp, 4> windowedReductionRegionOps;
funcOp.walk([&](WindowedReductionRegionOp op) {
windowedReductionRegionOps.push_back(op);
});
for (int i = 0; i < windowedReductionRegionOps.size(); ++i) {
if (failed(outlineWindowedReductionRegion(windowedReductionRegionOps[i],
i, *dispatchableFuncOps_))) {
return signalPassFailure();
}
}
}
}
private:
std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps_;
};
std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass(
std::shared_ptr<llvm::StringMap<FuncOp>> dispatchableFuncOps) {
return std::make_unique<OutlineReductionRegionsPass>(
std::move(dispatchableFuncOps)); // NOLINT
}
static PassRegistration<OutlineReductionRegionsPass> pass(
"iree-flow-outline-reduction-regions",
"Outlines reduction regions into standalone functions");
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir