Add test to verify support for fusion in Issue #3579 (#3660)
Closes #3579
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir b/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
index d383a9a..402166f 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
@@ -89,3 +89,35 @@
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG1]] :
// CHECK-SAME: outs(%[[RET0]] :
+
+// -----
+
+module {
+ func @issue_3579() {
+ %c0 = constant 0 : index
+ %cst_1 = constant dense<1.000000e+00> : tensor<1x10xf32>
+ %4 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<5x1x1xf32>
+ %5 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<i32>
+ %6 = "mhlo.torch_index_select"(%4, %5) {batch_dims = 0 : i64, dim = 0 : i64} : (tensor<5x1x1xf32>, tensor<i32>) -> tensor<1x1xf32>
+ %7 = "mhlo.reshape"(%6) : (tensor<1x1xf32>) -> tensor<1xf32>
+ %8 = "mhlo.broadcast_in_dim"(%7) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>
+ %9 = mhlo.multiply %8, %cst_1 : tensor<1x10xf32>
+ hal.interface.store.tensor %9, @legacy_io::@ret0, offset = %c0 : tensor<1x10xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// CHECK-LABEL: func @issue_3579
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<5x1x1xf32>
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<i32>
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<10xf32>
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: ins(%[[ARG1]] : memref<i32>)
+// CHECK-SAME: outs(%[[RET0]] : memref<10xf32>
+// CHECK: load %[[ARG0]]
+// CHECK: linalg.yield