blob: 133ef1258b670b25de1abb3e2598a0f47c67add9 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Conversion/Common/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace {
#include "iree/compiler/Conversion/Common/FoldTensorExtractOp.cpp.inc"
}
namespace {
/// Upstream canonicalization passes fold
///
/// (load (tensor_to_memref $value), $indices) to
///
/// (tensor_extract $value, $indices)
///
/// In general this is ill-defined because it ignores potential writes to the
/// result of the tensor_to_memref before the load. The assumption is that there
/// shouldn't be any writes using the result of tensor_to_memref. This is almost
/// impossible to enforce/verify. Nevertheless, in IREE we use
/// `tensor_to_memref` during bufferization of `std.constant` assuming that
/// downstream passes can handle the lowering of the `std.constant`.
///
/// On LLVM side, the `std.constant` is handled by the
/// `TensorConstantBufferizePass`, which creates a global object of `memref`
/// type. To get the tensor back you get a tensor.load. If the above
/// canonicalization pattern didnt exist, then a tensor.load would not be
/// needed.
///
/// This pass is specifically undoing the canonicalization by folding
///
/// (tensor_extract (tensor_load (get_global_memref:$value), $indices) to
///
/// (load $value, $indices)
///
/// In theory this could live upstream, but given that there is disagreement
/// about the validity of `tensor_to_memref` usage/canonicalizations, keeping
/// this pattern here.
class FoldTensorExtractOpPass
: public PassWrapper<FoldTensorExtractOpPass, OperationPass<>> {
void runOnOperation() override;
};
} // namespace
void FoldTensorExtractOpPass::runOnOperation() {
OwningRewritePatternList patterns(&getContext());
populateWithGenerated(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<>> createFoldTensorExtractOpPass() {
return std::make_unique<FoldTensorExtractOpPass>();
}
static PassRegistration<FoldTensorExtractOpPass> pass(
"iree-codegen-fold-tensor-extract-op",
"Fold `tensor.extract` operations prior to lowering to LLVM",
[] { return std::make_unique<FoldTensorExtractOpPass>(); });
} // namespace iree_compiler
} // namespace mlir