Basic conversion of `shape` to `shapex` dialect.
- New pass ConvertShapeToShapex lowering the `shape` dialect to `shapex`.
Mainly, this pass converts `!shape.shape` to `!shapex.ranked_shape`, but
it will need to do more in the future. Currently, due to limitations in
the conversion infra (detailed in a comment), this won't work in general
beyond a single basic block. Need to follow up and either write custom
conversion infra or improve MLIR core.
- add ConvertShapeToShapex to iree-tf-import-pipeline
- add ops needed for lowering BatchMatMul
- `shapex.gather_extents`: a powerful shape shuffling op that subsumes
concat, slicing, permuting, etc.
- `shapex.to_extent_tensor` and `shapex.from_extent_tensor`: ops for
bridging between the "tensor of extents" world and
`!shapex.ranked_shape`
- ConvertHLOToShapeDialect: add support for lowering
`xla_hlo.dynamic_broadcast_in_dim` to `shapex` ops.
- canonicalization pattern to erase unused make_ranked_shape ops, helpful
during legalizations
- lower shapex.gather_extents in MaterializeShapeCalculations
PiperOrigin-RevId: 308738215
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index f532100..1c26f93 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -37,9 +37,12 @@
"//iree/base:signature_mangle",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/Conversion",
+ "//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Shape",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
diff --git a/integrations/tensorflow/compiler/Passes.cpp b/integrations/tensorflow/compiler/Passes.cpp
index 50931e2..aa19c09 100644
--- a/integrations/tensorflow/compiler/Passes.cpp
+++ b/integrations/tensorflow/compiler/Passes.cpp
@@ -16,6 +16,8 @@
#include "integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.h"
#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.h"
+#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
@@ -27,6 +29,14 @@
// IREE core should go here.
void createIreeTfImportPipeline(OpPassManager &pm) {
////////////////////////////////////////////////////////////////////////////
+ // Lowering shape-related constructs.
+ ////////////////////////////////////////////////////////////////////////////
+ pm.addPass(Shape::createConvertHLOToShapePass());
+ pm.addPass(createConvertShapeToShapexPass());
+ // Clean up trivial redundancies.
+ pm.addPass(createCanonicalizerPass());
+
+ ////////////////////////////////////////////////////////////////////////////
// Lowering TensorList-related parts of tf dialect to tf_tensorlist dialect.
////////////////////////////////////////////////////////////////////////////
pm.addPass(tf_tensorlist::createConvertTfToTfTensorList());
diff --git a/iree/compiler/Dialect/Shape/Conversion/BUILD b/iree/compiler/Dialect/Shape/Conversion/BUILD
new file mode 100644
index 0000000..133cc09
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/BUILD
@@ -0,0 +1,27 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "ConvertShapeToShapex",
+ srcs = ["ConvertShapeToShapex.cpp"],
+ deps = [
+ "//iree/compiler/Dialect/Shape/IR",
+ "@llvm-project//mlir:Dialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:Transforms",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "Conversion",
+ hdrs = ["Passes.h"],
+ deps = [
+ ":ConvertShapeToShapex",
+ "@llvm-project//mlir:Pass",
+ ],
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
new file mode 100644
index 0000000..2434478
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
@@ -0,0 +1,212 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace Shape {
+namespace {
+
+// This conversion is currently quite limited, such as not handling multiple
+// basic blocks in general, due to doing a type conversion that the MLIR core
+// conversion infra doesn't handle well.
+//
+// In particular, we convert `!shape.shape` to `!shapex.ranked_shape<...>`, but
+// the contents of the `...` are context-dependent. Thus, one could say that
+// this pass does a context-dependent type conversion.
+//
+// The current MLIR conversion infra doesn't handle context-dependent type
+// conversions.
+//
+// I can see two solutions:
+//
+// 1. Extend the MLIR conversion infra to better support context-dependent type
+// conversions. One way to do this would be for the conversion infra to convert
+// blocks in RPO and use the type of the converted successor operand in a
+// dominating predecessor as the type for the block argument when converting a
+// block. A similar thing could be done with an RPO traversal of the callgraph.
+// This algorithm wouldn't work in the presence of recursively dead cycles. And
+// of course linkage boundaries cannot have a context-dependent type conversion
+// (by definition).
+//
+// 2. Avoid needing to convert to !shapex.ranked_shape in the first place. This
+// could be accomplished by generalizing !shape.shape to be able to support the
+// use case of !shapex.ranked_shape. One important requirement here is that
+// !shapex.ranked_shape models a partially-specified shape (hardcoded for the
+// ranked case). !shape.shape could be extended to capture partially-specified
+// shapes in the type, such as allowing `!shape.shape<*>` to model an unranked
+// shape (which is the default; no information), `!shape.shape<?x?x5x?>` to
+// model a rank-4 shape with dimension 2 being of extent 5, etc.
+//
+// Once we have this, we could do this lowering from generic !shape.shape to
+// statically-known ranked shapes more progressively and treat it more like a
+// type refinement algorithm.
+//
+// The main risk is that we are trying to shove too much stuff into the
+// !shape.shape type. There's a risk that "progressive lowering" becomes "no
+// clear boundaries" and we end up with code deep into the compiler continuously
+// needing to doublecheck that the !shape.shape's at this point are in fact
+// statically known to be ranked, or silently making that assumption and
+// triggering assertions on verifier-valid IR. Pipelines and legalization
+// targets could make these assertions not fire in practice, but it would
+// be a maintenance burden.
+
+class ConvertShapeOfOp : public OpConversionPattern<shape::ShapeOfOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ shape::ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tensorType = operands[0].getType().dyn_cast<RankedTensorType>();
+ if (!tensorType) {
+ return failure();
+ }
+ auto resultType =
+ RankedShapeType::get(tensorType.getShape(), rewriter.getContext());
+ rewriter.replaceOpWithNewOp<Shape::GetRankedShapeOp>(op, resultType,
+ operands[0]);
+ return success();
+ }
+};
+
+class ConvertSplitAtOp : public OpConversionPattern<shape::SplitAtOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ shape::SplitAtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ IntegerAttr indexAttr;
+ if (!matchPattern(op.index(), m_Constant(&indexAttr))) {
+ return rewriter.notifyMatchFailure(op, "requires constant `index`");
+ }
+ auto rank = operands[0].getType().cast<RankedShapeType>().getRank();
+ int64_t index = indexAttr.getInt();
+ if (index < 0) {
+ index += rank;
+ }
+ auto head_indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, index));
+ auto tail_indices = llvm::to_vector<4>(llvm::seq<int64_t>(index, rank));
+ Value head = rewriter.create<GatherExtentsOp>(
+ op.getLoc(), operands[0], rewriter.getI64TensorAttr(head_indices));
+ Value tail = rewriter.create<GatherExtentsOp>(
+ op.getLoc(), operands[0], rewriter.getI64TensorAttr(tail_indices));
+ rewriter.replaceOp(op, {head, tail});
+ return success();
+ }
+};
+
+class ConvertBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ shape::BroadcastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Value lhs = operands[0];
+ Value rhs = operands[1];
+ auto lhsType = lhs.getType().cast<RankedShapeType>();
+ auto rhsType = rhs.getType().cast<RankedShapeType>();
+ // Establish invariant that rank(lhs) <= rank(rhs)
+ if (lhsType.getRank() > rhsType.getRank()) {
+ std::swap(lhsType, rhsType);
+ std::swap(lhs, rhs);
+ }
+ SmallVector<int64_t, 6> resultShape;
+ OpTrait::util::getBroadcastedShape(lhsType.getAllDims(),
+ rhsType.getAllDims(), resultShape);
+ auto resultType = RankedShapeType::get(resultShape, rewriter.getContext());
+ auto iota = llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsType.getRank()));
+ rewriter.replaceOpWithNewOp<RankedBroadcastShapeOp>(
+ op, resultType, lhs, rhs,
+ /*lhs_broadcast_dimensions=*/
+ rewriter.getI64TensorAttr(makeArrayRef(iota).drop_front(
+ rhsType.getRank() - lhsType.getRank())),
+ /*rhs_broadcast_dimensions=*/
+ rewriter.getI64TensorAttr(iota));
+ return success();
+ }
+};
+
+class ConvertConcatOp : public OpConversionPattern<shape::ConcatOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ shape::ConcatOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultRank = operands[0].getType().cast<RankedShapeType>().getRank() +
+ operands[1].getType().cast<RankedShapeType>().getRank();
+ auto indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, resultRank));
+ rewriter.replaceOpWithNewOp<Shape::GatherExtentsOp>(
+ op, ValueRange({operands[0], operands[1]}),
+ rewriter.getI64TensorAttr(indices));
+ return success();
+ }
+};
+
+class ConvertToExtentTensorOp
+ : public OpConversionPattern<shape::ToExtentTensorOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
+ operands[0]);
+ return success();
+ }
+};
+
+class ConvertShapeToShapex
+ : public PassWrapper<ConvertShapeToShapex, OperationPass<ModuleOp>> {
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+ MLIRContext *context = &getContext();
+
+ // Conversion target definition.
+ ConversionTarget conversionTarget(*context);
+ conversionTarget.addIllegalDialect<shape::ShapeDialect>();
+ conversionTarget.addLegalDialect<iree_compiler::ShapeDialect>();
+
+ // Patterns.
+ OwningRewritePatternList patterns;
+ patterns.insert<ConvertShapeOfOp>(context);
+ patterns.insert<ConvertSplitAtOp>(context);
+ patterns.insert<ConvertBroadcastOp>(context);
+ patterns.insert<ConvertConcatOp>(context);
+ patterns.insert<ConvertToExtentTensorOp>(context);
+
+ if (failed(applyPartialConversion(module, conversionTarget, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+} // namespace Shape
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass() {
+ return std::make_unique<Shape::ConvertShapeToShapex>();
+}
+
+static PassRegistration<Shape::ConvertShapeToShapex> registration(
+ "convert-shape-to-shapex", "Convert `shape` dialect to `shapex` dialect");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Conversion/Passes.h b/iree/compiler/Dialect/Shape/Conversion/Passes.h
new file mode 100644
index 0000000..750bb86
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/Passes.h
@@ -0,0 +1,29 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
+#define IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Convert `shape` dialect to `shapex` dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_SHAPE_CONVERSION_PASSES_H_
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/BUILD b/iree/compiler/Dialect/Shape/Conversion/test/BUILD
new file mode 100644
index 0000000..14281d1
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
new file mode 100644
index 0000000..02c9e9c
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt -convert-shape-to-shapex -split-input-file -verify-diagnostics -allow-unregistered-dialect <%s | IreeFileCheck %s
+
+// -----
+// shape.shape_of
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>) {
+ // CHECK: shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ "foo.use"(%0) : (!shape.shape) -> ()
+ return
+}
+
+// -----
+// shape.split_at
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>) {
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ %index = constant 0 : i32
+ // CHECK: %[[RS:.+]] = shapex.get_ranked_shape %arg0
+ // CHECK: %[[HEAD:.+]] = "shapex.gather_extents"(%[[RS]]) {indices = dense<[]> : tensor<0xi64>} : (!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[]>
+ // CHECK: %[[TAIL:.+]] = "shapex.gather_extents"(%[[RS]]) {indices = dense<0> : tensor<1xi64>} : (!shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+ // CHECK: "foo.use"(%[[HEAD]], %[[TAIL]])
+ %head, %tail = "shape.split_at"(%0, %index) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+ "foo.use"(%head, %tail) : (!shape.shape, !shape.shape) -> ()
+ return
+}
+
+// -----
+// No conversion -- index is dynamic.
+func @f(%arg0: tensor<?xf32>, %index: i32) {
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ // expected-error @+1 {{failed to legalize operation}}
+ %head, %tail = "shape.split_at"(%0, %index) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+ "foo.use"(%head, %tail) : (!shape.shape, !shape.shape) -> ()
+ return
+}
+
+// -----
+// shape.broadcast
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ %1 = "shape.shape_of"(%arg1) : (tensor<?xf32>) -> !shape.shape
+ // CHECK: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_shape"(%[[LHSRS]], %[[RHSRS]]) {
+ // CHECK-SAME: lhs_broadcast_dimensions = dense<0> : tensor<1xi64>,
+ // CHECK-SAME: rhs_broadcast_dimensions = dense<0> : tensor<1xi64>}
+ // CHECK-SAME: : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+ %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ // CHECK: "foo.use"(%[[BROADCASTED]])
+ "foo.use"(%2) : (!shape.shape) -> ()
+ return
+}
+
+// -----
+// shape.concat
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ %1 = "shape.shape_of"(%arg1) : (tensor<?xf32>) -> !shape.shape
+ // CHECK: %[[CONCATTED:.+]] = "shapex.gather_extents"(%[[LHSRS]], %[[RHSRS]]) {indices = dense<[0, 1]> : tensor<2xi64>}
+ %2 = "shape.concat"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ // CHECK: "foo.use"(%[[CONCATTED]])
+ "foo.use"(%2) : (!shape.shape) -> ()
+ return
+}
+
+// -----
+// shape.to_extent_tensor
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %[[RS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ %0 = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
+ // CHECK: %[[EXTENTS:.+]] = "shapex.to_extent_tensor"(%[[RS]])
+ %1 = "shape.to_extent_tensor"(%0) : (!shape.shape) -> tensor<1xindex>
+ // CHECK: "foo.use"(%[[EXTENTS]])
+ "foo.use"(%1) : (tensor<1xindex>) -> ()
+ return
+}
+
diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD
index 8d0bb7d..c7e0595 100644
--- a/iree/compiler/Dialect/Shape/IR/BUILD
+++ b/iree/compiler/Dialect/Shape/IR/BUILD
@@ -52,6 +52,7 @@
"//iree/compiler/Utils",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
@@ -72,6 +73,7 @@
":td_files",
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
],
@@ -88,6 +90,7 @@
":td_files",
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
],
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index b96665e..c8550f2 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -56,7 +56,9 @@
// If the immediate predecessor is a TieShapeOp, then this op can be
// erased in favor of the input to the tie op.
auto tieOp = dyn_cast_or_null<TieShapeOp>(operands.operand().getDefiningOp());
- if (!tieOp) return failure();
+ if (!tieOp) {
+ return rewriter.notifyMatchFailure(op, "no associated tie_shape op");
+ }
rewriter.replaceOp(op, tieOp.shape());
return success();
@@ -135,6 +137,26 @@
return success();
}
+// TODO(silvasean): Better handling of "erase unused ops for legality".
+// Currently, the way that we legalize !shapex.ranked_shape into individual SSA
+// values per dimension is to iteratively reduce other ops to
+// shapex.ranked_dim/shapex.ranked_dims and shapex.make_ranked_shape and then
+// have patterns that know how to resolve the
+// shapex.ranked_dim/shapex.ranked_dims to scalar values by looking through the
+// shapex.make_ranked_shape ops, with the eventual goal of not having any uses
+// of the shapex.make_ranked_shape op itself, instead the main computation flow
+// using the individual SSA values. This naturally produces a lot of unused
+// shapex.make_ranked_shape ops which we need to delete for legality reasons.
+// This pattern allows conversions to erase those ops.
+LogicalResult eraseUnusedMakeRankedShapeOp(
+ MakeRankedShapeOp op, MakeRankedShapeOpOperandAdaptor operands,
+ PatternRewriter &rewriter) {
+ if (!op.getResult().use_empty())
+ return rewriter.notifyMatchFailure(op, "op has uses");
+ rewriter.eraseOp(op);
+ return success();
+}
+
LogicalResult dynamicMakeRankedShapeDimPattern(
RankedDimOp op, RankedDimOpOperandAdaptor operands,
PatternRewriter &rewriter) {
@@ -268,6 +290,28 @@
}
//===----------------------------------------------------------------------===//
+// shapex.from_extent_tensor
+//===----------------------------------------------------------------------===//
+
+LogicalResult fromExtentTensorOfToExtentTensorIsIdentity(
+ FromExtentTensorOp op, FromExtentTensorOpOperandAdaptor operands,
+ PatternRewriter &rewriter) {
+ auto toOp =
+ dyn_cast_or_null<ToExtentTensorOp>(op.extent_tensor().getDefiningOp());
+ if (!toOp) {
+ return failure();
+ }
+ rewriter.replaceOp(op, toOp.shape());
+ return success();
+}
+
+void FromExtentTensorOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ insertGreedyPattern(patterns, context,
+ fromExtentTensorOfToExtentTensorIsIdentity);
+}
+
+//===----------------------------------------------------------------------===//
// Standard folding and canonicalization conversion patterns.
//===----------------------------------------------------------------------===//
@@ -294,6 +338,7 @@
void populateFoldConversionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<TieShapeTypeConversionPattern>(context);
+ insertConversionPattern(patterns, context, eraseUnusedMakeRankedShapeOp);
insertConversionPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
insertConversionPattern(patterns, context,
elideDuplicateGetRankedShapePattern);
@@ -303,6 +348,8 @@
insertConversionPattern(patterns, context, identityMakeRankedShapePattern);
insertConversionPattern(patterns, context, elideStaticGetRankedShapePattern);
insertConversionPattern(patterns, context, safeCastCompatibleShapePattern);
+ insertConversionPattern(patterns, context,
+ fromExtentTensorOfToExtentTensorIsIdentity);
}
} // namespace Shape
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeBase.td b/iree/compiler/Dialect/Shape/IR/ShapeBase.td
index 09029f1..654e8cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeBase.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeBase.td
@@ -52,4 +52,9 @@
// be a useful feature.
def Shape_DimType : AnyTypeOf<[Index, AnySignlessInteger]>;
+def Shape_ExtentTensor : ShapedContainerType<
+ [Index, AnySignlessInteger],
+ And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
+ "a 1D tensor of extents">;
+
#endif // IREE_DIALECT_SHAPE_BASE
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index f03bca3..672c6cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -240,6 +240,115 @@
RankedDimsOp::build(builder, result, builder->getIndexType(), shape);
}
+//===----------------------------------------------------------------------===//
+// shape.gather_extents
+//===----------------------------------------------------------------------===//
+
+// Helper for accessing attributes for inferReturnTypes callback.
+// That helper gives the attributes as an `ArrayRef<NamedAttribute>` which isn't
+// the nicest form.
+template <typename Attr>
+static Attr getRequiredAttr(ArrayRef<NamedAttribute> attributes,
+ StringRef name) {
+ auto it = llvm::find_if(
+ attributes, [&](NamedAttribute attr) { return attr.first == name; });
+ assert(it != attributes.end());
+ return it->second.template cast<Attr>();
+}
+
+/*static*/ SmallVector<int64_t, 6> GatherExtentsOp::getConcatenatedExtents(
+ ValueRange values) {
+ SmallVector<int64_t, 6> ret;
+ for (auto type : values.getTypes()) {
+ auto rankedShape = type.cast<RankedShapeType>();
+ ret.append(rankedShape.getAllDims().begin(),
+ rankedShape.getAllDims().end());
+ }
+ return ret;
+}
+
+static LogicalResult verifyGatherExtentsOp(GatherExtentsOp op) {
+ int64_t totalExtents = 0;
+ for (Type type : op.shapes().getTypes()) {
+ totalExtents += type.cast<RankedShapeType>().getRank();
+ }
+
+ for (int64_t index : op.indices().getValues<int64_t>()) {
+ if (index >= totalExtents) {
+ return op.emitError() << "index " << index
+ << " exceeds total number of extents of operands ("
+ << totalExtents << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult GatherExtentsOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // We can't infer the DimType of the result if there are no operands.
+ // If a user requires this, then they should manually specify the return type.
+ // We could in theory use an index type here (the default).
+ assert(!operands.empty() && "inferring return type for empty operands");
+ auto indices = getRequiredAttr<DenseIntElementsAttr>(attributes, "indices")
+ .getValues<int64_t>();
+ auto inputExtents = getConcatenatedExtents(operands);
+ SmallVector<int64_t, 6> resultExtents;
+ for (auto index : indices) {
+ resultExtents.push_back(inputExtents[index]);
+ }
+ inferredReturnTypes.push_back(RankedShapeType::get(resultExtents, context));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// shapex.to_extent_tensor
+//===----------------------------------------------------------------------===//
+
+LogicalResult ToExtentTensorOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto inputType = operands[0].getType().cast<RankedShapeType>();
+ inferredReturnTypes.push_back(
+ RankedTensorType::get({inputType.getRank()}, IndexType::get(context)));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// shapex.from_extent_tensor
+//===----------------------------------------------------------------------===//
+
+static bool isValidTensorOfExtents(RankedTensorType type) {
+ // If the tensor of extents is not static shapes, that would imply that the
+ // tensor whose shape it is describing is unranked.
+ return type.getRank() == 1 && type.hasStaticShape();
+}
+
+LogicalResult FromExtentTensorOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto inputType = operands[0].getType().dyn_cast<RankedTensorType>();
+ if (!inputType || !isValidTensorOfExtents(inputType)) {
+ return failure();
+ }
+ SmallVector<int64_t, 6> extents(inputType.getDimSize(0),
+ static_cast<int64_t>(-1));
+ inferredReturnTypes.push_back(RankedShapeType::get(extents, context));
+ return success();
+}
+
+bool FromExtentTensorOp::isCompatibleReturnTypes(ArrayRef<Type> lhs,
+ ArrayRef<Type> rhs) {
+ auto lhsRs = lhs[0].cast<RankedShapeType>();
+ auto rhsRs = rhs[0].cast<RankedShapeType>();
+ return succeeded(
+ verifyCompatibleShape(lhsRs.getAllDims(), rhsRs.getAllDims()));
+}
+
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.h b/iree/compiler/Dialect/Shape/IR/ShapeOps.h
index e684ddd..348c33d 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.h
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.h
@@ -21,6 +21,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffects.h"
namespace mlir {
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
index 26ec1c8..5a4a57b 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
@@ -17,6 +17,7 @@
include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/SideEffects.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpAsmInterface.td"
//===----------------------------------------------------------------------===//
@@ -145,6 +146,53 @@
}];
}
+// TODO(silvasean): What if the shape is an error shape?
+def Shape_ToExtentTensorOp : Shape_PureOp<"to_extent_tensor",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Convert a ranked shape to a tensor of extents.";
+ let description = [{
+ Convert a !shapex.ranked_shape to a rank-1 tensor of integers.
+
+ Examples:
+ %t0 = "shapex.to_extent_tensor"(%rs0)
+ : (!shapex.ranked_shape<[3,?,5]>)
+ -> tensor<3xi32>
+ The resulting tensor will, for example, have elements [3,4,5] if the
+ dynamic dimension is 4 at runtime.
+ }];
+ let arguments = (ins Shape_RankedShape:$shape);
+ let results = (outs Shape_ExtentTensor:$extent_tensor);
+
+ // TODO: Custom parser/printer
+ let parser = ?;
+ let printer = ?;
+}
+
+def Shape_FromExtentTensorOp : Shape_PureOp<"from_extent_tensor",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Convert a tensor of extents to a ranked shape.";
+ let description = [{
+ Convert a rank-1 tensor of integers to a !shapex.ranked_shape.
+
+ Examples:
+ %t0 = "shapex.from_dimension_tensor"(%rs0)
+ : (tensor<3xi32>)
+ -> !shapex.ranked_shape<[?,?,?],i32>
+ }];
+ let arguments = (ins Shape_ExtentTensor:$extent_tensor);
+ let results = (outs Shape_RankedShape:$shape);
+
+ let hasCanonicalizer = 1;
+ let extraClassDeclaration = [{
+ // Declaration for overridden method from InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
+ }];
+
+ // TODO: Custom parser/printer
+ let parser = ?;
+ let printer = ?;
+}
+
def Shape_ConstRankedShapeOp : Shape_PureOp<"const_ranked_shape",
[ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "A constant ranked_shape.";
@@ -315,4 +363,85 @@
let printer = ?;
}
+//===----------------------------------------------------------------------===//
+// Shape manipulations.
+//===----------------------------------------------------------------------===//
+
+def Shape_GatherExtentsOp : Shape_PureOp<"gather_extents",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Gather extents across shapes.";
+ let description = [{
+ Gathers extents across the !shapex.ranked_shape's in `shapes`.
+
+ This op conceptually performs the following operation:
+ 1. The extents of all shapes in `shapes` are concatenated together into
+ a single list.
+ 2. The resulting shape is constructed by extracting extents from the
+ combined list according to `indices`.
+ In pseudocode:
+ ```
+ shapes = ... # a list of lists of extents
+ # Example: shapes = [[3,-1],[2,7]]
+ extents = [extent for extent in shape for shape in shapes]
+ # Example: extents = [3,-1,2,7]
+ # or to use another terminology: `extents = flatmap(shapes)`
+ results = [extents[index] for index in indices]
+ ```
+
+ A large class of shape manipulations can be canonicalized into this op,
+ including:
+ - taking slices of shapes
+ - concatenating shapes
+ - permuting shapes
+ The intuition behind this op is that eventually each extent will be
+ exploded into its own SSA value. At which point, this op merely becomes
+ and identification of each SSA value of the output extents with an
+ SSA value of the input extents.
+ This op has the useful property that is closed under composition with
+ itself, thus allowing an arbitrarily complex subgraph consisting of just
+ this op to be folded together.
+
+ Some examples of shape transfer functions captured with this op:
+
+ - Taking the last two extents of a shape:
+ - [d0,d1,d2,d3] indices=[2,3] -> [d2,d3]
+ - Concatenating three shapes:
+ - [d0,d1] [d2,d3] [d4,d5] indices=[0,1,2,3,4,5] -> [d0,d1,d2,d3,d4,d5]
+ - Shape transfer function for transpose with permutation [0,2,1]:
+ - [d0,d1,d2] indices=[0,2,1] -> [d0,d2,d1]
+ - Shape transfer function for outer product of a vector with itself:
+ - Initial state: [d0] [d1] indices=[0,1] -> [d0,d1]
+ - Canonicalized to a single-operand op after observing that both inputs
+ are the same !shapex.ranked_shape value: [d0] indices=[0,0] -> [d0,d0]
+ - Shape transfer function for matmul with a batch dimension on the LHS:
+ - [d0,d1,d2] [d4,d5] indices=[0,1,2,5] -> [d0,d1,d2,d5]
+
+ This op is somewhat inspired by the LLVM `shufflevector` instruction.
+
+ Possible future pretty syntax for single-arg case:
+ %rs = shapex.gather_extents %0[0,2,1] : !shapex.ranked_shape<[5,6,7]>
+ Consider a pretty syntax for "concat":
+ %rs = shapex.gather_extents concat(%0, %1) : !shapex.ranked_shape<[5,6,7]>, !shapex.ranked_shape<[8,9]>
+
+ }];
+ let arguments = (ins
+ Variadic<Shape_RankedShape>:$shapes,
+ I64ElementsAttr:$indices
+ );
+ let results = (outs
+ Shape_RankedShape:$result
+ );
+
+ let verifier = [{ return verify$cppClass(*this); }];
+
+ // TODO: Custom parser/printer
+ let parser = ?;
+ let printer = ?;
+
+ let extraClassDeclaration = [{
+ static SmallVector<int64_t, 6> getConcatenatedExtents(ValueRange values);
+ }];
+}
+
+
#endif // IREE_DIALECT_SHAPE_OPS
diff --git a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
index e0b0c1b..7f5395b 100644
--- a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
@@ -41,3 +41,10 @@
%0 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[2,4]> -> index
return
}
+
+// -----
+
+func @compatible_from_extent_tensor(%arg0: tensor<1xindex>) {
+ %0 = "shapex.from_extent_tensor"(%arg0) : (tensor<1xindex>) -> !shapex.ranked_shape<[3]>
+ return
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
index 11c9039..1531b4b 100644
--- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
@@ -143,6 +143,22 @@
}
};
+class ConvertDynamicBroadcastInDim
+ : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ xla_hlo::DynamicBroadcastInDimOpOperandAdaptor adapter(operands);
+ Value rankedShape = rewriter.create<Shape::FromExtentTensorOp>(
+ op.getLoc(), adapter.output_dimensions());
+ rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
+ op, op.getType(), adapter.operand(), rankedShape,
+ op.broadcast_dimensions());
+ return success();
+ }
+};
+
class ConvertHLOToShapePass
: public PassWrapper<ConvertHLOToShapePass, FunctionPass> {
void runOnFunction() override {
@@ -153,6 +169,9 @@
conversionTarget.addLegalDialect<StandardOpsDialect>();
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
+ conversionTarget.addIllegalOp<xla_hlo::DynamicBroadcastInDimOp>();
+ conversionPatterns.insert<ConvertDynamicBroadcastInDim>(&getContext());
+
#define CONVERT_BINARY_ELEMENTWISE_OP(HloOpTy) \
conversionTarget.addDynamicallyLegalOp<HloOpTy>( \
[](HloOpTy op) { return IsSameRankedTypeBinaryElementwiseOp(op); }); \
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index f73c3b9..894f746 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -103,6 +103,44 @@
filteredResultExtents);
}
+LogicalResult expandGatherExtentsOp(GatherExtentsOp op,
+ GatherExtentsOp::OperandAdaptor operands,
+ PatternRewriter &rewriter) {
+ // Calculate cumulative sums of the ranks of each operand, which allows
+ // us to map each index to its corresponding operand easily.
+ SmallVector<int64_t, 6> cumsum;
+ cumsum.push_back(0);
+ for (auto operand : operands.shapes()) {
+ auto rank = operand.getType().cast<Shape::RankedShapeType>().getRank();
+ cumsum.push_back(cumsum.back() + rank);
+ }
+
+ // For each index, extract the relevant extent from the operands.
+ SmallVector<Value, 6> extents;
+ for (auto index : op.indices().getValues<int64_t>()) {
+ auto it = llvm::upper_bound(cumsum, index) - 1;
+ auto operandNum = std::distance(cumsum.begin(), it);
+ auto dimNum = index - *it;
+ auto extent = rewriter.create<Shape::RankedDimOp>(
+ op.getLoc(), operands.shapes()[operandNum], dimNum);
+ extents.push_back(extent);
+ }
+
+ // Due to a quirk of MakeRankedShapeOp, we only want the dynamic
+ // dimensions.
+ SmallVector<Value, 6> onlyDynamicExtents;
+ auto resultType = op.result().getType().cast<Shape::RankedShapeType>();
+ for (int i = 0, e = resultType.getRank(); i < e; i++) {
+ if (resultType.isDimDynamic(i)) {
+ onlyDynamicExtents.push_back(extents[i]);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<Shape::MakeRankedShapeOp>(op, resultType,
+ onlyDynamicExtents);
+ return success();
+}
+
LogicalResult expandRankedBroadcastShapePattern(
RankedBroadcastShapeOp bcastOp,
RankedBroadcastShapeOp::OperandAdaptor operands,
@@ -234,6 +272,7 @@
// We explicitly want to convert these ops, eliminating them.
target.addIllegalOp<GetRankedShapeOp>();
target.addIllegalOp<RankedBroadcastShapeOp>();
+ target.addIllegalOp<GatherExtentsOp>();
}
void populateMaterializeShapeCalculationsConversionPatterns(
@@ -241,6 +280,8 @@
// Fallback patterns.
insertConversionPattern(patterns, context, expandRankedBroadcastShapePattern,
/*benefit=*/1);
+ insertConversionPattern(patterns, context, expandGatherExtentsOp,
+ /*benefit=*/1);
insertConversionPattern(patterns, context, materializeRankedShapePattern,
/*benefit=*/1);
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
index 96bdd6b..6c59951 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
@@ -67,6 +67,7 @@
GetRankedShapeOp::getCanonicalizationPatterns(patterns, context);
MakeRankedShapeOp::getCanonicalizationPatterns(patterns, context);
RankedDimOp::getCanonicalizationPatterns(patterns, context);
+ RankedDimsOp::getCanonicalizationPatterns(patterns, context);
TieShapeOp::getCanonicalizationPatterns(patterns, context);
applyPatternsAndFoldGreedily(getOperation(), patterns);
}
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir b/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
index 9819e1f..369325c 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/convert-hlo-to-shape-dialect.mlir
@@ -14,3 +14,14 @@
// CHECK-DAG: return %[[SUM]]
return %0 : tensor<?x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[SHAPE:.+]] = "shapex.from_extent_tensor"(%arg1) : (tensor<2xindex>) -> !shapex.ranked_shape<[?,?]>
+ // CHECK-DAG: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: return %[[BROADCASTED]]
+ %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<[1]> : tensor<1xi64>}: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
index 3fbc213..8a1d1aa 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
@@ -50,3 +50,38 @@
%2 = shapex.get_ranked_shape %1 : tensor<?x2xf32> -> !shapex.ranked_shape<[?,2]>
return %1, %2 : tensor<?x2xf32>, !shapex.ranked_shape<[?,2]>
}
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index) -> (index, index, index) {
+ %0 = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %1 = "shapex.gather_extents"(%0) {indices = dense<[1, 1, 0]> : tensor<3xi64>} : (!shapex.ranked_shape<[?,?]>) -> !shapex.ranked_shape<[?,?,?]>
+ %2:3 = shapex.ranked_dims %1 : !shapex.ranked_shape<[?,?,?]> -> index, index, index
+ // CHECK: return %arg1, %arg1, %arg0
+ return %2#0, %2#1, %2#2 : index, index, index
+}
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = shapex.make_ranked_shape %arg0 : (index) -> !shapex.ranked_shape<[?]>
+ %1 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[?]>
+ %2 = shapex.make_ranked_shape %arg2 : (index) -> !shapex.ranked_shape<[?]>
+ %gathered = "shapex.gather_extents"(%0, %1, %2) {indices = dense<[2, 2, 1, 0]> : tensor<4xi64>} : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?,?,?,?]>
+ %extents:4 = shapex.ranked_dims %gathered : !shapex.ranked_shape<[?,?,?,?]> -> index, index, index, index
+ // CHECK: return %arg2, %arg2, %arg1, %arg0
+ return %extents#0, %extents#1, %extents#2, %extents#3 : index, index, index, index
+}
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index, index, index) {
+ %0 = shapex.make_ranked_shape %arg0 : (index) -> !shapex.ranked_shape<[?,2]>
+ %1 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[3,?]>
+ %2 = shapex.make_ranked_shape %arg2 : (index) -> !shapex.ranked_shape<[7,?]>
+ %gathered = "shapex.gather_extents"(%0, %1, %2) {indices = dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi64>} : (!shapex.ranked_shape<[?,2]>, !shapex.ranked_shape<[3,?]>, !shapex.ranked_shape<[7,?]>) -> !shapex.ranked_shape<[?,2,3,?,7,?]>
+ %extents:6 = shapex.ranked_dims %gathered : !shapex.ranked_shape<[?,2,3,?,7,?]> -> index, index, index, index, index, index
+ // CHECK: return %arg0, %c2{{.*}}, %c3{{.*}}, %arg1, %c7{{.*}}, %arg2
+ return %extents#0, %extents#1, %extents#2, %extents#3, %extents#4, %extents#5 : index, index, index, index, index, index
+}
+
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index c780e14..08a5c4e 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -100,6 +100,7 @@
"@llvm-project//mlir:SDBM",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:SPIRVLowering",
+ "@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardToSPIRVConversions",
"@llvm-project//mlir:Transforms",
@@ -152,6 +153,7 @@
"//iree/compiler/Dialect/HAL/Transforms",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/IREE/Transforms",
+ "//iree/compiler/Dialect/Shape/Conversion",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/init_dialects.h b/iree/tools/init_dialects.h
index a91128d..3ce37dc 100644
--- a/iree/tools/init_dialects.h
+++ b/iree/tools/init_dialects.h
@@ -35,6 +35,7 @@
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Dialect.h"
@@ -51,6 +52,7 @@
registerDialect<linalg::LinalgDialect>();
registerDialect<loop::LoopOpsDialect>();
registerDialect<quant::QuantizationDialect>();
+ registerDialect<shape::ShapeDialect>();
registerDialect<spirv::SPIRVDialect>();
registerDialect<StandardOpsDialect>();
registerDialect<vector::VectorDialect>();