Rework lowering of shapex.ranked_broadcast_shape
This makes it a lot simpler and makes it actually correct in the dynamic
case. (it would previously miscompile)
There's still a TODO for emitting an error, which should be implementable
soon (waiting on lower-level parts of the stack).
PiperOrigin-RevId: 307537788
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 19b4ef6..5df0f1b 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -36,6 +36,7 @@
],
)
for name in [
+ "broadcasting_test",
"batch_norm_test",
"fill_test",
"control_flow_test",
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
new file mode 100644
index 0000000..f2f74d3
--- /dev/null
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test broadcasting support."""
+
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class BroadcastingModule(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([None], tf.float32),
+ tf.TensorSpec([None], tf.float32),
+ ])
+ def add(self, lhs, rhs):
+ return lhs + rhs
+
+
+@tf_test_utils.compile_modules(
+ backends=["tf", "iree_vmla"], m=BroadcastingModule)
+class BroadcastingTest(tf_test_utils.SavedModelTestCase):
+
+ def test_add_same_shape(self):
+ m = self.modules.m.all
+ dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
+ dst.print().assert_all_close()
+
+
+# TODO(silvasean): Make these work.
+# def test_add_broadcast_lhs(self):
+# m = self.modules.m.all
+# dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
+# dst.print().assert_all_close()
+#
+# def test_add_broadcast_rhs(self):
+# m = self.modules.m.all
+# dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
+# dst.print().assert_all_close()
+
+if __name__ == "__main__":
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+ tf.test.main()
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index 0ca9d19..f73c3b9 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -53,103 +53,54 @@
}
Value rewriteShapexRankedBroadcastShape(
- RankedBroadcastShapeOp bcastOp,
- RankedBroadcastShapeOp::OperandAdaptor operands, OpBuilder &builder) {
- auto lhsRs = operands.lhs().getType().cast<RankedShapeType>();
- auto rhsRs = operands.rhs().getType().cast<RankedShapeType>();
+ RankedBroadcastShapeOp op, RankedBroadcastShapeOp::OperandAdaptor operands,
+ OpBuilder &builder) {
+ auto lhs = operands.lhs();
+ auto rhs = operands.rhs();
+ auto loc = op.getLoc();
+ auto resultRs = op.getResult().getType().cast<RankedShapeType>();
- auto loc = bcastOp.getLoc();
- auto resultRs = bcastOp.getResult().getType().cast<RankedShapeType>();
- auto dimType = IndexType::get(builder.getContext());
+ auto c1 = builder.create<ConstantIndexOp>(loc, 1);
+ // Entries are the extent of the output along that dimension corresponding to
+ // the given side, or 1 (which is neutral w.r.t. broadcasting).
+ SmallVector<Value, 4> lhsResultExtents(resultRs.getRank(), c1);
+ SmallVector<Value, 4> rhsResultExtents(resultRs.getRank(), c1);
- // Pairs of the shape dim and corresponding value if dynamic.
- SmallVector<std::pair<Optional<int>, Value>, 4> lhsDims;
- SmallVector<std::pair<Optional<int>, Value>, 4> rhsDims;
- lhsDims.resize(resultRs.getRank());
- rhsDims.resize(resultRs.getRank());
+ for (auto dim : llvm::enumerate(op.lhs_broadcast_dimensions())) {
+ auto inputDim = dim.index();
+ auto outputDim = dim.value().getZExtValue();
+ lhsResultExtents[outputDim] =
+ builder.create<RankedDimOp>(loc, lhs, inputDim);
+ }
+ for (auto dim : llvm::enumerate(op.rhs_broadcast_dimensions())) {
+ auto inputDim = dim.index();
+ auto outputDim = dim.value().getZExtValue();
+ rhsResultExtents[outputDim] =
+ builder.create<RankedDimOp>(loc, rhs, inputDim);
+ }
- // Populate the lhs dims.
- for (auto dimMap : llvm::enumerate(bcastOp.lhs_broadcast_dimensions())) {
- auto inputDimIndex = dimMap.index();
- auto outputDimIndex = dimMap.value().getZExtValue();
- assert(outputDimIndex < lhsDims.size());
- if (!resultRs.isDimDynamic(outputDimIndex)) {
- // No need to populate fully static dimensions.
- continue;
- }
- if (lhsRs.isDimDynamic(inputDimIndex)) {
- lhsDims[outputDimIndex] =
- std::make_pair(-1, builder.create<RankedDimOp>(
- loc, dimType, operands.lhs(),
- builder.getI64IntegerAttr(inputDimIndex)));
- } else {
- lhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
+ SmallVector<Value, 4> resultExtents;
+ for (auto t : llvm::zip(lhsResultExtents, rhsResultExtents)) {
+ auto lhsExtent = std::get<0>(t);
+ auto rhsExtent = std::get<1>(t);
+ auto ugt =
+ builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhsExtent, rhsExtent);
+ auto max = builder.create<SelectOp>(loc, ugt, lhsExtent, rhsExtent);
+ resultExtents.push_back(max);
+ // TODO(silvasean): Create error handling code for invalid broadcasts.
+ // Use vm.cond_fail (or something that lowers to that).
+ }
+
+ // MakeRankedShapeOp only accepts the dynamic dims, so filter appropriately.
+ SmallVector<Value, 4> filteredResultExtents;
+ for (int i = 0, e = resultRs.getRank(); i < e; i++) {
+ if (resultRs.isDimDynamic(i)) {
+ filteredResultExtents.push_back(resultExtents[i]);
}
}
- // Populate the rhs dims.
- for (auto dimMap : llvm::enumerate(bcastOp.rhs_broadcast_dimensions())) {
- auto inputDimIndex = dimMap.index();
- auto outputDimIndex = dimMap.value().getZExtValue();
- assert(outputDimIndex < rhsDims.size());
- if (!resultRs.isDimDynamic(outputDimIndex)) {
- // No need to populate fully static dimensions.
- continue;
- }
- if (rhsRs.isDimDynamic(inputDimIndex)) {
- rhsDims[outputDimIndex] =
- std::make_pair(-1, builder.create<RankedDimOp>(
- loc, dimType, operands.rhs(),
- builder.getI64IntegerAttr(inputDimIndex)));
- } else {
- rhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
- }
- }
-
- // Now compute dynamic dims for each output dim.
- SmallVector<Value, 4> dynamicDims;
- for (int i = 0; i < lhsDims.size(); ++i) {
- if (!resultRs.isDimDynamic(i)) continue;
- auto lhsDimInfo = lhsDims[i];
- auto lhsDimSize = lhsDimInfo.first ? *lhsDimInfo.first : -1;
- auto rhsDimInfo = rhsDims[i];
- auto rhsDimSize = rhsDimInfo.first ? *rhsDimInfo.first : -1;
-
- if (lhsDimSize > 1) {
- // Non-degenerate static.
- bcastOp.emitRemark(
- "broadcast of non-degenerate lhs static dim not implemented");
- return nullptr;
- } else if (rhsDimSize > 1) {
- // Non-degenerate static.
- bcastOp.emitRemark(
- "broadcast of non-degenerate rhs static dim not implemented");
- return nullptr;
- } else if (lhsDimSize == 1) {
- // Degenerate static.
- bcastOp.emitRemark(
- "broadcast of degenerate lhs static dim not implemented");
- return nullptr;
- } else if (rhsDimSize == 1) {
- // Degenerate static.
- bcastOp.emitRemark(
- "broadcast of degenerate rhs static dim not implemented");
- return nullptr;
- } else {
- // Dynamic.
- // TODO: Generate code to assert.
- if (lhsDimInfo.second) {
- dynamicDims.push_back(lhsDimInfo.second);
- } else if (rhsDimInfo.second) {
- dynamicDims.push_back(rhsDimInfo.second);
- } else {
- return nullptr;
- }
- }
- }
-
- // And make the result shape.
- return builder.create<MakeRankedShapeOp>(loc, resultRs, dynamicDims);
+ return builder.create<MakeRankedShapeOp>(loc, resultRs,
+ filteredResultExtents);
}
LogicalResult expandRankedBroadcastShapePattern(
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
index 4ff445e..3fbc213 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
@@ -17,6 +17,23 @@
}
// -----
+
+// CHECK-LABEL: @f
+func @f(%arg0 : !shapex.ranked_shape<[?]>, %arg1 : !shapex.ranked_shape<[?]>) -> (!shapex.ranked_shape<[?]>) {
+ // CHECK-DAG: %[[LHSEXTENT:.+]] = shapex.ranked_dim %arg0[0]
+ // CHECK-DAG: %[[RHSEXTENT:.+]] = shapex.ranked_dim %arg1[0]
+ // CHECK-DAG: %[[GT:.+]] = cmpi "ugt", %[[LHSEXTENT]], %[[RHSEXTENT]] : index
+ // CHECK-DAG: %[[MAX:.+]] = select %[[GT]], %[[LHSEXTENT]], %[[RHSEXTENT]] : index
+ // CHECK-DAG: %[[RS:.+]] = shapex.make_ranked_shape %[[MAX]]
+ // CHECK-DAG: return %[[RS]]
+ %0 = "shapex.ranked_broadcast_shape"(%arg0, %arg1) {
+ lhs_broadcast_dimensions = dense<[0]> : tensor<1xi64>,
+ rhs_broadcast_dimensions = dense<[0]> : tensor<1xi64>
+ } : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+ return %0 : !shapex.ranked_shape<[?]>
+}
+
+// -----
// CHECK-LABEL: @runTimeFallback
// CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<?x2xf32>
// CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shapex.ranked_shape<[?,2]>