Rework lowering of shapex.ranked_broadcast_shape

This makes it a lot simpler and makes it actually correct in the dynamic
case. (it would previously miscompile)

There's still a TODO for emitting an error, which should be implementable
soon (waiting on lower-level parts of the stack).

PiperOrigin-RevId: 307537788
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 19b4ef6..5df0f1b 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -36,6 +36,7 @@
         ],
     )
     for name in [
+        "broadcasting_test",
         "batch_norm_test",
         "fill_test",
         "control_flow_test",
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
new file mode 100644
index 0000000..f2f74d3
--- /dev/null
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test broadcasting support."""
+
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class BroadcastingModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None], tf.float32),
+      tf.TensorSpec([None], tf.float32),
+  ])
+  def add(self, lhs, rhs):
+    return lhs + rhs
+
+
+@tf_test_utils.compile_modules(
+    backends=["tf", "iree_vmla"], m=BroadcastingModule)
+class BroadcastingTest(tf_test_utils.SavedModelTestCase):
+
+  def test_add_same_shape(self):
+    m = self.modules.m.all
+    dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
+    dst.print().assert_all_close()
+
+
+# TODO(silvasean): Make these work.
+#   def test_add_broadcast_lhs(self):
+#     m = self.modules.m.all
+#     dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
+#     dst.print().assert_all_close()
+#
+#   def test_add_broadcast_rhs(self):
+#     m = self.modules.m.all
+#     dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
+#     dst.print().assert_all_close()
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index 0ca9d19..f73c3b9 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -53,103 +53,54 @@
 }
 
 Value rewriteShapexRankedBroadcastShape(
-    RankedBroadcastShapeOp bcastOp,
-    RankedBroadcastShapeOp::OperandAdaptor operands, OpBuilder &builder) {
-  auto lhsRs = operands.lhs().getType().cast<RankedShapeType>();
-  auto rhsRs = operands.rhs().getType().cast<RankedShapeType>();
+    RankedBroadcastShapeOp op, RankedBroadcastShapeOp::OperandAdaptor operands,
+    OpBuilder &builder) {
+  auto lhs = operands.lhs();
+  auto rhs = operands.rhs();
+  auto loc = op.getLoc();
+  auto resultRs = op.getResult().getType().cast<RankedShapeType>();
 
-  auto loc = bcastOp.getLoc();
-  auto resultRs = bcastOp.getResult().getType().cast<RankedShapeType>();
-  auto dimType = IndexType::get(builder.getContext());
+  auto c1 = builder.create<ConstantIndexOp>(loc, 1);
+  // Entries are the extent of the output along that dimension corresponding to
+  // the given side, or 1 (which is neutral w.r.t. broadcasting).
+  SmallVector<Value, 4> lhsResultExtents(resultRs.getRank(), c1);
+  SmallVector<Value, 4> rhsResultExtents(resultRs.getRank(), c1);
 
-  // Pairs of the shape dim and corresponding value if dynamic.
-  SmallVector<std::pair<Optional<int>, Value>, 4> lhsDims;
-  SmallVector<std::pair<Optional<int>, Value>, 4> rhsDims;
-  lhsDims.resize(resultRs.getRank());
-  rhsDims.resize(resultRs.getRank());
+  for (auto dim : llvm::enumerate(op.lhs_broadcast_dimensions())) {
+    auto inputDim = dim.index();
+    auto outputDim = dim.value().getZExtValue();
+    lhsResultExtents[outputDim] =
+        builder.create<RankedDimOp>(loc, lhs, inputDim);
+  }
+  for (auto dim : llvm::enumerate(op.rhs_broadcast_dimensions())) {
+    auto inputDim = dim.index();
+    auto outputDim = dim.value().getZExtValue();
+    rhsResultExtents[outputDim] =
+        builder.create<RankedDimOp>(loc, rhs, inputDim);
+  }
 
-  // Populate the lhs dims.
-  for (auto dimMap : llvm::enumerate(bcastOp.lhs_broadcast_dimensions())) {
-    auto inputDimIndex = dimMap.index();
-    auto outputDimIndex = dimMap.value().getZExtValue();
-    assert(outputDimIndex < lhsDims.size());
-    if (!resultRs.isDimDynamic(outputDimIndex)) {
-      // No need to populate fully static dimensions.
-      continue;
-    }
-    if (lhsRs.isDimDynamic(inputDimIndex)) {
-      lhsDims[outputDimIndex] =
-          std::make_pair(-1, builder.create<RankedDimOp>(
-                                 loc, dimType, operands.lhs(),
-                                 builder.getI64IntegerAttr(inputDimIndex)));
-    } else {
-      lhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
+  SmallVector<Value, 4> resultExtents;
+  for (auto t : llvm::zip(lhsResultExtents, rhsResultExtents)) {
+    auto lhsExtent = std::get<0>(t);
+    auto rhsExtent = std::get<1>(t);
+    auto ugt =
+        builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhsExtent, rhsExtent);
+    auto max = builder.create<SelectOp>(loc, ugt, lhsExtent, rhsExtent);
+    resultExtents.push_back(max);
+    // TODO(silvasean): Create error handling code for invalid broadcasts.
+    // Use vm.cond_fail (or something that lowers to that).
+  }
+
+  // MakeRankedShapeOp only accepts the dynamic dims, so filter appropriately.
+  SmallVector<Value, 4> filteredResultExtents;
+  for (int i = 0, e = resultRs.getRank(); i < e; i++) {
+    if (resultRs.isDimDynamic(i)) {
+      filteredResultExtents.push_back(resultExtents[i]);
     }
   }
 
-  // Populate the rhs dims.
-  for (auto dimMap : llvm::enumerate(bcastOp.rhs_broadcast_dimensions())) {
-    auto inputDimIndex = dimMap.index();
-    auto outputDimIndex = dimMap.value().getZExtValue();
-    assert(outputDimIndex < rhsDims.size());
-    if (!resultRs.isDimDynamic(outputDimIndex)) {
-      // No need to populate fully static dimensions.
-      continue;
-    }
-    if (rhsRs.isDimDynamic(inputDimIndex)) {
-      rhsDims[outputDimIndex] =
-          std::make_pair(-1, builder.create<RankedDimOp>(
-                                 loc, dimType, operands.rhs(),
-                                 builder.getI64IntegerAttr(inputDimIndex)));
-    } else {
-      rhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
-    }
-  }
-
-  // Now compute dynamic dims for each output dim.
-  SmallVector<Value, 4> dynamicDims;
-  for (int i = 0; i < lhsDims.size(); ++i) {
-    if (!resultRs.isDimDynamic(i)) continue;
-    auto lhsDimInfo = lhsDims[i];
-    auto lhsDimSize = lhsDimInfo.first ? *lhsDimInfo.first : -1;
-    auto rhsDimInfo = rhsDims[i];
-    auto rhsDimSize = rhsDimInfo.first ? *rhsDimInfo.first : -1;
-
-    if (lhsDimSize > 1) {
-      // Non-degenerate static.
-      bcastOp.emitRemark(
-          "broadcast of non-degenerate lhs static dim not implemented");
-      return nullptr;
-    } else if (rhsDimSize > 1) {
-      // Non-degenerate static.
-      bcastOp.emitRemark(
-          "broadcast of non-degenerate rhs static dim not implemented");
-      return nullptr;
-    } else if (lhsDimSize == 1) {
-      // Degenerate static.
-      bcastOp.emitRemark(
-          "broadcast of degenerate lhs static dim not implemented");
-      return nullptr;
-    } else if (rhsDimSize == 1) {
-      // Degenerate static.
-      bcastOp.emitRemark(
-          "broadcast of degenerate rhs static dim not implemented");
-      return nullptr;
-    } else {
-      // Dynamic.
-      // TODO: Generate code to assert.
-      if (lhsDimInfo.second) {
-        dynamicDims.push_back(lhsDimInfo.second);
-      } else if (rhsDimInfo.second) {
-        dynamicDims.push_back(rhsDimInfo.second);
-      } else {
-        return nullptr;
-      }
-    }
-  }
-
-  // And make the result shape.
-  return builder.create<MakeRankedShapeOp>(loc, resultRs, dynamicDims);
+  return builder.create<MakeRankedShapeOp>(loc, resultRs,
+                                           filteredResultExtents);
 }
 
 LogicalResult expandRankedBroadcastShapePattern(
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
index 4ff445e..3fbc213 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
+++ b/iree/compiler/Dialect/Shape/Transforms/test/materialize_shape_calculations.mlir
@@ -17,6 +17,23 @@
 }
 
 // -----
+
+// CHECK-LABEL: @f
+func @f(%arg0 : !shapex.ranked_shape<[?]>, %arg1 : !shapex.ranked_shape<[?]>) -> (!shapex.ranked_shape<[?]>) {
+  // CHECK-DAG: %[[LHSEXTENT:.+]] = shapex.ranked_dim %arg0[0]
+  // CHECK-DAG: %[[RHSEXTENT:.+]] = shapex.ranked_dim %arg1[0]
+  // CHECK-DAG: %[[GT:.+]] = cmpi "ugt", %[[LHSEXTENT]], %[[RHSEXTENT]] : index
+  // CHECK-DAG: %[[MAX:.+]] = select %[[GT]], %[[LHSEXTENT]], %[[RHSEXTENT]] : index
+  // CHECK-DAG: %[[RS:.+]] = shapex.make_ranked_shape %[[MAX]]
+  // CHECK-DAG: return %[[RS]]
+  %0 = "shapex.ranked_broadcast_shape"(%arg0, %arg1) {
+    lhs_broadcast_dimensions = dense<[0]> : tensor<1xi64>,
+    rhs_broadcast_dimensions = dense<[0]> : tensor<1xi64>
+  } : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+  return %0 : !shapex.ranked_shape<[?]>
+}
+
+// -----
 // CHECK-LABEL: @runTimeFallback
 // CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<?x2xf32>
 // CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shapex.ranked_shape<[?,2]>