blob: 431f0e3f3d21aebe80fbc5fd72e8c44ecda58659 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
// Returns true if the constant value is a splat constant and can be
// rematerialized in a dispatch region.
bool isSplatConstant(ConstantOp constantOp) {
if (constantOp.getValue().isa<SplatElementsAttr>()) {
// Splats are always small and can be much better handled by broadcasting
// within the dispatch regions.
return true;
} else if (auto value = constantOp.getValue().dyn_cast<DenseElementsAttr>()) {
return value.isSplat();
}
// Assume anything unshaped is small. This may not always be true in custom
// dialects but is in std for now.
return false;
}
// Returns true if the dispatch region is allowed to have constants inside.
// Certain regions that may get replaced or turned into kernel imports shouldn't
// have the constants moved into them as they'll just get lost.
bool canDispatchRegionContainConstants(DispatchRegionOp dispatchRegionOp) {
for (auto &block : dispatchRegionOp.body()) {
for (auto &op : block) {
// TODO(b/144530470): replace with tablegen attributes/interfaces.
if (isa<mhlo::DotOp>(&op) || isa<mhlo::ConvOp>(&op)) {
// These two generally result in a lot of generated code so we try to
// keep constants out such that can dedupe more. We may still want to
// allow some parameters in (shapes/etc).
return false;
}
}
}
return true;
}
// Recursively clones the given |sourceOp| and returns the newly cloned op.
Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
// Note that we dedupe required operands in the case of multiple arguments
// coming from the same source operation.
SmallPtrSet<Operation *, 4> operandOps;
for (auto operand : sourceOp->getOperands()) {
operandOps.insert(operand.getDefiningOp());
}
for (auto *operandOp : operandOps) {
recursivelyCloneOp(operandOp, builder, mapping);
}
return builder.clone(*sourceOp, *mapping);
}
// Clones the |sourceValue| op tree into |targetBlock|.
// |mapping| is used to lookup existing values that may be present in the block
// such as block arguments or already cloned ancestor ops. |mapping| will be
// updated as the tree is cloned.
Value cloneOpTreeIntoBlock(Value sourceValue, Block *targetBlock,
BlockAndValueMapping *mapping) {
// If the op has already been cloned we can just reuse that.
// This happens if multiple arguments reference the same trees.
if (auto existingValue = mapping->lookupOrNull(sourceValue)) {
return existingValue;
}
OpBuilder builder = OpBuilder::atBlockEnd(targetBlock);
builder.setInsertionPointToStart(targetBlock);
auto *sourceOp = sourceValue.getDefiningOp();
auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping);
// Return only the result matching our source value (in the case of multiple
// results).
int resultIndex = std::distance(
sourceOp->result_begin(),
std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue));
return clonedOp->getResult(resultIndex);
}
// Inlines use of the given |value| from outside of a dispatch region to inside
// of it and removes the argument. Supports multiple arguments that reference
// |value| and will clone the entire value tree.
LogicalResult inlineDispatchRegionOperandsUsingValue(
DispatchRegionOp dispatchRegionOp, Value value) {
// Find all args that are using this value.
SmallVector<unsigned, 4> argIndices;
for (auto arg : llvm::enumerate(dispatchRegionOp.args())) {
if (arg.value() == value) {
argIndices.push_back(arg.index());
}
}
if (argIndices.empty()) {
// Not used? Wasteful call!
return success();
}
// Clone the value (and the ops required to create it) into the entry block.
auto &entryBlock = dispatchRegionOp.body().getBlocks().front();
BlockAndValueMapping mapping;
auto clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping);
// Replace all uses of the inner operand with the new value.
for (unsigned argIndex : argIndices) {
entryBlock.getArgument(argIndex).replaceAllUsesWith(clonedValue);
}
// Remove the dispatch region args and the block args that have been
// replaced.
for (unsigned argIndex : llvm::reverse(argIndices)) {
dispatchRegionOp.getOperation()->eraseOperand(
dispatchRegionOp.mapArgOperandToOpOperand(argIndex));
entryBlock.eraseArgument(argIndex);
}
return success();
}
// Rematerializes a constant inside of all dispatch regions that use it.
// Afterward the constant is only removed if there are no other uses within the
// non-dispatch block (such as by sequencer ops).
LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) {
Value constantValue = constantOp.getResult();
SmallVector<DispatchRegionOp, 4> usingRegionOps;
for (auto *user : constantValue.getUsers()) {
if (auto dispatchRegionOp = dyn_cast<DispatchRegionOp>(user)) {
// Ensure this isn't just the workload and is used as an arg.
if (std::find(dispatchRegionOp.args().begin(),
dispatchRegionOp.args().end(),
constantValue) != dispatchRegionOp.args().end()) {
if (canDispatchRegionContainConstants(dispatchRegionOp)) {
usingRegionOps.push_back(dispatchRegionOp);
}
}
}
}
for (auto &dispatchRegionOp : usingRegionOps) {
if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp,
constantValue))) {
return failure();
}
}
// Remove if there are no other uses within the block.
if (constantOp.use_empty()) {
constantOp.erase();
}
return success();
}
} // namespace
// Finds constant arguments to dispatch regions that are too small to be worth
// putting into constant pools. This prevents things like a CSE'd scalar
// constant of 0.0 being passed by reference to a bunch of regions. Later
// backend-specific passes running on the dispatch regions may also be able to
// improve their constant propagation chances by having the full constant value
// available.
//
// Note that this currently only operates at the block level. Constants that are
// pushed across branches are assumed to have been rematerialized within blocks
// already, but if that isn't the case then this pass can be extended to do
// that.
class RematerializeDispatchConstantsPass
: public PassWrapper<RematerializeDispatchConstantsPass, FunctionPass> {
public:
void runOnFunction() override {
for (auto &block : getFunction()) {
SmallVector<ConstantOp, 8> smallConstantOps;
for (auto constantOp : block.getOps<ConstantOp>()) {
if (isSplatConstant(constantOp)) {
smallConstantOps.push_back(constantOp);
}
}
// Note: we iterate in reverse so that the rematerialized constants appear
// in the same order they did originally (as insertion is at the top).
for (auto constantOp : llvm::reverse(smallConstantOps)) {
if (failed(rematerializeConstantInDispatchRegions(constantOp))) {
return signalPassFailure();
}
}
}
}
};
std::unique_ptr<OperationPass<FuncOp>>
createRematerializeDispatchConstantsPass() {
return std::make_unique<RematerializeDispatchConstantsPass>();
}
static PassRegistration<RematerializeDispatchConstantsPass> pass(
"iree-flow-rematerialize-dispatch-constants",
"Rematerializes small previously-CSE'd constants into dispatch regions");
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir