[DataTiling] Add supports for materializing elementwise ops. (#15507)

This PR adds a conversion case for element-wise `linalg.generic` ops and
fixes a bug where the element type of `original_type` was used for
conversion in some places.
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index b231108..3d00bb7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -1169,3 +1169,76 @@
 // CHECK-SAME:       outs(%[[OUTS]] :
 //      CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
 // CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @extend_batch_vecmat(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+} {
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c11008 = arith.constant 11008 : index
+  %c0_i8 = arith.constant 0 : i8
+  %c0_i32 = arith.constant 0 : i32
+  %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<32x1x128xi8>
+  %1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<32x128x11008xi8>
+  %padded = tensor.pad %0 low[0, 0, 0] high[%c0, %c0, %c0] {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %c0_i8 : i8
+  } : tensor<32x1x128xi8> to tensor<?x?x?xi8>
+  %4 = iree_linalg_ext.set_encoding %padded : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+  %5 = tensor.empty(%c32, %c1, %c128) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>) outs(%5 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>) {
+  ^bb0(%in: i8, %out: i32):
+    %17 = arith.extsi %in : i8 to i32
+    linalg.yield %17 : i32
+  } -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+  %padded_0 = tensor.pad %1 low[0, 0, 0] high[%c0, %c0, %c0] {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %c0_i8 : i8
+  } : tensor<32x128x11008xi8> to tensor<?x?x?xi8>
+  %7 = iree_linalg_ext.set_encoding %padded_0 : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+  %8 = tensor.empty(%c32, %c128, %c11008) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7 : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) outs(%8 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) {
+  ^bb0(%in: i8, %out: i32):
+    %17 = arith.extsi %in : i8 to i32
+    linalg.yield %17 : i32
+  } -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+  %10 = tensor.empty(%c32, %c1, %c11008) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+  %11 = linalg.fill ins(%c0_i32 : i32) outs(%10 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>) -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+  %12 = linalg.batch_matmul ins(%6, %9 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>, tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) outs(%11 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>) -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+  %13 = iree_linalg_ext.unset_encoding %12 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>> -> tensor<?x?x?xi32>
+  %extracted_slice = tensor.extract_slice %13[0, 0, 0] [32, 1, 11008] [1, 1, 1] : tensor<?x?x?xi32> to tensor<32x1x11008xi32>
+  %16 = hal.tensor.export %extracted_slice "output 0" : tensor<32x1x11008xi32> -> !hal.buffer_view
+  return %16 : !hal.buffer_view
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//      CHECK: func @extend_batch_vecmat(%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view attributes
+//  CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+//  CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+//      CHECK: %[[LHS:.+]] = hal.tensor.import %[[ARG0]] "input 0" : !hal.buffer_view -> tensor<32x1x128xi8>
+//      CHECK: %[[RHS:.+]] = hal.tensor.import %[[ARG1]] "input 1" : !hal.buffer_view -> tensor<32x128x11008xi8>
+//      CHECK: %[[INIT_LHS_PACK:.+]] = tensor.empty() : tensor<32x1x64x1x2xi8>
+//      CHECK: %[[LHS_PACK:.+]] = tensor.pack %[[LHS]] padding_value(%[[C0_I8]] : i8) inner_dims_pos = [1, 2] inner_tiles = [1, 2] into %[[INIT_LHS_PACK]] : tensor<32x1x128xi8> -> tensor<32x1x64x1x2xi8>
+//      CHECK: %[[INIT_LHS_EXT:.+]] = tensor.empty() : tensor<32x1x64x1x2xi32>
+//      CHECK: %[[LHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[LHS_PACK]] : tensor<32x1x64x1x2xi8>) outs(%[[INIT_LHS_EXT]] : tensor<32x1x64x1x2xi32>) {
+// CHECK-NEXT:     ^bb0(%[[LHS_EXT_ARG_IN:.+]]: i8, %[[LHS_EXT_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[LHS_EXT_OP:.+]] = arith.extsi %[[LHS_EXT_ARG_IN]] : i8 to i32
+// CHECK-NEXT:     linalg.yield %[[LHS_EXT_OP]] : i32
+//      CHECK: %[[INIT_RHS_PACK:.+]] = tensor.empty() : tensor<32x688x64x16x2xi8>
+//      CHECK: %[[RHS_PACK:.+]] = tensor.pack %[[RHS]] padding_value(%[[C0_I8]] : i8) outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %[[INIT_RHS_PACK]] : tensor<32x128x11008xi8> -> tensor<32x688x64x16x2xi8>
+//      CHECK: %[[INIT_RHS_EXT:.+]] = tensor.empty() : tensor<32x688x64x16x2xi32>
+//      CHECK: %[[RHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[RHS_PACK]] : tensor<32x688x64x16x2xi8>) outs(%[[INIT_RHS_EXT]] : tensor<32x688x64x16x2xi32>) {
+// CHECK-NEXT:     ^bb0(%[[RHS_EXT_ARG_IN:.+]]: i8, %[[RHS_EXT_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[RHS_EXT_OP:.+]] = arith.extsi %[[RHS_EXT_ARG_IN]] : i8 to i32
+// CHECK-NEXT:     linalg.yield %[[RHS_EXT_OP]] : i32
+//      CHECK: %[[INIT_FILL:.+]] = tensor.empty() : tensor<32x1x688x1x16xi32>
+//      CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[INIT_FILL]] : tensor<32x1x688x1x16xi32>) -> tensor<32x1x688x1x16xi32>
+//      CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d ins(%[[LHS_EXT]], %[[RHS_EXT]] : tensor<32x1x64x1x2xi32>, tensor<32x688x64x16x2xi32>) outs(%[[FILL]] : tensor<32x1x688x1x16xi32>) -> tensor<32x1x688x1x16xi32>
+//      CHECK: %[[INIT_UNPACK:.+]] = tensor.empty() : tensor<32x1x11008xi32>
+//      CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] inner_dims_pos = [1, 2] inner_tiles = [1, 16] into %[[INIT_UNPACK]] : tensor<32x1x688x1x16xi32> -> tensor<32x1x11008xi32>
+//      CHECK: %[[EXPORT:.+]] = hal.tensor.export %[[UNPACK]] "output 0" : tensor<32x1x11008xi32> -> !hal.buffer_view
+//      CHECK: return %[[EXPORT]] : !hal.buffer_view
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 3ed53a5..a73e2d3 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -66,7 +66,8 @@
     return dropEncoding(tensorType);
   }
   return tensor::PackOp::inferPackedType(
-             getOriginalTypeWithEncoding(tensorType),
+             getOriginalTypeWithEncoding(tensorType)
+                 .clone(tensorType.getElementType()),
              materializeEncodingInfo->innerTileSizes,
              materializeEncodingInfo->innerDimsPos,
              materializeEncodingInfo->outerDimsPerm)
@@ -361,8 +362,9 @@
                     ValueRange convertedOperands,
                     MaterializeEncodingFn materializeEncodingFn,
                     MaterializeEncodingValueFn materializeEncodingValueFn) {
-  auto resultType = getOriginalTypeWithEncoding(
-      emptyOp->getResultTypes()[0].cast<RankedTensorType>());
+  auto emptyType = emptyOp->getResultTypes()[0].cast<RankedTensorType>();
+  auto resultType =
+      getOriginalTypeWithEncoding(emptyType).clone(emptyType.getElementType());
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
       materializeEncodingFn(resultType);
   Location loc = emptyOp.getLoc();
@@ -391,6 +393,41 @@
   return newEmptyOp;
 }
 
+/// Utility method to convert from `linalg.generic` on `tensor` type with
+/// encoding to `linalg.generic` on the materialized type
+static FailureOr<Operation *>
+lowerOpWithEncoding(RewriterBase &rewriter, linalg::GenericOp genericOp,
+                    ValueRange convertedInputOperands,
+                    ValueRange convertedOutputOperands, MaterializeEncodingFn,
+                    MaterializeEncodingValueFn) {
+  if (!genericOp.hasTensorSemantics() || !isElementwise(genericOp) ||
+      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "linalg.generic op is not elementwise "
+                                       "with single input and single output");
+  }
+  if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap m) { return m.isIdentity(); })) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "indexing maps are not all identity maps");
+  }
+  auto convertedResultType =
+      convertedOutputOperands[0].getType().cast<RankedTensorType>();
+  SmallVector<AffineMap> maps(
+      2, AffineMap::getMultiDimIdentityMap(convertedResultType.getRank(),
+                                           rewriter.getContext()));
+  SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
+                                                 utils::IteratorType::parallel);
+  auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
+      genericOp.getLoc(), convertedResultType, convertedInputOperands,
+      convertedOutputOperands, maps, iteratorTypes,
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.inlineRegionBefore(genericOp.getRegion(),
+                              materializedGenericOp.getRegion(),
+                              materializedGenericOp.getRegion().begin());
+  return materializedGenericOp.getOperation();
+}
+
 namespace {
 //===---------------------------------------------------------------------===//
 // Patterns to lower ops with encodings. These are written as
@@ -491,7 +528,7 @@
   MaterializeEncodingFn materializeEncodingFn;
 };
 
-/// Generic pattern to convert operaiton that is in Destination Passing Style.
+/// Generic pattern to convert operation that is in Destination Passing Style.
 template <typename OpTy>
 struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
   using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
@@ -633,6 +670,7 @@
   patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
                   MaterializeDPSOperation<linalg::MatmulOp>,
                   MaterializeDPSOperation<linalg::BatchMatmulOp>,
+                  MaterializeDPSOperation<linalg::GenericOp>,
                   MaterializeOperation<tensor::EmptyOp>,
                   SetEncodingOpToPackOpConversion,
                   UnsetEncodingOpToUnPackOpConversion>(