Handle 0D tensor while converting `tensor.from_elements` to `flow.tensor.splat (#8939)
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_std.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_std.run
index b942bd9..b8ef157 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_std.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_std.run
@@ -1,3 +1,2 @@
-# XFAIL: *
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=true --functions=reduce_std -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_variance.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_variance.run
index fe136e7..68fbe6d 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_variance.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__complex_dynamic_dim_reduce_variance.run
@@ -1,3 +1,2 @@
-# XFAIL: *
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=true --functions=reduce_variance -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_mean.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_mean.run
index 1cfde11..b6c34cf 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_mean.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_mean.run
@@ -1,3 +1,2 @@
-# XFAIL: *
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=true --functions=reduce_mean -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_std.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_std.run
index b942bd9..b8ef157 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_std.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_std.run
@@ -1,3 +1,2 @@
-# XFAIL: *
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=true --functions=reduce_std -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_variance.run b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_variance.run
index fe136e7..68fbe6d 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_variance.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/llvmaot__dynamic_dim_reduce_variance.run
@@ -1,3 +1,2 @@
-# XFAIL: *
# REQUIRES: llvmaot
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_llvmaot --dynamic_dims=true --functions=reduce_variance -artifacts_dir=%t
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_std.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_std.run
index f711012..60031d4 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_std.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_std.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=true --functions=reduce_std -artifacts_dir=%t
-# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_variance.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_variance.run
index 86df659..4019a71 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_variance.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__complex_dynamic_dim_reduce_variance.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=true --functions=reduce_variance -artifacts_dir=%t
-# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_mean.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_mean.run
index 3c39588..b8b8181 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_mean.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_mean.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=true --functions=reduce_mean -artifacts_dir=%t
-# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_std.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_std.run
index f711012..60031d4 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_std.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_std.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=true --functions=reduce_std -artifacts_dir=%t
-# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_variance.run b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_variance.run
index 86df659..4019a71 100644
--- a/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_variance.run
+++ b/integrations/tensorflow/test/iree_tf_tests/math/vulkan__dynamic_dim_reduce_variance.run
@@ -1,3 +1,2 @@
# REQUIRES: vulkan
# RUN: %PYTHON -m iree_tf_tests.math.math_test --target_backends=iree_vulkan --dynamic_dims=true --functions=reduce_variance -artifacts_dir=%t
-# XFAIL: *
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
index f9a2410..e7bb849 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
@@ -307,13 +307,22 @@
// TODO: This pattern was mainly added to iron out some kinks specific to
// detensoring (see: https://github.com/google/iree/issues/1159). Do we need
// to expand this check for other uses?
- if (op->getParentOfType<Flow::DispatchWorkgroupsOp>() ||
- op.getType().getDimSize(0) != 1) {
+ if (op->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+ return failure();
+ }
+ auto tensorType = op.getType();
+ if (!tensorType.hasRank()) {
+ return failure();
+ }
+
+ // Check that all the dimensions are 1.
+ if (!llvm::all_of(tensorType.getShape(),
+ [](int64_t dim) { return dim == 1; })) {
return failure();
}
rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
- op, op.getType(), op.getOperand(0), ValueRange());
+ op, tensorType, op.getOperand(0), ValueRange());
return success();
}
};
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
index 2dbbd01..f1af810 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
@@ -33,3 +33,14 @@
}
return %0 : tensor<f32>
}
+
+// -----
+
+func.func @tensor.from_elements_0D(%arg0 : f32) -> tensor<f32> {
+ %0 = tensor.from_elements %arg0 : tensor<f32>
+ return %0 : tensor<f32>
+}
+// CHECK: func @tensor.from_elements_0D
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[SPLAT:.+]] = flow.tensor.splat %[[ARG0]] : tensor<f32>
+// CHECK: return %[[SPLAT]]