[vmla] Dynamic shape lowering for dot_general
- Cleanly separate VMLA/Transforms/PreConversionLowering.cpp which does preparatory tensor->tensor rewrites.
- restructure conversion so xla_hlo.broadcast_in_dim and xla_hlo.dynamic_broadcast_in_dim are lowered via shapex.ranked_broadcast_in_dim
- add shape transfer function plugin for VMLA (only vmla.batch.matmul.pseudo for now); Also fixup docs on vmla.batch.matmul.pseudo.
- shape transfer function for xla_hlo.dynamic_reshape
- augment matrix_ops_test.py to test the cases all the way from Python!!! This is the end of the dynamic matmul! (though there are still some major issues, like error handling :/)
Also, fix up CMake build for convert-shape-to-shapex pass.
PiperOrigin-RevId: 309129583
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index ea38b2a..10293da 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -49,12 +49,26 @@
return tf.matmul(lhs, rhs)
@tf.function(input_signature=[
- tf.TensorSpec([1, 7, 4, 2], tf.float32),
- tf.TensorSpec([7, 1, 2, 4], tf.float32),
+ tf.TensorSpec([None, None, 4, 2], tf.float32),
+ tf.TensorSpec([None, None, 2, 4], tf.float32),
])
def matmul_high_rank_batch(self, lhs, rhs):
return tf.matmul(lhs, rhs)
+ @tf.function(input_signature=[
+ tf.TensorSpec([None, None, None], tf.float32),
+ tf.TensorSpec([None, None, None], tf.float32),
+ ])
+ def matmul_dynamic(self, lhs, rhs):
+ return tf.matmul(lhs, rhs)
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([None, None, None], tf.float32),
+ tf.TensorSpec([None, None], tf.float32),
+ ])
+ def matmul_dynamic_lhs_batch(self, lhs, rhs):
+ return tf.matmul(lhs, rhs)
+
@tf_test_utils.compile_modules(
backends=["tf", "iree_vmla"], mat=MatrixOpsModule)
@@ -63,31 +77,55 @@
def test_basic_matmul(self):
m = self.modules.mat.all
dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
- dst.print().assert_all_close()
+ dst.assert_all_close()
def test_matmul_lhs_batch(self):
m = self.modules.mat.all
dst = m.matmul_lhs_batch(
tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
- dst.print().assert_all_close()
+ dst.assert_all_close()
def test_matmul_rhs_batch(self):
m = self.modules.mat.all
dst = m.matmul_rhs_batch(
tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
- dst.print().assert_all_close()
+ dst.assert_all_close()
def test_matmul_broadcast_singleton_dimension(self):
m = self.modules.mat.all
dst = m.matmul_broadcast_singleton_dimension(
tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
- dst.print().assert_all_close()
+ dst.assert_all_close()
def test_matmul_high_rank_batch(self):
m = self.modules.mat.all
dst = m.matmul_high_rank_batch(
tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
- dst.print().assert_all_close()
+ dst.assert_all_close()
+
+ def test_matmul_dynamic_matching_batch(self):
+ m = self.modules.mat.all
+ dst = m.matmul_dynamic(
+ tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
+ dst.assert_all_close()
+
+ def test_matmul_dynamic_broadcast_lhs(self):
+ m = self.modules.mat.all
+ dst = m.matmul_dynamic(
+ tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
+ dst.assert_all_close()
+
+ def test_matmul_dynamic_broadcast_rhs(self):
+ m = self.modules.mat.all
+ dst = m.matmul_dynamic(
+ tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
+ dst.assert_all_close()
+
+ def test_matmul_dynamic_rank_broadcasting(self):
+ m = self.modules.mat.all
+ dst = m.matmul_dynamic_lhs_batch(
+ tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
+ dst.assert_all_close()
if __name__ == "__main__":
diff --git a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..85d92cd
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
@@ -0,0 +1,42 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ ConvertShapeToShapex
+ SRCS
+ "ConvertShapeToShapex.cpp"
+ DEPS
+ MLIRDialect
+ MLIRIR
+ MLIRPass
+ MLIRShape
+ MLIRTransforms
+ iree::compiler::Dialect::Shape::IR
+ ALWAYSLINK
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ Conversion
+ HDRS
+ "Passes.h"
+ DEPS
+ ::ConvertShapeToShapex
+ MLIRPass
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/Shape/Conversion/Passes.h b/iree/compiler/Dialect/Shape/Conversion/Passes.h
index 750bb86..6104d40 100644
--- a/iree/compiler/Dialect/Shape/Conversion/Passes.h
+++ b/iree/compiler/Dialect/Shape/Conversion/Passes.h
@@ -23,6 +23,14 @@
// Convert `shape` dialect to `shapex` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass();
+namespace Shape {
+
+inline void registerShapeConversionPasses() {
+ createConvertShapeToShapexPass();
+}
+
+} // namespace Shape
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
new file mode 100644
index 0000000..559e5c0
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
@@ -0,0 +1,34 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "VMLAShapeBuilder",
+ srcs = [
+ "VMLAShapeBuilder.cpp",
+ ],
+ hdrs = [
+ "VMLAShapeBuilder.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/VMLA/IR",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ ],
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
new file mode 100644
index 0000000..33e2494
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ VMLAShapeBuilder
+ HDRS
+ "VMLAShapeBuilder.h"
+ SRCS
+ "VMLAShapeBuilder.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::VMLA::IR
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
new file mode 100644
index 0000000..ad71a21
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
@@ -0,0 +1,68 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
+
+#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Optional.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir::iree_compiler::Shape;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+namespace {
+
+Value rewriteBatchMatMulPseudoOp(RankedShapeType resultShape,
+ BatchMatMulPseudoOp op, OpBuilder &builder) {
+ auto lhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.lhs());
+ auto rhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.rhs());
+ SmallVector<Value, 6> extents;
+ // Batch dimension (already been established to match between both operands,
+ // so arbitrarily use the LHS).
+ extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 0));
+ // RHS free dimension.
+ extents.push_back(builder.create<RankedDimOp>(op.getLoc(), rhsShape, 1));
+ // LHS free dimension.
+ extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 1));
+ // Due to a quirk of MakeRankedShapeOp, we only pass in the dynamic dims.
+ // So prune them down here.
+ SmallVector<Value, 6> onlyDynamicExtents;
+ for (int i = 0; i < 3; i++) {
+ if (resultShape.isDimDynamic(i)) {
+ onlyDynamicExtents.push_back(extents[i]);
+ }
+ }
+ return builder.create<MakeRankedShapeOp>(op.getLoc(), resultShape,
+ onlyDynamicExtents);
+}
+
+} // namespace
+
+void populateVMLACustomOpShapeBuilder(CustomOpShapeBuilderList &builders) {
+ auto &b = builders.make<CallbackCustomOpShapeBuilder>();
+ b.insertOpRankedShapeBuilder<BatchMatMulPseudoOp>(rewriteBatchMatMulPseudoOp);
+}
+
+} // namespace VMLA
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
new file mode 100644
index 0000000..a80de01
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
@@ -0,0 +1,34 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
+#define IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+// Creates a custom op shape builder for VMLA ops that are not otherwise
+// supported through traits or other declarative means.
+void populateVMLACustomOpShapeBuilder(
+ iree_compiler::Shape::CustomOpShapeBuilderList &builders);
+
+} // namespace VMLA
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
new file mode 100644
index 0000000..14281d1
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
new file mode 100644
index 0000000..70ba1cb
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-shape-materialize-calculations %s | IreeFileCheck %s
+
+// -----
+// CHECK-LABEL: func @batch.matmul.pseudo
+func @batch.matmul.pseudo(
+ %lhs: tensor<?x?x?xf32>, %rhs: tensor<?x?x?xf32>,
+ %lhsShape: !shapex.ranked_shape<[?,?,?]>, %rhsShape: !shapex.ranked_shape<[?,?,?]>
+) -> !shapex.ranked_shape<[?,?,?]> {
+ %lhsTied = shapex.tie_shape %lhs, %lhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
+ %rhsTied = shapex.tie_shape %rhs, %rhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
+ %0 = "vmla.batch.matmul.pseudo"(%lhsTied, %rhsTied) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ // CHECK-DAG: %[[BATCH:.+]] = shapex.ranked_dim %arg2[0]
+ // CHECK-DAG: %[[FLHS:.+]] = shapex.ranked_dim %arg2[1]
+ // CHECK-DAG: %[[FRHS:.+]] = shapex.ranked_dim %arg3[1]
+ // CHECK-DAG: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[BATCH]], %[[FRHS]], %[[FLHS]]
+ // CHECK-DAG: return %[[SHAPE]]
+ %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
+ return %1 : !shapex.ranked_shape<[?,?,?]>
+}
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index f20127f..142cccc 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -277,6 +277,12 @@
return builder.create<MakeRankedShapeOp>(loc, resultShape, outputExtents);
}
+Value rewriteDynamicReshape(RankedShapeType resultShape, DynamicReshapeOp op,
+ OpBuilder &builder) {
+ return builder.create<FromExtentTensorOp>(op.getLoc(), resultShape,
+ op.output_shape());
+}
+
} // namespace
// Creates a custom op shape builder for XLA-HLO ops that are not otherwise
@@ -306,6 +312,8 @@
b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
b.insertOpRankedShapeBuilder<xla_hlo::DotGeneralOp>(rewriteDotGeneral);
+ b.insertOpRankedShapeBuilder<xla_hlo::DynamicReshapeOp>(
+ rewriteDynamicReshape);
}
} // namespace xla_hlo
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
index 2200907..59882e5 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/custom_ops.mlir
@@ -35,3 +35,14 @@
%1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
return %1 : !shapex.ranked_shape<[?,?,?]>
}
+
+// -----
+
+// CHECK-LABEL: func @dynamic_reshape
+func @dynamic_reshape(%arg0: tensor<?xf32>, %arg1: tensor<2xindex>) -> !shapex.ranked_shape<[?,?]> {
+ // CHECK-DAG: %[[SHAPE:.+]] = "shapex.from_extent_tensor"(%arg1)
+ // CHECK-DAG: return %[[SHAPE]]
+ %0 = "xla_hlo.dynamic_reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+ %1 = shapex.get_ranked_shape %0 : tensor<?x?xf32> -> !shapex.ranked_shape<[?,?]>
+ return %1 : !shapex.ranked_shape<[?,?]>
+}
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
index e389397..7de0280 100644
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -34,6 +34,7 @@
],
deps = [
"//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/Shape/Plugins/VMLA:VMLAShapeBuilder",
"//iree/compiler/Dialect/Shape/Plugins/XLA:XlaHloShapeBuilder",
"//iree/compiler/Dialect/Shape/Utils:TypeConversion",
"//iree/compiler/Utils",
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
index d4aa7c8..0a6ee0a 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -37,6 +37,7 @@
MLIRSupport
MLIRTransforms
iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::Shape::Plugins::VMLA::VMLAShapeBuilder
iree::compiler::Dialect::Shape::Plugins::XLA::XlaHloShapeBuilder
iree::compiler::Dialect::Shape::Utils::TypeConversion
iree::compiler::Utils
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index 894f746..739e899 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
#include "iree/compiler/Utils/PatternUtils.h"
@@ -47,6 +48,7 @@
static CustomOpShapeBuilderList globalBuilders = ([]() {
CustomOpShapeBuilderList builders;
xla_hlo::populateXlaHloCustomOpShapeBuilder(builders);
+ IREE::VMLA::populateVMLACustomOpShapeBuilder(builders);
return builders;
})();
return &globalBuilders;
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
index 8d7a8a6..14fab6d 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
@@ -21,7 +21,6 @@
name = "HLOToVMLA",
srcs = [
"ConvertConvOps.cpp",
- "ConvertDotOps.cpp",
"ConvertHLOToVMLA.cpp",
"ConvertReductionOps.cpp",
],
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
index 0da3cca..b47ddce 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
@@ -21,7 +21,6 @@
"ConvertHLOToVMLA.h"
SRCS
"ConvertConvOps.cpp"
- "ConvertDotOps.cpp"
"ConvertHLOToVMLA.cpp"
"ConvertReductionOps.cpp"
DEPS
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp
deleted file mode 100644
index cfbc07f..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertDotOps.cpp
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitVector.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Convert instances of `xla_hlo.dot` to `xla_hlo.dot_general`.
-//
-// TODO(silvasean): This logically is part of a future HLO client -> HLO server
-// type of pass in the xla_hlo dialect proper.
-struct CanonicalizeDotOp : public OpRewritePattern<xla_hlo::DotOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(xla_hlo::DotOp op,
- PatternRewriter &rewriter) const override {
- Value lhs = op.lhs();
- Value rhs = op.rhs();
- RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
- RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
- if (!lhsType || !rhsType) {
- return failure();
- }
- if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
- return failure();
- }
- // TODO(silvasean): Move this helper to MLIR core.
- auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- rewriter.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
- };
- auto dimensionNumbers = xla_hlo::DotDimensionNumbers::get(
- /*lhs_batching_dimensions=*/make1DElementsAttr({}),
- /*rhs_batching_dimensions=*/make1DElementsAttr({}),
- /*lhs_contracting_dimensions=*/make1DElementsAttr({1}),
- /*rhs_contracting_dimensions=*/make1DElementsAttr({0}),
- rewriter.getContext());
- rewriter.replaceOpWithNewOp<xla_hlo::DotGeneralOp>(
- op, op.getType(), lhs, rhs, dimensionNumbers,
- op.precision_config().hasValue() ? op.precision_config().getValue()
- : nullptr);
- return success();
- }
-};
-
-// Inserts transposes on the operands of DotGeneralOp's such that the resulting
-// batch dimensions are all the leading dimensions and all the contracting
-// dimensions are all the trailing dimensions.
-//
-// Furthermore, all batch, contracting, and free dimensions are flattened into
-// single dimensions, with an appropriate reshape back to the original
-// dimensions.
-//
-// This results in a very simple corresponding VMLA op in the runtime.
-// [1 batch dimension, 1 free dimension, 1 contracting dimension].
-//
-// The result doesn't have a DotGeneralOp, but rather a
-// VMLA::BatchMatMulPseudoOp which represents this transformation.
-//
-// TODO(silvasean): Move this to a "prepare" pass and test separately.
-struct CanonicalizeDotGeneralOp
- : public OpRewritePattern<xla_hlo::DotGeneralOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(xla_hlo::DotGeneralOp op,
- PatternRewriter &rewriter) const override {
- Value lhs = op.lhs();
- Value rhs = op.rhs();
- RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
- RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
- Type elementType = lhsType.getElementType();
- if (!lhsType || !rhsType) {
- return failure();
- }
- // TODO(silvasean): Extend to support dynamic shapes.
- // This op is a really good case for testing our e2e dynamic shape support.
- // There's interesting questions at the TF2XLA level too.
- if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) {
- emitWarning(op.getLoc())
- << "unimplemented: vmla dynamic shapes for xla_hlo.dot_general";
- return failure();
- }
- xla_hlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
- auto extract1DVector = [](DenseIntElementsAttr elements) {
- SmallVector<int64_t, 6> ret;
- for (const APInt &element : elements) {
- ret.push_back(element.getLimitedValue());
- }
- return ret;
- };
- auto lhsBatchingDims =
- extract1DVector(dimNumbers.lhs_batching_dimensions());
- auto rhsBatchingDims =
- extract1DVector(dimNumbers.rhs_batching_dimensions());
- auto lhsContractingDims =
- extract1DVector(dimNumbers.lhs_contracting_dimensions());
- auto rhsContractingDims =
- extract1DVector(dimNumbers.rhs_contracting_dimensions());
- // TODO(silvasean): Move this helper to MLIR core.
- auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- rewriter.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
- };
- auto totalElements = [](ArrayRef<int64_t> extents) {
- int64_t numElements = 1;
- for (int64_t extent : extents) {
- numElements *= extent;
- }
- return numElements;
- };
- auto handleOneSide = [&](ArrayRef<int64_t> batchingDims,
- ArrayRef<int64_t> contractingDims, Value &value,
- RankedTensorType &type,
- SmallVectorImpl<int64_t> &outFreeDims,
- SmallVectorImpl<int64_t> &outFreeDimExtents,
- SmallVectorImpl<int64_t> &outBatchingDimExtents) {
- outBatchingDimExtents.clear();
- RankedTensorType untransposedType = type;
- SmallVector<int64_t, 6> permutation;
- llvm::BitVector freeDims(untransposedType.getRank(), true);
- SmallVector<int64_t, 6> contractingDimExtents;
- for (auto dims : {batchingDims, contractingDims}) {
- for (int64_t dim : dims) {
- freeDims.reset(dim);
- }
- }
- for (int64_t dim : batchingDims) {
- permutation.push_back(dim);
- outBatchingDimExtents.push_back(untransposedType.getDimSize(dim));
- }
- for (int64_t dim : freeDims.set_bits()) {
- permutation.push_back(dim);
- outFreeDims.push_back(dim);
- outFreeDimExtents.push_back(untransposedType.getDimSize(dim));
- }
- for (int64_t dim : contractingDims) {
- permutation.push_back(dim);
- contractingDimExtents.push_back(untransposedType.getDimSize(dim));
- }
- // Construct the type that the transpose will result in.
- SmallVector<int64_t, 6> transposeShape;
- for (int64_t index : permutation) {
- transposeShape.push_back(type.getDimSize(index));
- }
- auto transposeType = RankedTensorType::get(transposeShape, elementType);
- auto transpose = rewriter.create<xla_hlo::TransposeOp>(
- op.getLoc(), transposeType, value, make1DElementsAttr(permutation));
-
- auto reshapeType =
- RankedTensorType::get({totalElements(outBatchingDimExtents),
- totalElements(outFreeDimExtents),
- totalElements(contractingDimExtents)},
- elementType);
- value = rewriter.create<xla_hlo::ReshapeOp>(op.getLoc(), reshapeType,
- transpose);
- };
- SmallVector<int64_t, 6> batchingDimExtents;
- SmallVector<int64_t, 6> lhsFreeDims;
- SmallVector<int64_t, 6> lhsFreeDimExtents;
- handleOneSide(lhsBatchingDims, lhsContractingDims, lhs, lhsType,
- lhsFreeDims, lhsFreeDimExtents, batchingDimExtents);
- SmallVector<int64_t, 6> rhsFreeDims;
- SmallVector<int64_t, 6> rhsFreeDimExtents;
- handleOneSide(rhsBatchingDims, rhsContractingDims, rhs, rhsType,
- rhsFreeDims, rhsFreeDimExtents, batchingDimExtents);
-
- auto dstShape = llvm::to_vector<6>(llvm::makeArrayRef(
- {totalElements(batchingDimExtents), totalElements(rhsFreeDimExtents),
- totalElements(lhsFreeDimExtents)}));
- auto dstType = RankedTensorType::get(dstShape, elementType);
- Value dst = rewriter.create<IREE::VMLA::BatchMatMulPseudoOp>(
- op.getLoc(), dstType, lhs, rhs);
- RankedTensorType transposeType = RankedTensorType::get(
- {dstShape[0], dstShape[2], dstShape[1]}, elementType);
- auto transpose = rewriter.create<xla_hlo::TransposeOp>(
- op.getLoc(), transposeType, dst, make1DElementsAttr({0, 2, 1}));
- auto reshapeShape = batchingDimExtents;
- reshapeShape.append(lhsFreeDimExtents.begin(), lhsFreeDimExtents.end());
- reshapeShape.append(rhsFreeDimExtents.begin(), rhsFreeDimExtents.end());
- auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
- rewriter.replaceOpWithNewOp<xla_hlo::ReshapeOp>(op, reshapeType, transpose);
- return success();
- }
-};
-
-} // namespace
-
-void populateHLODotToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- // Tensor-level preparation for lowering to the runtime BatchMatMul op.
- patterns.insert<CanonicalizeDotGeneralOp>(context);
- patterns.insert<CanonicalizeDotOp>(context);
-
- // Lowering from tensor ops to VMLA runtime ops.
- patterns.insert<VMLAOpConversion<IREE::VMLA::BatchMatMulPseudoOp,
- IREE::VMLA::BatchMatMulOp>>(context,
- typeConverter);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 1d2ba93..4aafac6 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -60,7 +60,10 @@
LogicalResult matchAndRewrite(
SRC srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (srcOp.getOperand().hasOneUse()) {
+ // xla_hlo::DynamicReshape has multiple operands, so we cannot just say
+ // `getOperand()`. But `getOperand(0)` doesn't work for the other
+ // single-operand ops. So use the raw Operation to get the operand.
+ if (srcOp.getOperation()->getOperand(0).hasOneUse()) {
// Can directly pass through the input buffer as we don't need to clone
// for other users.
rewriter.replaceOp(srcOp, operands[0]);
@@ -76,15 +79,22 @@
}
};
-// Converts a broadcast_in_dim op to either a broadcast or a tile depending on
-// the input shape.
+// Converts a shapex.ranked_broadcast_in_dim op to either a broadcast or a tile
+// depending on the input shape.
+//
+// We assume that xla_hlo.broadcast_in_dim and xla_hlo.dynamic_broadcast_in_dim
+// have been legalized into that op.
+//
+// Note that shapex.ranked_broadcast_in_dim is not strictly speaking an HLO op,
+// but we would like HLO to eventually have something like it, and the shapex
+// dialect is currently where we have it stuffed.
struct BroadcastInDimOpConversion
- : public OpConversionPattern<xla_hlo::BroadcastInDimOp> {
+ : public OpConversionPattern<Shape::RankedBroadcastInDimOp> {
BroadcastInDimOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
LogicalResult matchAndRewrite(
- xla_hlo::BroadcastInDimOp srcOp, ArrayRef<Value> operands,
+ Shape::RankedBroadcastInDimOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto srcShape = VMLAConversionTarget::getTensorShape(
srcOp.getLoc(), srcOp.operand(), typeConverter, rewriter);
@@ -454,8 +464,10 @@
// xla_hlo.reduce and xla_hlo.reduce_window.
populateHLOReductionToVMLAPatterns(context, patterns, typeConverter);
- // xla_hlo.dot and xla_hlo.dot_general.
- populateHLODotToVMLAPatterns(context, patterns, typeConverter);
+ // vmla.batch.matmul.pseudo
+ patterns.insert<VMLAOpConversion<IREE::VMLA::BatchMatMulPseudoOp,
+ IREE::VMLA::BatchMatMulOp>>(context,
+ typeConverter);
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
@@ -534,6 +546,7 @@
patterns.insert<IdentityOpConversion<xla_hlo::BitcastConvertOp>>(context);
patterns.insert<IdentityOpConversion<xla_hlo::CopyOp>>(context);
patterns.insert<IdentityOpConversion<xla_hlo::ReshapeOp>>(context);
+ patterns.insert<IdentityOpConversion<xla_hlo::DynamicReshapeOp>>(context);
// Conversions that don't have a 1:1 mapping, mostly involving buffer views
// or transfers.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
index 44b0a3d..aaccde9 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
@@ -2,6 +2,7 @@
// CHECK-LABEL: @broadcast_in_dim_2D_3D
func @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> attributes { sym_visibility = "private" } {
+ %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
%input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
// CHECK-DAG: %[[SRC:.+]] = "vmla.constant"
// CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2,4]>
@@ -9,7 +10,7 @@
// CHECK-DAG: %[[DST_SIZE:.+]] = constant 96 : index
// CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
// CHECK-DAG: "vmla.tile"(%[[SRC]], %[[SRC_SHAPE]], %[[DST]], %[[DST_SHAPE]]) {element_type = i32}
- %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32>
+ %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
// CHECK-NEXT: return %[[DST]] : !vmla.buffer
return %0 : tensor<3x2x4xi32>
}
@@ -18,6 +19,7 @@
// CHECK-LABEL: @broadcast_in_dim_3D_scalar
func @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> attributes { sym_visibility = "private" } {
+ %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
// CHECK-DAG: %[[SRC:.+]] = "vmla.constant"
// CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
// CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
@@ -25,7 +27,7 @@
%input = constant dense<42> : tensor<i32>
// CHECK-NEXT: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
// CHECK-NEXT: "vmla.broadcast"(%[[SRC]], %[[SRC_SHAPE]], %[[DST]], %[[DST_SHAPE]]) {element_type = i32}
- %0 = "xla_hlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>) -> tensor<3x2x4xi32>
+ %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
// CHECK-NEXT: return %[[DST]] : !vmla.buffer
return %0 : tensor<3x2x4xi32>
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir
deleted file mode 100644
index d7fc32b..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dot_general.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// -----
-
-func @f(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> attributes {sym_visibility = "private"} {
- // CHECK: vmla.batch.matmul
- %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {
- lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
- lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
- rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
- rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>
- }} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
- return %0 : tensor<3x5xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 5bfa951..8cefc91 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -379,7 +379,7 @@
All operands are rank-3 with the following dimension structure:
- lhs = [B, FLHS, C]
- rhs = [B, FRHS, C]
- - dst = [B, FLHS, FRHS]
+ - dst = [B, FRHS, FLHS]
Where:
- B = batch dimension
- C = contracting dimension
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index 2981f40..7ba9688 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -22,6 +22,7 @@
srcs = [
"Conversion.cpp",
"Passes.cpp",
+ "PreConversionLowering.cpp",
"UnrollReductions.cpp",
],
hdrs = [
@@ -37,6 +38,7 @@
"//iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA",
"//iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA",
"//iree/compiler/Dialect/VMLA/IR",
+ "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
index e980b46..80fad09 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@
SRCS
"Conversion.cpp"
"Passes.cpp"
+ "PreConversionLowering.cpp"
"UnrollReductions.cpp"
DEPS
LLVMSupport
@@ -39,6 +40,7 @@
iree::compiler::Dialect::VMLA::Conversion::HLOToVMLA
iree::compiler::Dialect::VMLA::Conversion::StandardToVMLA
iree::compiler::Dialect::VMLA::IR
+ iree::compiler::Dialect::VMLA::IR::VMLADialect
tensorflow::mlir_xla
PUBLIC
)
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index b9cf4dc..00f806e 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -99,6 +99,21 @@
});
conversionTarget.addIllegalOp<Shape::RankedDimOp>();
conversionTarget.addIllegalOp<Shape::RankedDimsOp>();
+ // XLA ops use tensors of extents, so we tend to launder back to
+ // !shapex.ranked_shape for most shape-related things. This is a problem
+ // because we don't have a lowering for the ops going back and forth between
+ // tensors of extents and !shapex.ranked_shape. So we mark this op as
+ // illegal and rely on our fold of `from_extent_tensor(to_extent_tensor(x))
+ // -> x` to eliminate these ops. Setting it illegal here triggers that fold.
+ // This is skating on thin ice.
+ // TODO(silvasean): Legalize ToExtentTensorOp and FromExtentTensorOp.
+ conversionTarget.addIllegalOp<Shape::FromExtentTensorOp>();
+ // RankedBroadcastInDimOp is an logically something that should be an
+ // xla_hlo op (or in a dialect at a similar level of abstraction), but since
+ // it isn't technically in that dialect, we need to special-case mark it as
+ // illegal here.
+ // TODO(silvasean): Reconcile the dialect layering here.
+ conversionTarget.addIllegalOp<Shape::RankedBroadcastInDimOp>();
if (failed(applyPartialConversion(getOperation(), conversionTarget,
conversionPatterns, &typeConverter))) {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index e0393f1..dc4cc6c 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -43,8 +43,10 @@
// ---------------------------------------------------------------------------
// Tensor-level rewrites.
- // At this point, the computation is in tensor-level CFG form with the ability
- // perform transformations that alter shapes.
+ // At this point, the computation is in tensor-level CFG form.
+ // There are no specific requirements on shape-related calculations at this
+ // point yet, so general tensor->tensor transformations in preparation
+ // for later conversion steps should go here.
// ---------------------------------------------------------------------------
// Legalize input types.
// TODO(benvanik): legalize input.
@@ -56,6 +58,12 @@
// Unroll multi-dimensional reductions to one reduction per dimension.
passManager.addNestedPass<FuncOp>(createUnrollReductionsPass());
+ // Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
+ passManager.addNestedPass<FuncOp>(createPreConversionLoweringPass());
+
+ // Clean up the IR before going into shape-materialized IR.
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+
// ---------------------------------------------------------------------------
// Shape calculation.
// Pre-conditions:
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.h b/iree/compiler/Dialect/VMLA/Transforms/Passes.h
index dbb413e..7d7b74d 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.h
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.h
@@ -52,6 +52,9 @@
// dimension, from innermost to outermost.
std::unique_ptr<OperationPass<FuncOp>> createUnrollReductionsPass();
+// Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
+std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass();
+
//===----------------------------------------------------------------------===//
// Dialect conversion
//===----------------------------------------------------------------------===//
@@ -66,6 +69,7 @@
inline void registerVMLAPasses() {
createUnrollReductionsPass();
createConversionPass();
+ createPreConversionLoweringPass();
}
} // namespace VMLA
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
new file mode 100644
index 0000000..610bd6c
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -0,0 +1,320 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitVector.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VMLA {
+
+namespace {
+
+// Convert instances of `xla_hlo.dot` to `xla_hlo.dot_general`.
+//
+// TODO(silvasean): This logically is part of a future HLO client -> HLO server
+// type of pass in the xla_hlo dialect proper.
+struct LowerDotOp : public OpRewritePattern<xla_hlo::DotOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(xla_hlo::DotOp op,
+ PatternRewriter &rewriter) const override {
+ Value lhs = op.lhs();
+ Value rhs = op.rhs();
+ RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
+ RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
+ if (!lhsType || !rhsType) {
+ return failure();
+ }
+ if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
+ return failure();
+ }
+ // TODO(silvasean): Move this helper to MLIR core.
+ auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
+ auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+ rewriter.getIntegerType(64));
+ return DenseIntElementsAttr::get(type, integers);
+ };
+ auto dimensionNumbers = xla_hlo::DotDimensionNumbers::get(
+ /*lhs_batching_dimensions=*/make1DElementsAttr({}),
+ /*rhs_batching_dimensions=*/make1DElementsAttr({}),
+ /*lhs_contracting_dimensions=*/make1DElementsAttr({1}),
+ /*rhs_contracting_dimensions=*/make1DElementsAttr({0}),
+ rewriter.getContext());
+ rewriter.replaceOpWithNewOp<xla_hlo::DotGeneralOp>(
+ op, op.getType(), lhs, rhs, dimensionNumbers,
+ op.precision_config().hasValue() ? op.precision_config().getValue()
+ : nullptr);
+ return success();
+ }
+};
+
+// Inserts transposes on the operands of DotGeneralOp's such that the resulting
+// batch dimensions are all the leading dimensions and all the contracting
+// dimensions are all the trailing dimensions.
+//
+// Furthermore, all batch, contracting, and free dimensions are flattened into
+// single dimensions, with an appropriate reshape back to the original
+// dimensions.
+//
+// This results in a very simple corresponding VMLA op in the runtime.
+// [1 batch dimension, 1 free dimension, 1 contracting dimension].
+//
+// The result doesn't have a DotGeneralOp, but rather a
+// VMLA::BatchMatMulPseudoOp which represents this transformation.
+//
+// TODO(silvasean): Move this to a "prepare" pass and test separately.
+struct LowerDotGeneralOp : public OpRewritePattern<xla_hlo::DotGeneralOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(xla_hlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ Value lhs = op.lhs();
+ Value rhs = op.rhs();
+ RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
+ RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
+ Type elementType = lhsType.getElementType();
+ if (!lhsType || !rhsType) {
+ return rewriter.notifyMatchFailure(op, "requires ranked types");
+ }
+ xla_hlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+ auto extract1DVector = [](DenseIntElementsAttr elements) {
+ SmallVector<int64_t, 6> ret;
+ for (const APInt &element : elements) {
+ ret.push_back(element.getLimitedValue());
+ }
+ return ret;
+ };
+ auto lhsBatchingDims =
+ extract1DVector(dimNumbers.lhs_batching_dimensions());
+ auto rhsBatchingDims =
+ extract1DVector(dimNumbers.rhs_batching_dimensions());
+ auto lhsContractingDims =
+ extract1DVector(dimNumbers.lhs_contracting_dimensions());
+ auto rhsContractingDims =
+ extract1DVector(dimNumbers.rhs_contracting_dimensions());
+ // TODO(silvasean): Move this helper to MLIR core.
+ auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
+ auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+ rewriter.getIntegerType(64));
+ return DenseIntElementsAttr::get(type, integers);
+ };
+ auto totalElements = [&](ArrayRef<Value> extents) {
+ Value numElements = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), IntegerAttr::get(rewriter.getIndexType(), 1));
+ for (Value extent : extents) {
+ numElements =
+ rewriter.create<mlir::MulIOp>(op.getLoc(), numElements, extent);
+ }
+ return numElements;
+ };
+ auto handleOneSide = [&](ArrayRef<int64_t> batchingDims,
+ ArrayRef<int64_t> contractingDims, Value &value,
+ RankedTensorType &type,
+ SmallVectorImpl<int64_t> &outFreeDims,
+ SmallVectorImpl<Value> &outFreeDimExtents,
+ SmallVectorImpl<Value> &outBatchingDimExtents) {
+ outBatchingDimExtents.clear();
+ RankedTensorType untransposedType = type;
+ SmallVector<int64_t, 6> permutation;
+ llvm::BitVector freeDims(untransposedType.getRank(), true);
+ SmallVector<Value, 6> contractingDimExtents;
+ Value valueShape =
+ rewriter.create<Shape::GetRankedShapeOp>(op.getLoc(), value);
+ auto getExtentValue = [&](int64_t dim) {
+ return rewriter.create<Shape::RankedDimOp>(op.getLoc(), valueShape,
+ dim);
+ };
+ for (auto dims : {batchingDims, contractingDims}) {
+ for (int64_t dim : dims) {
+ freeDims.reset(dim);
+ }
+ }
+ for (int64_t dim : batchingDims) {
+ permutation.push_back(dim);
+ outBatchingDimExtents.push_back(getExtentValue(dim));
+ }
+ for (int64_t dim : freeDims.set_bits()) {
+ permutation.push_back(dim);
+ outFreeDims.push_back(dim);
+ outFreeDimExtents.push_back(getExtentValue(dim));
+ }
+ for (int64_t dim : contractingDims) {
+ permutation.push_back(dim);
+ contractingDimExtents.push_back(getExtentValue(dim));
+ }
+ // Construct the type that the transpose will result in.
+ SmallVector<int64_t, 6> transposeStaticShape;
+ for (int64_t index : permutation) {
+ (void)index;
+ transposeStaticShape.push_back(-1);
+ }
+ auto transposeType =
+ RankedTensorType::get(transposeStaticShape, elementType);
+ auto transpose = rewriter.create<xla_hlo::TransposeOp>(
+ op.getLoc(), transposeType, value, make1DElementsAttr(permutation));
+
+ SmallVector<Value, 6> reshapeShape;
+ reshapeShape.push_back(totalElements(outBatchingDimExtents));
+ reshapeShape.push_back(totalElements(outFreeDimExtents));
+ reshapeShape.push_back(totalElements(contractingDimExtents));
+ auto reshapeType = RankedTensorType::get(
+ {static_cast<int64_t>(-1), static_cast<int64_t>(-1),
+ static_cast<int64_t>(-1)},
+ elementType);
+ auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
+ op.getLoc(),
+ Shape::RankedShapeType::get(reshapeType.getShape(),
+ rewriter.getContext()),
+ reshapeShape);
+ auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
+ op.getLoc(), reshapeRankedShape);
+ value = rewriter.create<xla_hlo::DynamicReshapeOp>(
+ op.getLoc(), reshapeType, transpose, reshapeShapeExtentTensor);
+ };
+ SmallVector<Value, 6> batchingDimExtents;
+ SmallVector<int64_t, 6> lhsFreeDims;
+ SmallVector<Value, 6> lhsFreeDimExtents;
+ handleOneSide(lhsBatchingDims, lhsContractingDims, lhs, lhsType,
+ lhsFreeDims, lhsFreeDimExtents, batchingDimExtents);
+ SmallVector<int64_t, 6> rhsFreeDims;
+ SmallVector<Value, 6> rhsFreeDimExtents;
+ handleOneSide(rhsBatchingDims, rhsContractingDims, rhs, rhsType,
+ rhsFreeDims, rhsFreeDimExtents, batchingDimExtents);
+
+ auto dstStaticShape = llvm::to_vector<6>(
+ llvm::makeArrayRef({static_cast<int64_t>(-1), static_cast<int64_t>(-1),
+ static_cast<int64_t>(-1)}));
+ auto dstType = RankedTensorType::get(dstStaticShape, elementType);
+ Value dst = rewriter.create<IREE::VMLA::BatchMatMulPseudoOp>(
+ op.getLoc(), dstType, lhs, rhs);
+ RankedTensorType transposeType = RankedTensorType::get(
+ {dstStaticShape[0], dstStaticShape[2], dstStaticShape[1]}, elementType);
+ auto transpose = rewriter.create<xla_hlo::TransposeOp>(
+ op.getLoc(), transposeType, dst, make1DElementsAttr({0, 2, 1}));
+ auto reshapeShape = batchingDimExtents;
+ reshapeShape.append(lhsFreeDimExtents.begin(), lhsFreeDimExtents.end());
+ reshapeShape.append(rhsFreeDimExtents.begin(), rhsFreeDimExtents.end());
+ SmallVector<int64_t, 6> reshapeStaticShape;
+ for (int i = 0, e = batchingDimExtents.size() + lhsFreeDimExtents.size() +
+ rhsFreeDimExtents.size();
+ i < e; i++) {
+ reshapeStaticShape.push_back(-1);
+ }
+ auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
+ op.getLoc(),
+ Shape::RankedShapeType::get(reshapeStaticShape, rewriter.getContext()),
+ reshapeShape);
+ auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
+ op.getLoc(), reshapeRankedShape);
+ rewriter.replaceOpWithNewOp<xla_hlo::DynamicReshapeOp>(
+ op, op.getType(), transpose, reshapeShapeExtentTensor);
+ return success();
+ }
+};
+
+class LowerBroadcastInDimOp
+ : public OpRewritePattern<xla_hlo::BroadcastInDimOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(xla_hlo::BroadcastInDimOp op,
+ PatternRewriter &rewriter) const override {
+ auto type = op.getType().cast<RankedTensorType>();
+ auto shapeType =
+ Shape::RankedShapeType::get(type.getShape(), rewriter.getContext());
+ auto shape =
+ rewriter.create<Shape::ConstRankedShapeOp>(op.getLoc(), shapeType);
+ rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
+ op, op.getType(), op.operand(), shape, op.broadcast_dimensions());
+ return success();
+ }
+};
+
+// Lower xla_hlo::BroadcastOp via xla_hlo::BroadcastInDimOp.
+class LowerBroadcastOp : public OpRewritePattern<xla_hlo::BroadcastOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(xla_hlo::BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ auto type = op.getOperand().getType().cast<RankedTensorType>();
+ auto resultType = op.getType().cast<RankedTensorType>();
+ auto broadcastDimensions = llvm::to_vector<6>(llvm::seq<int64_t>(
+ resultType.getRank() - type.getRank(), resultType.getRank()));
+ rewriter.replaceOpWithNewOp<xla_hlo::BroadcastInDimOp>(
+ op, op.getType(), op.getOperand(),
+ rewriter.getI64TensorAttr(broadcastDimensions));
+ return success();
+ }
+};
+
+class PreConversionLoweringPass
+ : public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
+ public:
+ void runOnOperation() {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ ConversionTarget target(*context);
+ target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<IREE::VMLA::VMLADialect>();
+ target.addLegalDialect<xla_hlo::XlaHloDialect>();
+ target.addLegalDialect<ShapeDialect>();
+
+ target.addIllegalOp<xla_hlo::DotGeneralOp>();
+ patterns.insert<LowerDotGeneralOp>(context);
+ target.addIllegalOp<xla_hlo::DotOp>();
+ patterns.insert<LowerDotOp>(context);
+ target.addIllegalOp<xla_hlo::BroadcastInDimOp>();
+ patterns.insert<LowerBroadcastInDimOp>(context);
+ target.addIllegalOp<xla_hlo::BroadcastOp>();
+ patterns.insert<LowerBroadcastOp>(context);
+
+ if (failed(applyPartialConversion(getOperation(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+static PassRegistration<PreConversionLoweringPass> pass(
+ "iree-vmla-pre-conversion-lowering",
+ "Tensor-level pattern-based lowerings.");
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass() {
+ return std::make_unique<PreConversionLoweringPass>();
+}
+
+} // namespace VMLA
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
new file mode 100644
index 0000000..ecd94f7
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -0,0 +1,33 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering %s | IreeFileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
+ // CHECK: vmla.batch.matmul
+ %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+ lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+ rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>
+ }} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
+ // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
+ %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
+ return %0 : tensor<4x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>) -> tensor<5x6x3xf32> {
+ // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs5_6_3)
+ %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[5, 6]> : tensor<2xi64>} : (tensor<3xf32>) -> tensor<5x6x3xf32>
+ return %0 : tensor<5x6x3xf32>
+}
diff --git a/iree/test/e2e/regression/dynamic_dot_general.mlir b/iree/test/e2e/regression/dynamic_dot_general.mlir
new file mode 100644
index 0000000..488136f
--- /dev/null
+++ b/iree/test/e2e/regression/dynamic_dot_general.mlir
@@ -0,0 +1,46 @@
+// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -input-value="2x2xf32=[[1.0, 0.0], [0.0, 1.0]]" -input-value="2x3xf32=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]" -input-value="2x2x2xf32=[[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]]" -input-value="2x2x3xf32=[[[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]" | IreeFileCheck %s
+
+// TODO(silvasean): Extent xla_ops directory test infra to support
+// testing dynamic shapes.
+
+// CHECK-LABEL: EXEC @basic_dot
+func @basic_dot(
+ %lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>,
+ %unused0: tensor<?x?x?xf32>, %unused1: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+ %0 = "xla_hlo.dot_general"(%lhs, %rhs) {dot_dimension_numbers={
+ lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+ lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
+ rhs_contracting_dimensions = dense<0> : tensor<1xi64>
+ }} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: 2x3xf32=[1 2 3][4 5 6]
+
+// CHECK-LABEL: EXEC @batch_dimension
+func @batch_dimension(
+ %unused0: tensor<?x?xf32>, %unused1: tensor<?x?xf32>,
+ %lhs: tensor<?x?x?xf32>, %rhs: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = "xla_hlo.dot_general"(%lhs, %rhs) {dot_dimension_numbers={
+ lhs_batching_dimensions = dense<[0]> : tensor<1xi64>,
+ lhs_contracting_dimensions = dense<[2]> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0]> : tensor<1xi64>,
+ rhs_contracting_dimensions = dense<[1]> : tensor<1xi64>
+ }} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK: 2x2x3xf32=[
+// CHECK-SAME: [1.5 2.5 3.5][4.5 5.5 6.5]
+// CHECK-SAME: ][
+// CHECK-SAME: [2 4 6][8 10 12]
+// CHECK-SAME: ]
+
+
+// TODO(silvasean): Add more tests when we have better test infra.
+// This is currently too verbose / unreadable. We should test:
+// - multiple contracting dimensions
+// - multiple batch dimensions
+// - multiple free dimensions
+// - intermingled batch, free, and contracting dimensions
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 08a5c4e..13a6ac7 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -74,6 +74,7 @@
"//iree/compiler/Dialect/HAL/Transforms",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/IREE/Transforms",
+ "//iree/compiler/Dialect/Shape/Conversion",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 7ba48ce..32f51ea 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -138,6 +138,7 @@
iree::compiler::Dialect::HAL::Transforms
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::IREE::Transforms
+ iree::compiler::Dialect::Shape::Conversion
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::VM::Analysis
diff --git a/iree/tools/init_passes.h b/iree/tools/init_passes.h
index 1b5631f..ee0ac73 100644
--- a/iree/tools/init_passes.h
+++ b/iree/tools/init_passes.h
@@ -26,6 +26,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
@@ -175,6 +176,7 @@
IREE::Flow::registerFlowAnalysisTestPasses();
IREE::HAL::registerHALPasses();
IREE::registerIreePasses();
+ Shape::registerShapeConversionPasses();
Shape::registerShapePasses();
IREE::VM::registerVMPasses();
IREE::VM::registerVMAnalysisTestPasses();