Merge pull request #2812 from hanhanW:main-to-google
PiperOrigin-RevId: 325296102
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index b0b66eb..3d064c9 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -58,9 +58,6 @@
# keep sorted
VMLA_FAILING = [
- "broadcasting_test.py", # TODO(b/162816314)
- "dynamic_mlp_relu_test.py", # TODO(b/162816314)
- "dynamic_mlp_test.py", # TODO(b/162816314)
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"ring_buffer_test.py", # TODO(b/148747011)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 568327a..fc3143d 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -54,7 +54,6 @@
return tf.nn.depthwise_conv2d(
img, kernel, [1, 2, 2, 1], "SAME", name="result")
-
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 4], tf.float32),
tf.TensorSpec([2, 4, 4, 1], tf.float32),
@@ -113,7 +112,6 @@
self.compare_backends(batched_feature_padded_same_stride_1_output_1)
-
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
diff --git a/iree/compiler/Dialect/Shape/Conversion/BUILD b/iree/compiler/Dialect/Shape/Conversion/BUILD
index 11ccfa3..de795a4 100644
--- a/iree/compiler/Dialect/Shape/Conversion/BUILD
+++ b/iree/compiler/Dialect/Shape/Conversion/BUILD
@@ -13,6 +13,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
index 1f28cb4..7e51c71 100644
--- a/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@
MLIRIR
MLIRPass
MLIRShape
+ MLIRStandardOps
MLIRTransforms
iree::compiler::Dialect::Shape::IR
PUBLIC
diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
index decf426..2dea237 100644
--- a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
+++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
@@ -208,6 +209,22 @@
}
};
+// Currently, upstream shape lowering can use tensor<?xindex> to represent a
+// shape, and will insert tensor_cast ops to convert to specific extent tensor
+// types. However, not all tensor_cast ops are shape-related.
+class ConvertTensorCastOp : public OpConversionPattern<TensorCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ TensorCastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!operands[0].getType().isa<RankedShapeType>())
+ return rewriter.notifyMatchFailure(op, "not a shape-related tensor_cast");
+ rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
+ operands[0]);
+ return success();
+ }
+};
+
class ConvertShapeToShapex
: public PassWrapper<ConvertShapeToShapex, OperationPass<ModuleOp>> {
void runOnOperation() override {
@@ -227,6 +244,7 @@
patterns.insert<ConvertBroadcastOp>(context);
patterns.insert<ConvertConcatOp>(context);
patterns.insert<ConvertToExtentTensorOp>(context);
+ patterns.insert<ConvertTensorCastOp>(context);
if (failed(applyPartialConversion(module, conversionTarget, patterns))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
index ea613c9..d470afd 100644
--- a/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
+++ b/iree/compiler/Dialect/Shape/Conversion/test/shape_to_shapex.mlir
@@ -91,3 +91,22 @@
return
}
+// -----
+// tensor_cast
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+ // CHECK: %[[LHSRS:.+]] = shapex.get_ranked_shape %arg0 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ // CHECK: %[[RHSRS:.+]] = shapex.get_ranked_shape %arg1 : tensor<?xf32> -> !shapex.ranked_shape<[?]>
+ %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+ %1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<?xindex>
+ // CHECK: %[[BROADCASTED:.+]] = "shapex.ranked_broadcast_shape"(%[[LHSRS]], %[[RHSRS]]) {
+ // CHECK-SAME: lhs_broadcast_dimensions = dense<0> : tensor<1xi64>,
+ // CHECK-SAME: rhs_broadcast_dimensions = dense<0> : tensor<1xi64>}
+ // CHECK-SAME: : (!shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?]>) -> !shapex.ranked_shape<[?]>
+ %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ // CHECK: %[[EXTENTS:.+]] = "shapex.to_extent_tensor"(%[[BROADCASTED]]) : (!shapex.ranked_shape<[?]>) -> tensor<1xindex>
+ %3 = tensor_cast %2 : tensor<?xindex> to tensor<1xindex>
+ // CHECK: "foo.use"(%[[EXTENTS]])
+ "foo.use"(%3) : (tensor<1xindex>) -> ()
+ return
+}
diff --git a/iree/test/e2e/vulkan_specific/gemm.mlir b/iree/test/e2e/vulkan_specific/gemm.mlir
index 55afab9..2833da4 100644
--- a/iree/test/e2e/vulkan_specific/gemm.mlir
+++ b/iree/test/e2e/vulkan_specific/gemm.mlir
@@ -39,7 +39,7 @@
[0.92783505, 0.93298969, 0.93814433, 0.94329897, 0.94845361,
0.95360825, 0.95876289, 0.96391753, 0.96907216, 0.9742268 ,
0.97938144, 0.98453608, 0.98969072, 0.99484536, 1. ]]>
- : tensor<13x15xf32>
+ : tensor<13x15xf32>
%1 = iree.unfoldable_constant dense<
[[0. , 0.00558659, 0.01117318, 0.01675978, 0.02234637,
0.02793296, 0.03351955, 0.03910615, 0.04469274, 0.05027933,