blob: ac417e7a29f6c1f9f3eea075c11e3b6bed0395e1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "third_party/llvm/llvm/include/llvm/Support/Debug.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Support/LLVM.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Utils.h"
#include "third_party/mlir_edge/iree/compiler/IR/Ops.h"
#include "third_party/mlir_edge/iree/compiler/Utils/DispatchUtils.h"
#include "third_party/tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Chosen randomly for now. We can measure and see what makes sense.
constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024;
// Returns true if the constant value is under a certain threshold.
// This threshold is fixed for all backends as a value that is assumed small
// enough to be worth inlining possibly several times (at the cost of binary
// bloat).
bool isConstantSmall(ConstantOp constantOp) {
if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) {
return shapedType.getSizeInBits() / 8 <=
kMaxRematerializedConstantSizeInBytes;
}
// Assume anything unshaped is small. This may not always be true in custom
// dialects but is in std for now.
return true;
}
// Returns true if the dispatch region is allowed to have constants inside.
// Certain regions that may get replaced or turned into kernel imports shouldn't
// have the constants moved into them as they'll just get lost.
bool canDispatchRegionContainConstants(
IREE::DispatchRegionOp dispatchRegionOp) {
for (auto &block : dispatchRegionOp.getBody()) {
for (auto &op : block) {
if (isa<xla_hlo::DotOp>(&op)) {
return false;
}
}
}
return true;
}
// Rematerializes a constant inside of all dispatch regions that use it.
// Afterward the constant is only removed if there are no other uses within the
// non-dispatch block (such as by sequencer ops).
LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) {
Value *constantValue = constantOp.getResult();
SmallVector<IREE::DispatchRegionOp, 4> usingRegionOps;
for (auto *user : constantValue->getUsers()) {
if (auto dispatchRegionOp = dyn_cast<IREE::DispatchRegionOp>(user)) {
// Ensure this isn't just the workload and is used as an arg.
if (std::find(dispatchRegionOp.arg_operand_begin(),
dispatchRegionOp.arg_operand_end(),
constantValue) != dispatchRegionOp.arg_operand_end()) {
if (canDispatchRegionContainConstants(dispatchRegionOp)) {
usingRegionOps.push_back(dispatchRegionOp);
}
}
}
}
for (auto &dispatchRegionOp : usingRegionOps) {
if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp,
constantValue))) {
return failure();
}
}
// Remove if there are no other uses within the block.
if (constantOp.use_empty()) {
constantOp.erase();
}
return success();
}
} // namespace
// Finds constant arguments to dispatch regions that are too small to be worth
// putting into constant pools. This prevents things like a CSE'd scalar
// constant of 0.0 being passed by reference to a bunch of regions. Later
// backend-specific passes running on the dispatch regions may also be able to
// improve their constant propagation chances by having the full constant value
// available.
//
// Note that this currently only operates at the block level. Constants that are
// pushed across branches are assumed to have been rematerialized within blocks
// already, but if that isn't the case then this pass can be extended to do
// that.
class RematerializeDispatchConstantsPass
: public FunctionPass<RematerializeDispatchConstantsPass> {
public:
void runOnFunction() override {
for (auto &block : getFunction()) {
SmallVector<ConstantOp, 8> smallConstantOps;
for (auto constantOp : block.getOps<ConstantOp>()) {
if (isConstantSmall(constantOp)) {
smallConstantOps.push_back(constantOp);
}
}
// Note: we iterate in reverse so that the rematerialized constants appear
// in the same order they did originally (as insertion is at the top).
for (auto constantOp : llvm::reverse(smallConstantOps)) {
if (failed(rematerializeConstantInDispatchRegions(constantOp))) {
return signalPassFailure();
}
}
}
}
};
std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() {
return std::make_unique<RematerializeDispatchConstantsPass>();
}
static PassRegistration<RematerializeDispatchConstantsPass> pass(
"iree-rematerialize-dispatch-constants",
"Rematerializes small previously-CSE'd constants into dispatch regions.");
} // namespace iree_compiler
} // namespace mlir