Merge pull request #2812 from hanhanW:main-to-google

PiperOrigin-RevId: 325296102
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index b0b66eb..3d064c9 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -58,9 +58,6 @@
 
 # keep sorted
 VMLA_FAILING = [
-    "broadcasting_test.py",  # TODO(b/162816314)
-    "dynamic_mlp_relu_test.py",  # TODO(b/162816314)
-    "dynamic_mlp_test.py",  # TODO(b/162816314)
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "ring_buffer_test.py",  # TODO(b/148747011)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 568327a..fc3143d 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -54,7 +54,6 @@
     return tf.nn.depthwise_conv2d(
         img, kernel, [1, 2, 2, 1], "SAME", name="result")
 
-
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 4], tf.float32),
       tf.TensorSpec([2, 4, 4, 1], tf.float32),
@@ -113,7 +112,6 @@
     self.compare_backends(batched_feature_padded_same_stride_1_output_1)
 
 
-
 if __name__ == "__main__":
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
diff --git a/iree/compiler/Dialect/Shape/Conversion/BUILD b/iree/compiler/Dialect/Shape/Conversion/BUILD
index 11ccfa3..de795a4 100644
--- a/iree/compiler/Dialect/Shape/Conversion/BUILD
+++ b/iree/compiler/Dialect/Shape/Conversion/BUILD
@@ -13,6 +13,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Shape",
+        "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
     ],
 )
diff --git a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
index 1f28cb4..7e51c71 100644
--- a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@
     MLIRIR
     MLIRPass
     MLIRShape
+    MLIRStandardOps
     MLIRTransforms
     iree::compiler::Dialect::Shape::IR
   PUBLIC
diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
index decf426..2dea237 100644
--- a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
+++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
@@ -16,6 +16,7 @@
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
@@ -208,6 +209,22 @@
   }
 };
 
+// Currently, upstream shape lowering can use tensor<?xindex> to represent a
+// shape, and will insert tensor_cast ops to convert to specific extent tensor
+// types. However, not all tensor_cast ops are shape-related.
+class ConvertTensorCastOp : public OpConversionPattern<TensorCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      TensorCastOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!operands[0].getType().isa<RankedShapeType>())
+      return rewriter.notifyMatchFailure(op, "not a shape-related tensor_cast");
+    rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
+                                                         operands[0]);
+    return success();
+  }
+};
+
 class ConvertShapeToShapex
     : public PassWrapper<ConvertShapeToShapex, OperationPass<ModuleOp>> {
   void runOnOperation() override {
@@ -227,6 +244,7 @@
     patterns.insert<ConvertBroadcastOp>(context);
     patterns.insert<ConvertConcatOp>(context);
     patterns.insert<ConvertToExtentTensorOp>(context);
+    patterns.insert<ConvertTensorCastOp>(context);
 
     if (failed(applyPartialConversion(module, conversionTarget, patterns))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
index ea613c9..d470afd 100644
--- a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
+++ b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
@@ -91,3 +91,22 @@
   return
 }
 
+// -----
+// tensor_cast
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+  %1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<?xindex>
+  // CHECK: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_shape"(%[[LHSRS]], %[[RHSRS]]) {
+  // CHECK-SAME: lhs_broadcast_dimensions = dense<0> : tensor<1xi64>,
+  // CHECK-SAME: rhs_broadcast_dimensions = dense<0> : tensor<1xi64>}
+  // CHECK-SAME: : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  // CHECK: %[[EXTENTS:.+]] = "shapex.to_extent_tensor"(%[[BROADCASTED]]) : (!shapex.ranked_shape<[?]>) -> tensor<1xindex>
+  %3 = tensor_cast %2 : tensor<?xindex> to tensor<1xindex>
+  // CHECK: "foo.use"(%[[EXTENTS]])
+  "foo.use"(%3) : (tensor<1xindex>) -> ()
+  return
+}
diff --git a/iree/test/e2e/vulkan_specific/gemm.mlir b/iree/test/e2e/vulkan_specific/gemm.mlir
index 55afab9..2833da4 100644
--- a/iree/test/e2e/vulkan_specific/gemm.mlir
+++ b/iree/test/e2e/vulkan_specific/gemm.mlir
@@ -39,7 +39,7 @@
        [0.92783505, 0.93298969, 0.93814433, 0.94329897, 0.94845361,
         0.95360825, 0.95876289, 0.96391753, 0.96907216, 0.9742268 ,
         0.97938144, 0.98453608, 0.98969072, 0.99484536, 1.        ]]>
-	: tensor<13x15xf32>
+    : tensor<13x15xf32>
   %1 = iree.unfoldable_constant dense<
        [[0.        , 0.00558659, 0.01117318, 0.01675978, 0.02234637,
         0.02793296, 0.03351955, 0.03910615, 0.04469274, 0.05027933,