blob: 740a324db1b7baa9a27c2d39c6bdc07673bcbad7 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "third_party/llvm/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm/include/llvm/ADT/SetVector.h"
#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm/include/llvm/ADT/iterator_range.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Utils.h"
#include "third_party/mlir_edge/iree/compiler/IR/Ops.h"
#include "third_party/mlir_edge/iree/compiler/IR/Sequencer/HLOps.h"
#include "third_party/mlir_edge/iree/compiler/Utils/MemRefUtils.h"
#include "third_party/mlir_edge/iree/compiler/Utils/OpUtils.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Returns a value containing the indices in the form of a memref with shape
// {|indices|.size()}.
Value *makeIndicesMemRef(const MemRefType &type,
Operation::operand_range indices, OpBuilder &builder) {
auto &useOp = *builder.getInsertionPoint();
size_t indicesCount = std::distance(indices.begin(), indices.end());
if (indicesCount == 0) {
return builder
.create<IREE::ConstantOp>(
useOp.getLoc(), builder.getMemRefType({1}, builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 0))
.getResult();
} else if (indicesCount == 1) {
return loadAccessValue(useOp.getLoc(), *indices.begin(), builder);
}
// TODO(benvanik): support arbitrary indices.
useOp.emitError() << "Multiple indices are not yet implemented";
return nullptr;
}
// Returns a value containing the lengths in the form of a memref with shape
// {|dims|.size()}.
Value *makeLengthsMemRef(Value *storedValue, OpBuilder &builder) {
Type valueType = storedValue->getType();
if (auto shapedType = valueType.dyn_cast<ShapedType>()) {
return builder.create<IREESeq::HL::ShapeOp>(storedValue->getLoc(),
storedValue);
} else {
return builder.create<IREE::ConstantOp>(
storedValue->getLoc(),
builder.getMemRefType({1}, builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
}
}
// Returns the origin operation of a value if it is a load.
LoadOp findOriginLoadOperation(Value *value) {
// TODO(benvanik): follow through identity ops or something?
if (auto loadOp = dyn_cast<LoadOp>(value->getDefiningOp())) {
return loadOp;
}
return nullptr;
}
// Inserts a copy operation performing the same work as a store.
//
// Example:
// %0 = ... : memref<4xf32>
// %1 = load %0[%offset] : memref<4xf32>
// %2 = ... : memref<f32>
// store %1, %2[] : memref<f32>
// ->
// %0 = ... : memref<4xf32>
// %2 = ... : memref<f32>
// iree_hl_interp.copy %0[%offset], %2[], [%length]
void insertCopyForStore(LoadOp &loadOp, StoreOp &storeOp) {
OpBuilder builder(storeOp);
auto *srcIndices =
makeIndicesMemRef(loadOp.getMemRefType(), loadOp.getIndices(), builder);
auto *dstIndices =
makeIndicesMemRef(storeOp.getMemRefType(), storeOp.getIndices(), builder);
auto *lengths = makeLengthsMemRef(storeOp.getValueToStore(), builder);
builder.create<IREESeq::HL::CopyOp>(storeOp.getLoc(), loadOp.getMemRef(),
srcIndices, storeOp.getMemRef(),
dstIndices, lengths);
}
} // namespace
class LoadStoreDataFlowOptPass : public FunctionPass<LoadStoreDataFlowOptPass> {
public:
void runOnFunction() override {
auto func = getFunction();
// Find stores and attempt to optimize load+store pairs.
llvm::SetVector<Operation *> deadOperations;
func.walk([&](StoreOp storeOp) {
if (auto loadOp = findOriginLoadOperation(storeOp.getValueToStore())) {
insertCopyForStore(loadOp, storeOp);
deadOperations.insert(storeOp);
}
});
// Remove all the now-unused ops.
removeDeadOperations(deadOperations);
}
};
std::unique_ptr<OpPassBase<FuncOp>> createLoadStoreDataFlowOptPass() {
return std::make_unique<LoadStoreDataFlowOptPass>();
}
static PassRegistration<LoadStoreDataFlowOptPass> pass(
"iree-load-store-data-flow-opt",
"Optimize local load and store data flow by removing redundant accesses");
} // namespace iree_compiler
} // namespace mlir