blob: 1ddeaf4c074f9ae3451c38f3aa4e53000bb00f9b [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- ResolveShapeOps.cpp - Pass to resolve shape related ops ------------===//
//
// This file implements functionalities to resolve shape related ops. For
// dynamic shape dimensions, this pass traces them back to the original HAL
// interface bindings.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Replaces `std.dim` on `shapex.tie_shape` with `shapex.ranked_shape` so that
/// later this shape dimension query can be folded away via canonicalization on
/// `shapex.ranked_shape`.
///
/// This pattern translates the following IR sequence:
///
/// ```mlir
/// %dynamic_dim = ... : index
/// %shape = shapex.make_ranked_shape %dynamic_dim ...
/// %tie_shape = shapex.tie_shape ..., %shape : ...
/// ...
/// %get_dim = std.dim %tie_shape ...
/// ```
///
/// Into:
///
/// ```mlir
/// %dynamic_dim = ... : index
/// %shape = shapex.make_ranked_shape %dynamic_dim ...
/// %tie_shape = shapex.tie_shape ..., %shape : ...
/// ...
/// %get_dim = shapex.ranked_dim %shape ...
/// ```
struct StdDimResolver final : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto tieShapeOp =
dyn_cast<Shape::TieShapeOp>(dimOp.memrefOrTensor().getDefiningOp());
if (!tieShapeOp) return failure();
Optional<int64_t> index = dimOp.getConstantIndex();
assert(index.hasValue() && "expect constant index in `std.dim` operation");
rewriter.replaceOpWithNewOp<Shape::RankedDimOp>(dimOp, tieShapeOp.shape(),
index.getValue());
return success();
}
};
/// Elides all `shapex.tie_shape` ops.
struct TieShapeElider final : public OpRewritePattern<Shape::TieShapeOp> {
using OpRewritePattern<Shape::TieShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Shape::TieShapeOp tieOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(tieOp, tieOp.operand());
return success();
}
};
struct ResolveShapeOpsPass
: public PassWrapper<ResolveShapeOpsPass, FunctionPass> {
void runOnFunction() override;
};
} // namespace
void ResolveShapeOpsPass::runOnFunction() {
MLIRContext *context = &getContext();
OwningRewritePatternList dimPatterns;
dimPatterns.insert<StdDimResolver>(context);
// Set up a target to convert all std.dim ops. We need a conversion target
// here to error out early if some std.dim op cannot be converted.
ConversionTarget target(*context);
target.addIllegalOp<DimOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(applyFullConversion(getFunction(), target, dimPatterns))) {
return signalPassFailure();
}
OwningRewritePatternList shapePatterns;
shapePatterns.insert<TieShapeElider>(context);
Shape::RankedDimOp::getCanonicalizationPatterns(shapePatterns, context);
// Then elide all shapex.tie_shape ops and canonicalize shapex.ranked_dim
// given that we don't need the shape annotation anymore.
applyPatternsAndFoldGreedily(getFunction(), shapePatterns);
}
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOpsPass() {
return std::make_unique<ResolveShapeOpsPass>();
}
static PassRegistration<ResolveShapeOpsPass> pass("iree-codegen-resolve-shape",
"resolve shape");
} // namespace iree_compiler
} // namespace mlir