Add support from `tensor.from_elements` to `flow` with multiple elements (#18034)

https://github.com/iree-org/iree/issues/17086

We can support multipel `tensor.from_elements` through using
`flow.tensor.store` for each elements into an empty array.

---------

Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
diff --git a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json
index 8d461e7..3646f61 100644
--- a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json
+++ b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json
@@ -86,16 +86,13 @@
     "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded",
     "onnx/node/generated/test_center_crop_pad_crop",
     "onnx/node/generated/test_center_crop_pad_crop_and_pad",
-    "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded",
-    "onnx/node/generated/test_center_crop_pad_crop_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded",
     "onnx/node/generated/test_center_crop_pad_pad",
-    "onnx/node/generated/test_center_crop_pad_pad_expanded",
     "onnx/node/generated/test_col2im",
     "onnx/node/generated/test_col2im_5d",
     "onnx/node/generated/test_col2im_dilations",
diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json
index a41a6e0..d2689f5 100644
--- a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json
+++ b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json
@@ -92,16 +92,13 @@
     "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded",
     "onnx/node/generated/test_center_crop_pad_crop",
     "onnx/node/generated/test_center_crop_pad_crop_and_pad",
-    "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded",
-    "onnx/node/generated/test_center_crop_pad_crop_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded",
     "onnx/node/generated/test_center_crop_pad_pad",
-    "onnx/node/generated/test_center_crop_pad_pad_expanded",
     "onnx/node/generated/test_col2im",
     "onnx/node/generated/test_col2im_5d",
     "onnx/node/generated/test_col2im_dilations",
@@ -158,9 +155,7 @@
     "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0",
     "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
     "onnx/node/generated/test_group_normalization_epsilon",
-    "onnx/node/generated/test_group_normalization_epsilon_expanded",
     "onnx/node/generated/test_group_normalization_example",
-    "onnx/node/generated/test_group_normalization_example_expanded",
     "onnx/node/generated/test_gru_batchwise",
     "onnx/node/generated/test_gru_defaults",
     "onnx/node/generated/test_gru_seq_length",
diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json
index c6a9a35..06beb22 100644
--- a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json
+++ b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json
@@ -87,16 +87,13 @@
     "onnx/node/generated/test_castlike_STRING_to_FLOAT_expanded",
     "onnx/node/generated/test_center_crop_pad_crop",
     "onnx/node/generated/test_center_crop_pad_crop_and_pad",
-    "onnx/node/generated/test_center_crop_pad_crop_and_pad_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw",
     "onnx/node/generated/test_center_crop_pad_crop_axes_chw_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_axes_hwc_expanded",
-    "onnx/node/generated/test_center_crop_pad_crop_expanded",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc",
     "onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc_expanded",
     "onnx/node/generated/test_center_crop_pad_pad",
-    "onnx/node/generated/test_center_crop_pad_pad_expanded",
     "onnx/node/generated/test_col2im",
     "onnx/node/generated/test_col2im_5d",
     "onnx/node/generated/test_col2im_dilations",
@@ -153,9 +150,7 @@
     "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0",
     "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
     "onnx/node/generated/test_group_normalization_epsilon",
-    "onnx/node/generated/test_group_normalization_epsilon_expanded",
     "onnx/node/generated/test_group_normalization_example",
-    "onnx/node/generated/test_group_normalization_example_expanded",
     "onnx/node/generated/test_gru_batchwise",
     "onnx/node/generated/test_gru_defaults",
     "onnx/node/generated/test_gru_seq_length",
diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json
index bf88b19..bbdeac9 100644
--- a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json
+++ b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json
@@ -187,7 +187,6 @@
     "onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
     "onnx/node/generated/test_gridsample_zeros_padding",
     "onnx/node/generated/test_group_normalization_epsilon",
-    "onnx/node/generated/test_group_normalization_epsilon_expanded",
     "onnx/node/generated/test_group_normalization_example",
     "onnx/node/generated/test_group_normalization_example_expanded",
     "onnx/node/generated/test_gru_batchwise",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
index 0a03362..d5794af 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
@@ -187,17 +187,33 @@
     }
     auto tensorType = op.getType();
     if (!tensorType.hasRank()) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "unranked result type not supported");
     }
 
-    // Check that all the dimensions are 1.
-    if (!llvm::all_of(tensorType.getShape(),
-                      [](int64_t dim) { return dim == 1; })) {
-      return failure();
+    if (op.getNumOperands() == 1) {
+      rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
+          op, tensorType, op.getOperand(0), ValueRange());
+      return success();
     }
 
-    rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
-        op, tensorType, op.getOperand(0), ValueRange());
+    const int64_t rank = tensorType.getRank();
+    Value result = rewriter.create<tensor::EmptyOp>(
+        op.getLoc(), tensorType.getShape(), tensorType.getElementType());
+    SmallVector<Value> ivs(rank);
+    for (int i = 0, s = op.getNumOperands(); i < s; ++i) {
+      int64_t index = i;
+      for (int j = rank - 1; j >= 0; --j) {
+        int64_t iv = index % tensorType.getDimSize(j);
+        index = index / tensorType.getDimSize(j);
+        ivs[j] = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), iv);
+      }
+
+      result = rewriter.create<Flow::TensorStoreOp>(
+          op.getLoc(), op.getOperand(i), result, ivs);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
index 13489f8..087d0e3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
@@ -10,18 +10,6 @@
 }
 
 // -----
-// CHECK: util.func public @tensor.from_elements__not_convertible(%[[arg0:.*]]: i8)
-util.func public @tensor.from_elements__not_convertible(%arg0: i8) -> (i8) {
-  // CHECK: %[[c0:.*]] = arith.constant 0
-  %c0 = arith.constant 0 : index
-  // CHECK: %[[res:.*]] = tensor.from_elements %[[arg0]], %[[arg0]] : tensor<2xi8>
-  %0 = tensor.from_elements %arg0, %arg0 : tensor<2xi8>
-  // CHECK: flow.tensor.load %[[res]][%[[c0]]]
-  %1 = flow.tensor.load %0[%c0] : tensor<2xi8>
-  util.return %1 : i8
-}
-
-// -----
 util.func public @tensor.from_elements__within_dispatch_workgroups_not_converted() -> tensor<f32> {
   %x = arith.constant 100 : index
   %0 = flow.dispatch.workgroups[%x]() : () -> (tensor<f32>) = () {
@@ -44,3 +32,21 @@
 // CHECK-SAME:     %[[ARG0:.+]]: f32
 //      CHECK:   %[[SPLAT:.+]] = flow.tensor.splat %[[ARG0]] : tensor<f32>
 //      CHECK:   util.return %[[SPLAT]]
+
+// -----
+
+// CHECK-LABEL: util.func public @tensor.from_elements_2D
+util.func @tensor.from_elements_2D(%arg0 : f32, %arg1 : f32, %arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32) -> tensor<2x3xf32> {
+  %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3xf32>
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK:     %[[EMPTY:.+]] = tensor.empty() : tensor<2x3xf32>
+  // CHECK:     %[[STORE0:.+]] = flow.tensor.store %arg0, %[[EMPTY]][%[[C0]], %[[C0]]] : tensor<2x3xf32>
+  // CHECK:     %[[STORE1:.+]] = flow.tensor.store %arg1, %[[STORE0]][%[[C0]], %[[C1]]] : tensor<2x3xf32>
+  // CHECK:     %[[STORE2:.+]] = flow.tensor.store %arg2, %[[STORE1]][%[[C0]], %[[C2]]] : tensor<2x3xf32>
+  // CHECK:     %[[STORE3:.+]] = flow.tensor.store %arg3, %[[STORE2]][%[[C1]], %[[C0]]] : tensor<2x3xf32>
+  // CHECK:     %[[STORE4:.+]] = flow.tensor.store %arg4, %[[STORE3]][%[[C1]], %[[C1]]] : tensor<2x3xf32>
+  // CHECK:     %[[STORE5:.+]] = flow.tensor.store %arg5, %[[STORE4]][%[[C1]], %[[C2]]] : tensor<2x3xf32>
+  util.return %0 : tensor<2x3xf32>
+}