Enable fusion for elementwise Linalg op + pack op (#11374)

It also updates the LinalgExtVectorization to use tile+fuse, so we can
tile+fuse the generic ops.

Some metric data w/ mobilebert fp32:

The number of dispatches:

- Legacy mmt4d: 39
- data tiling w/o fusion: 57
- data tiling w/ pack fusion: 59

It's reasonable for having more different dispatches because some of
different set_encoding ops could be folded into same producer
dispatches. E.g., we could have dispatch_A, LHS_encoding, RHS_encoding
in the beginning. After more aggressive fusion, we could get `dispatch_A
+ LHS_encoding` + `dispatch_A + RHS_encoding` + `LHS_encoding` +
`RHS_encoding`. There would be 4 dispatches after fusion. We should use
the metric about the number of kernel launch.

The number of `flow.dispatch` launch:

- Legacy mmt4d: 1980
- data tiling w/o fusion: 2871
- data tiling w/ pack fusion: 2750

The legacy mmt4d path has less kernel launches because

1. Need unpack op fusion, which is WIP.
2. Propagation helps better fusion.
3. We don't have canonicalization patterns for packing on constant.

I verified that (3.) can save 361 times of kernel launch, tracking in
https://github.com/iree-org/iree/issues/11360

Relands https://github.com/iree-org/iree/pull/11284 with fixes for
mid-air collision.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
index 551ec12..e35121c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -27,6 +27,7 @@
             "check_ir_before_llvm_conversion.mlir",
             "check_ir_before_llvm_conversion_not_fail_unbound.mlir",
             "convert_to_llvm.mlir",
+            "data_tiling_pipeline.mlir",
             "emit_vectorization_remarks.mlir",
             "hal_executable_constants.mlir",
             "hal_interface_bindings.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index e899873..348d33b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -22,6 +22,7 @@
     "check_ir_before_llvm_conversion.mlir"
     "check_ir_before_llvm_conversion_not_fail_unbound.mlir"
     "convert_to_llvm.mlir"
+    "data_tiling_pipeline.mlir"
     "emit_vectorization_remarks.mlir"
     "hal_executable_constants.mlir"
     "hal_interface_bindings.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
new file mode 100644
index 0000000..4b6bc58
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))' --split-input-file %s | FileCheck %s
+
+hal.executable private @elem_pack {
+  hal.executable.variant public @embedded_elf_x86_64, target = <"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}> {
+    hal.executable.export public @elem_pack ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<CPUDataTiling>} {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+      %c1 = arith.constant 1 : index
+      %0 = affine.apply affine_map<()[s0] -> ((s0 ceildiv 8) ceildiv 64)>()[%arg1]
+      %1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%arg2]
+      hal.return %1, %0, %c1 : index, index, index
+    }
+    builtin.module {
+      func.func @elem_pack() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<128x384xf32>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<16x384x8x1xf32>>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x384xf32>> -> tensor<128x384xf32>
+        %3 = tensor.empty() : tensor<128x384xf32>
+        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<128x384xf32>) outs(%3 : tensor<128x384xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %7 = arith.addf %in, %in : f32
+          linalg.yield %7 : f32
+        } -> tensor<128x384xf32>
+        %5 = tensor.empty() : tensor<16x384x8x1xf32>
+        %6 = iree_linalg_ext.pack %4 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %5 : (tensor<128x384xf32> tensor<16x384x8x1xf32>) -> tensor<16x384x8x1xf32>
+        flow.dispatch.tensor.store %6, %1, offsets = [0, 0, 0, 0], sizes = [16, 384, 8, 1], strides = [1, 1, 1, 1] : tensor<16x384x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<16x384x8x1xf32>>
+        return
+      }
+    }
+  }
+}
+// CHECK: func.func @elem_pack
+// CHECK:   %[[READ:.+]] = vector.transfer_read
+// CHECK:   %[[ADD:.+]] = arith.addf %[[READ]], %[[READ]]
+// CHECK:   %[[BCAST:.+]] = vector.broadcast %[[ADD]]
+// CHECK:   vector.transfer_write %[[BCAST]], %{{.+}}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index bd3d50f..5798b67 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -619,6 +619,14 @@
   Operation *producer = operand.get().getDefiningOp();
   Operation *consumer = operand.getOwner();
 
+  auto linalgProducerOp = dyn_cast<linalg::LinalgOp>(producer);
+  auto setEncodingOp = dyn_cast<IREE::LinalgExt::SetEncodingOp>(consumer);
+  if (linalgProducerOp && setEncodingOp) {
+    return linalg::isElementwise(linalgProducerOp) &&
+           linalgProducerOp.getNumLoops() ==
+               setEncodingOp.getSourceType().getRank();
+  }
+
   if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer)) {
     return false;
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir
index d2f8db0..ff595a7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_default.mlir
@@ -59,3 +59,28 @@
 // CHECK-LABEL: func.func @reduction_broadcast_elementwise_type_mismatch
 //      CHECK: flow.dispatch.workgroups
 //      CHECK: flow.dispatch.workgroups
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d1)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_set_encoding(%arg0: tensor<512xf32>, %arg1: tensor<384x512xf32>,
+    %arg2: tensor<384x512xf32>) -> tensor<384x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>> {
+  %0 = tensor.empty() : tensor<384x512xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map1],
+                       iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<512xf32>, tensor<384x512xf32>, tensor<384x512xf32>)
+    outs(%0 : tensor<384x512xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32):
+    %3 = arith.addf %in, %in_0 : f32
+    %4 = arith.addf %3, %in_1 : f32
+    linalg.yield %4 : f32
+  } -> tensor<384x512xf32>
+  %2 = iree_linalg_ext.set_encoding %1 : tensor<384x512xf32> -> tensor<384x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+  return %2 : tensor<384x512xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>>
+}
+// CHECK-LABEL: func.func @elem_set_encoding
+// CHECK:         flow.dispatch.workgroups
+// CHECK:           linalg.generic
+// CHECK:           iree_linalg_ext.set_encoding
+// CHECK-NOT:     flow.dispatch.workgroups
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index 1283374..725d2ad 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -952,7 +952,7 @@
     // Apply tiling to make outer dims be all 1s.
     {
       SimpleRewriter rewriter(ctx);
-      auto packTilingOptions =
+      auto packOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
           scf::SCFTilingOptions().setTileSizeComputationFunction(
               [](OpBuilder &builder, Operation *op) {
                 Location loc = op->getLoc();
@@ -960,15 +960,16 @@
                 SmallVector<Value> tileSizes(
                     inputRank, builder.create<arith::ConstantIndexOp>(loc, 1));
                 return tileSizes;
-              });
+              }));
       auto funcOp = getOperation();
       funcOp->walk([&](LinalgExt::PackOp op) {
-        FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCFForOp(
-            rewriter, cast<TilingInterface>(op.getOperation()),
-            packTilingOptions);
-        if (failed(tilingResult))
+        FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+            scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op,
+                                                                  packOptions);
+        if (failed(tileAndFuseResult))
           return signalPassFailure();
-        rewriter.replaceOp(op, tilingResult->replacements);
+        rewriter.replaceOp(op,
+                           tileAndFuseResult->replacements[op.getResult(0)]);
       });
 
       auto unpackTilingOptions =