[GlobalOpt] Create linalg.generic ops instead of CastOpInterface ops on tensors (#15591)

Closes #15547
diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandVectors.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandVectors.cpp
index feebe49..20d26b3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/ExpandVectors.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandVectors.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/GlobalOptimization/PassDetail.h"
 #include "iree/compiler/GlobalOptimization/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -106,11 +107,12 @@
           rewriter
               .create<tensor::ExpandShapeOp>(loc, newVectorCastInTy, castIn, ri)
               .getResult();
-      expandedIn =
-          rewriter
-              .create(loc, castOp.value()->getName().getIdentifier(),
-                      expandedIn, newVectorInTy, castOp.value()->getAttrs())
-              ->getResult(0);
+      auto genericOp = castOp.value()->getParentOfType<linalg::GenericOp>();
+      NamedAttrList castAttrs = genericOp
+                                    ? linalg::getPrunedAttributeList(genericOp)
+                                    : castOp.value()->getAttrs();
+      expandedIn = createGenericElementwiseCastOp(rewriter, loc, expandedIn,
+                                                  castOp.value(), castAttrs);
     } else {
       expandedIn =
           rewriter
diff --git a/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
index a1b911b..e1d473b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
@@ -6,11 +6,13 @@
 
 #include "iree/compiler/GlobalOptimization/PassDetail.h"
 #include "iree/compiler/GlobalOptimization/Passes.h"
+#include "iree/compiler/GlobalOptimization/Utils.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/PatternMatch.h"
@@ -151,15 +153,12 @@
   if (inputType.getElementType() == outputType.getElementType()) {
     return input;
   }
-  auto castedType =
-      RankedTensorType::get(inputType.getShape(), outputType.getElementType());
   for (auto bodyOp : genericOp.getBody()->getOps<CastOpInterface>()) {
     Value castInput = bodyOp->getOperand(0);
     if (isBlockArgumentAtIndex(castInput, inputIdx)) {
-      return rewriter
-          .create(bodyOp->getLoc(), bodyOp->getName().getIdentifier(), input,
-                  castedType, bodyOp->getAttrs())
-          ->getResult(0);
+      return createGenericElementwiseCastOp(
+          rewriter, loc, input, bodyOp,
+          linalg::getPrunedAttributeList(genericOp));
     }
   }
   return failure();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
index 26904ff..06c9b82 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
@@ -146,38 +146,12 @@
 static Value castEncodedResult(OpBuilder &builder, Location loc, Value encoded,
                                CastOpInterface castOp,
                                IREE::LinalgExt::EncodingAttr encodingAttr) {
-  auto encodedType = cast<RankedTensorType>(encoded.getType());
-  auto castResultElemType = getElementTypeOrSelf(castOp->getResultTypes()[0]);
-  auto castedType = RankedTensorType::get(encodedType.getShape(),
-                                          castResultElemType, encodingAttr);
-  assert(encodedType.getElementType() ==
-             getElementTypeOrSelf(castOp->getOperandTypes()[0]) &&
-         "Expected encoded element type to be the same as the cast input "
-         "element type");
-  SmallVector<OpFoldResult> inputMixedSizes =
-      tensor::getMixedSizes(builder, loc, encoded);
-  Value init = builder.create<tensor::EmptyOp>(
-      loc, inputMixedSizes, castResultElemType, encodingAttr);
-  SmallVector<AffineMap> maps(
-      2, AffineMap::getMultiDimIdentityMap(castedType.getRank(),
-                                           builder.getContext()));
-  SmallVector<utils::IteratorType> iteratorTypes(castedType.getRank(),
-                                                 utils::IteratorType::parallel);
   auto genericOp = castOp->getParentOfType<linalg::GenericOp>();
-  NamedAttrList castAttrs =
-      genericOp ? genericOp->getAttrs() : castOp->getAttrs();
-  return builder
-      .create<linalg::GenericOp>(
-          loc, castedType, encoded, init, maps, iteratorTypes,
-          [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
-            Value castRes =
-                b.create(nestedLoc, castOp->getName().getIdentifier(), args[0],
-                         castResultElemType)
-                    ->getResult(0);
-            b.create<linalg::YieldOp>(nestedLoc, castRes);
-          },
-          castAttrs)
-      ->getResult(0);
+  NamedAttrList castAttrs = genericOp
+                                ? linalg::getPrunedAttributeList(genericOp)
+                                : castOp->getAttrs();
+  return createGenericElementwiseCastOp(builder, loc, encoded, castOp,
+                                        castAttrs, encodingAttr);
 }
 
 static Value
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp
index a1a5859..70ea19a 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -52,6 +53,39 @@
   return castSrcElemType;
 }
 
+Value createGenericElementwiseCastOp(
+    OpBuilder &builder, Location loc, Value input, CastOpInterface castOp,
+    ArrayRef<NamedAttribute> attrs,
+    std::optional<IREE::LinalgExt::EncodingAttr> encoding) {
+  auto inputType = cast<RankedTensorType>(input.getType());
+  SmallVector<AffineMap> maps(
+      2, AffineMap::getMultiDimIdentityMap(inputType.getRank(),
+                                           builder.getContext()));
+  SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
+                                                 utils::IteratorType::parallel);
+  auto elementType = getElementTypeOrSelf(castOp->getResultTypes()[0]);
+  auto castedType = inputType.clone(elementType);
+  SmallVector<OpFoldResult> inputMixedSizes =
+      tensor::getMixedSizes(builder, loc, input);
+  Value init =
+      encoding
+          ? builder.create<tensor::EmptyOp>(loc, inputMixedSizes, elementType,
+                                            *encoding)
+          : builder.create<tensor::EmptyOp>(loc, inputMixedSizes, elementType);
+  return builder
+      .create<linalg::GenericOp>(
+          loc, castedType, input, init, maps, iteratorTypes,
+          [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+            Value castRes =
+                b.create(nestedLoc, castOp->getName().getIdentifier(), args[0],
+                         elementType)
+                    ->getResult(0);
+            b.create<linalg::YieldOp>(nestedLoc, castRes);
+          },
+          attrs)
+      .getResult(0);
+}
+
 } // namespace GlobalOptimization
 } // namespace iree_compiler
 } // namespace mlir
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.h b/compiler/src/iree/compiler/GlobalOptimization/Utils.h
index 03987bb..8a9f910 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Utils.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.h
@@ -7,27 +7,38 @@
 #define IREE_COMPILER_GLOBALOPTIMIZATION_UTILS_H_
 
 #include <optional>
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 
 namespace mlir {
 class Type;
 class Value;
 class CastOpInterface;
+class OpBuilder;
+class Location;
+class NamedAttribute;
 
 namespace iree_compiler {
 namespace GlobalOptimization {
 
-// If the producer is a CastOpInterface, or a linalg::GenericOp that performs
-// only a CastOpInterface on its input, return the CastOpInterface op.
-// Otherwise, return std::nullopt.
-//
-// **Note: If the CastOpInterface has been generalized, the return Operation
-//         is the body CastOpInterface op, not the linalg::GenericOp.
+/// If the producer is a CastOpInterface, or a linalg::GenericOp that performs
+/// only a CastOpInterface on its input, return the CastOpInterface op.
+/// Otherwise, return std::nullopt.
+///
+/// **Note: If the CastOpInterface has been generalized, the return Operation
+///         is the body CastOpInterface op, not the linalg::GenericOp.
 std::optional<CastOpInterface> getDefiningCastOp(Value input);
 
-// Returns the source element type of the defining CastOpInterface of `input`,
-// if there is one. Otherwise return std::nullopt.
+/// Returns the source element type of the defining CastOpInterface of `input`,
+/// if there is one. Otherwise return std::nullopt.
 std::optional<Type> getCastElemType(Value input);
 
+/// Create an elementwise identity map linalg::GenericOp that casts the `input`
+/// with the same cast operation as the passed CastOpInterface `castOp`.
+Value createGenericElementwiseCastOp(
+    OpBuilder &builder, Location loc, Value input, CastOpInterface castOp,
+    ArrayRef<NamedAttribute> attrs,
+    std::optional<IREE::LinalgExt::EncodingAttr> encoding = std::nullopt);
+
 } // namespace GlobalOptimization
 } // namespace iree_compiler
 } // namespace mlir
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/expand_vectors.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/expand_vectors.mlir
index 195ad58..bad1503 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/expand_vectors.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/expand_vectors.mlir
@@ -128,7 +128,7 @@
 
 // -----
 
-func.func @vecmat_bf16bf16f32_casted_dynamic(%arg0 : tensor<?xbf16>, %arg1 : tensor<?x?xbf16>,
+func.func @vecmat_bf16bf16f32_casted_dynamic(%arg0 : tensor<?xbf16>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg0, %c0 : tensor<?xbf16>
@@ -142,29 +142,33 @@
     %2 = arith.extf %in : bf16 to f32
     linalg.yield %2 : f32
   } -> tensor<?xf32>
-  %casted1 = arith.extf %arg1 : tensor<?x?xbf16> to tensor<?x?xf32>
-  %1 = linalg.vecmat ins(%casted0, %casted1 : tensor<?xf32>, tensor<?x?xf32>)
+  %1 = linalg.vecmat ins(%casted0, %arg1 : tensor<?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
   return %1 : tensor<?xf32>
 }
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //      CHECK:  func @vecmat_bf16bf16f32_casted_dynamic(
-// CHECK-SAME:  %[[ARG0:.+]]: tensor<?xbf16>, %[[ARG1:.+]]: tensor<?x?xbf16>, %[[ARG2:.+]]: tensor<?xf32>
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<?xbf16>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?xf32>
+//  CHECK-DAG:  %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:  %[[EXPANDED_IN:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<?xbf16> into tensor<1x?xbf16>
+//  CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[EXPANDED_IN]], %[[C1]] : tensor<1x?xbf16>
+//      CHECK:  %[[INIT_CASTED0:.+]] = tensor.empty(%[[DIM0]]) : tensor<1x?xf32>
+//      CHECK:  %[[CASTED0:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXPANDED_IN]] : tensor<1x?xbf16>) outs(%[[INIT_CASTED0]] : tensor<1x?xf32>) {
+// CHECK-NEXT:     ^bb0(%[[CAST_ARG_IN:.+]]: bf16, %[[CAST_ARG_OUT:.+]]: f32):
+// CHECK-NEXT:     %[[CAST_OP:.+]] = arith.extf %[[CAST_ARG_IN]] : bf16 to f32
+// CHECK-NEXT:     linalg.yield %[[CAST_OP]] : f32
 //  CHECK-DAG:  %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
-//  CHECK-DAG:  %[[CASTED0:.+]] = arith.extf %[[EXPANDED_IN]] : tensor<1x?xbf16> to tensor<1x?xf32>
-//  CHECK-DAG:  %[[CASTED1:.+]] = arith.extf %[[ARG1]] : tensor<?x?xbf16> to tensor<?x?xf32>
-//  CHECK-DAG:  %[[MATMUL:.+]] = linalg.matmul ins(%[[CASTED0]], %[[CASTED1]] : tensor<1x?xf32>, tensor<?x?xf32>) outs(%[[EXPANDED_OUT]] : tensor<1x?xf32>)
+//  CHECK-DAG:  %[[MATMUL:.+]] = linalg.matmul ins(%[[CASTED0]], %[[ARG1]] : tensor<1x?xf32>, tensor<?x?xf32>) outs(%[[EXPANDED_OUT]] : tensor<1x?xf32>)
 //  CHECK-DAG:  %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MATMUL]] {{\[}}[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
 //      CHECK:  return %[[COLLAPSED]]
 
 // -----
 
-func.func @matvec_i8i8i32_casted_dynamic(%arg0 : tensor<?x?xi8>, %arg1 : tensor<?xi8>,
+func.func @matvec_i8i8i32_casted_dynamic(%arg0 : tensor<?x?xi32>, %arg1 : tensor<?xi8>,
     %arg2 : tensor<?xi32>) -> tensor<?xi32> {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg1, %c0 : tensor<?xi8>
   %0 = tensor.empty(%dim) : tensor<?xi32>
-  %casted0 = arith.extui %arg0 : tensor<?x?xi8> to tensor<?x?xi32>
   %casted1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 
                                               affine_map<(d0) -> (d0)>], 
                              iterator_types = ["parallel"]} 
@@ -174,20 +178,29 @@
     %2 = arith.extsi %in : i8 to i32
     linalg.yield %2 : i32
   } -> tensor<?xi32>
-  %1 = linalg.matvec ins(%casted0, %casted1 : tensor<?x?xi32>, tensor<?xi32>)
+  %1 = linalg.matvec ins(%arg0, %casted1 : tensor<?x?xi32>, tensor<?xi32>)
       outs(%arg2 : tensor<?xi32>) -> tensor<?xi32>
   return %1 : tensor<?xi32>
 }
+
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //      CHECK:  func @matvec_i8i8i32_casted_dynamic(
-// CHECK-SAME:  %[[ARG0:.+]]: tensor<?x?xi8>, %[[ARG1:.+]]: tensor<?xi8>, %[[ARG2:.+]]: tensor<?xi32>
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<?x?xi32>, %[[ARG1:.+]]: tensor<?xi8>, %[[ARG2:.+]]: tensor<?xi32>
+//  CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:  %[[EXPANDED_IN:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1]] : tensor<?xi8> into tensor<?x1xi8>
+//  CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[EXPANDED_IN]], %[[C0]] : tensor<?x1xi8>
+//      CHECK:  %[[INIT_CASTED1:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1xi32>
+//      CHECK:  %[[CASTED1:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXPANDED_IN]] : tensor<?x1xi8>) outs(%[[INIT_CASTED1]] : tensor<?x1xi32>) {
+// CHECK-NEXT:     ^bb0(%[[CAST_ARG_IN:.+]]: i8, %[[CAST_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[CAST_OP:.+]] = arith.extsi %[[CAST_ARG_IN]] : i8 to i32
+// CHECK-NEXT:     linalg.yield %[[CAST_OP]] : i32
 //  CHECK-DAG:  %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1]] : tensor<?xi32> into tensor<?x1xi32>
-//  CHECK-DAG:  %[[CASTED0:.+]] = arith.extui %[[ARG0]] : tensor<?x?xi8> to tensor<?x?xi32>
-//  CHECK-DAG:  %[[CASTED1:.+]] = arith.extsi %[[EXPANDED_IN]] : tensor<?x1xi8> to tensor<?x1xi32>
-//  CHECK-DAG:  %[[MATMUL:.+]] = linalg.matmul ins(%[[CASTED0]], %[[CASTED1]] : tensor<?x?xi32>, tensor<?x1xi32>) outs(%[[EXPANDED_OUT]] : tensor<?x1xi32>)
+//  CHECK-DAG:  %[[MATMUL:.+]] = linalg.matmul ins(%[[ARG0]], %[[CASTED1]] : tensor<?x?xi32>, tensor<?x1xi32>) outs(%[[EXPANDED_OUT]] : tensor<?x1xi32>)
 //  CHECK-DAG:  %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MATMUL]] {{\[}}[0, 1]] : tensor<?x1xi32> into tensor<?xi32>
 //      CHECK:  return %[[COLLAPSED]]
 
+// -----
+
 func.func @batch_vecmat_casted_f16f32f32_dynamic(%arg0 : tensor<3x?xf16>, %arg1 : tensor<3x?x?xf32>,
     %arg2 : tensor<3x?xf32>) -> tensor<3x?xf32> {
   %c1 = arith.constant 1 : index
@@ -207,10 +220,17 @@
   return %1 : tensor<3x?xf32>
 }
 
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK:  func @batch_vecmat_casted_f16f32f32_dynamic(
 // CHECK-SAME:  %[[ARG0:.+]]: tensor<3x?xf16>, %[[ARG1:.+]]: tensor<3x?x?xf32>, %[[ARG2:.+]]: tensor<3x?xf32>
 //  CHECK-DAG:  %[[EXPANDED_IN:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] : tensor<3x?xf16> into tensor<3x1x?xf16>
-//  CHECK-DAG:  %[[CASTED0:.+]] = arith.extf %[[EXPANDED_IN]] : tensor<3x1x?xf16> to tensor<3x1x?xf32>
+//  CHECK-DAG:  %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[EXPANDED_IN]], %[[C2]] : tensor<3x1x?xf16>
+//      CHECK:  %[[INIT_CASTED0:.+]] = tensor.empty(%[[DIM0]]) : tensor<3x1x?xf32>
+//      CHECK:  %[[CASTED0:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_IN]] : tensor<3x1x?xf16>) outs(%[[INIT_CASTED0]] : tensor<3x1x?xf32>) {
+// CHECK-NEXT:     ^bb0(%[[CAST_ARG_IN:.+]]: f16, %[[CAST_ARG_OUT:.+]]: f32):
+// CHECK-NEXT:     %[[CAST_OP:.+]] = arith.extf %[[CAST_ARG_IN]] : f16 to f32
+// CHECK-NEXT:     linalg.yield %[[CAST_OP]] : f32
 //  CHECK-DAG:  %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2]] : tensor<3x?xf32> into tensor<3x1x?xf32>
 //  CHECK-DAG:  %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[CASTED0]], %[[ARG1]] : tensor<3x1x?xf32>, tensor<3x?x?xf32>) outs(%[[EXPANDED_OUT]] : tensor<3x1x?xf32>)
 //  CHECK-DAG:  %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MATMUL]] {{\[}}[0, 1], [2]] : tensor<3x1x?xf32> into tensor<3x?xf32>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
index 815f815..e8eb5ee 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
@@ -21,16 +21,25 @@
     return %2 : tensor<11008x32xi32>
   }
 }
-
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK:  func @raise_batch_vecmat(
 // CHECK-SAME:  %[[ARG0:.+]]: tensor<32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
 //  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
-//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
-//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
 //  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x128x11008xi4>
 //  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
-//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[ARG0]] : tensor<32x128xi16> to tensor<32x128xi32>
-//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+//      CHECK:  %[[INIT_EXTSI:.+]] = tensor.empty() : tensor<32x128xi32>
+//      CHECK:  %[[EXTSI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<32x128xi16>) outs(%[[INIT_EXTSI]] : tensor<32x128xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTSI_ARG_IN:.+]]: i16, %[[EXTSI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTSI_OP:.+]] = arith.extsi %[[EXTSI_ARG_IN]] : i16 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTSI_OP]] : i32
+//      CHECK:  %[[INIT_EXTUI:.+]] = tensor.empty() : tensor<32x128x11008xi32>
+//      CHECK:  %[[EXTUI:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE0]] : tensor<32x128x11008xi4>) outs(%[[INIT_EXTUI]] : tensor<32x128x11008xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTUI_ARG_IN:.+]]: i4, %[[EXTUI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTUI_OP:.+]] = arith.extui %[[EXTUI_ARG_IN]] : i4 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTUI_OP]] : i32
+//      CHECK:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+//      CHECK:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
 //      CHECK:  %[[VECMAT:.+]] = linalg.batch_vecmat ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
 //      CHECK:  %[[INIT2:.+]] = tensor.empty() : tensor<11008x32xi32>
 //      CHECK:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[VECMAT]] : tensor<32x11008xi32>) outs(%[[INIT2]] : tensor<11008x32xi32>) permutation = [1, 0]
@@ -59,18 +68,27 @@
     return %2 : tensor<11008x32xi32>
   }
 }
-
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //      CHECK:  func @raise_batch_matvec(
 // CHECK-SAME:  %[[ARG0:.+]]: tensor<11008x32x128xi4>, %[[ARG1:.+]]: tensor<128x32xi16>
 //  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
-//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
-//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
 //  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x11008x128xi4>
 //  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x11008x128xi4>) permutation = [1, 0, 2]
 //  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty() : tensor<32x128xi16>
 //  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<128x32xi16>) outs(%[[INIT2]] : tensor<32x128xi16>) permutation = [1, 0]
-//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x11008x128xi4> to tensor<32x11008x128xi32>
-//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE1]] : tensor<32x128xi16> to tensor<32x128xi32>
+//      CHECK:  %[[INIT_EXTUI:.+]] = tensor.empty() : tensor<32x11008x128xi32>
+//      CHECK:  %[[EXTUI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE0]] : tensor<32x11008x128xi4>) outs(%[[INIT_EXTUI]] : tensor<32x11008x128xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTUI_ARG_IN:.+]]: i4, %[[EXTUI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTUI_OP:.+]] = arith.extui %[[EXTUI_ARG_IN]] : i4 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTUI_OP]] : i32
+//      CHECK:  %[[INIT_EXTSI:.+]] = tensor.empty() : tensor<32x128xi32>
+//      CHECK:  %[[EXTSI:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[TRANSPOSE1]] : tensor<32x128xi16>) outs(%[[INIT_EXTSI]] : tensor<32x128xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTSI_ARG_IN:.+]]: i16, %[[EXTSI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTSI_OP:.+]] = arith.extsi %[[EXTSI_ARG_IN]] : i16 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTSI_OP]] : i32
+//      CHECK:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+//      CHECK:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
 //      CHECK:  %[[MATMUL:.+]] = linalg.batch_matvec ins(%[[EXTUI]], %[[EXTSI]] : tensor<32x11008x128xi32>, tensor<32x128xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
 //      CHECK:  %[[INIT3:.+]] = tensor.empty() : tensor<11008x32xi32>
 //      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x11008xi32>) outs(%[[INIT3]] : tensor<11008x32xi32>) permutation = [1, 0]
@@ -99,18 +117,26 @@
     return %2 : tensor<11008x32x8xi32>
   }
 }
-
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK:  func @raise_batch_matmul(
 // CHECK-SAME:  %[[ARG0:.+]]: tensor<8x32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
 //  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
-//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x8x11008xi32>
-//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x8x11008xi32>)
 //  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x8x128xi16>
 //  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x32x128xi16>) outs(%[[INIT1]] : tensor<32x8x128xi16>) permutation = [1, 0, 2]
 //  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty() : tensor<32x128x11008xi4>
 //  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT2]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
-//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<32x8x128xi16> to tensor<32x8x128xi32>
-//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+//      CHECK:  %[[INIT_EXTSI:.+]] = tensor.empty() : tensor<32x8x128xi32>
+//      CHECK:  %[[EXTSI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE0]] : tensor<32x8x128xi16>) outs(%[[INIT_EXTSI]] : tensor<32x8x128xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTSI_ARG_IN:.+]]: i16, %[[EXTSI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTSI_OP:.+]] = arith.extsi %[[EXTSI_ARG_IN]] : i16 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTSI_OP]] : i32
+//      CHECK:  %[[INIT_EXTUI:.+]] = tensor.empty() : tensor<32x128x11008xi32>
+//      CHECK:  %[[EXTUI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE1]] : tensor<32x128x11008xi4>) outs(%[[INIT_EXTUI]] : tensor<32x128x11008xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTUI_ARG_IN:.+]]: i4, %[[EXTUI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTUI_OP:.+]] = arith.extui %[[EXTUI_ARG_IN]] : i4 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTUI_OP]] : i32
+//      CHECK:  %[[INIT0:.+]] = tensor.empty() : tensor<32x8x11008xi32>
+//      CHECK:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x8x11008xi32>)
 //      CHECK:  %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x8x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x8x11008xi32>)
 //      CHECK:  %[[INIT3:.+]] = tensor.empty() : tensor<11008x32x8xi32>
 //      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x32x8xi32>) permutation = [2, 0, 1]
@@ -141,21 +167,29 @@
     return %2 : tensor<11008x?x8xi32>
   }
 }
-
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK:  func @raise_batch_matmul_dyn(
 // CHECK-SAME:  %[[ARG0:.+]]: tensor<8x?x128xi16>, %[[ARG1:.+]]: tensor<11008x?x128xi4>
 //  CHECK-DAG:  %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
 //  CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<8x?x128xi16>
-//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x11008xi32>
-//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<?x8x11008xi32>)
 //  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x128xi16>
 //  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x?x128xi16>) outs(%[[INIT1]] : tensor<?x8x128xi16>) permutation = [1, 0, 2]
 //  CHECK-DAG:  %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<11008x?x128xi4>
 //  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty(%[[DIM1]]) : tensor<?x128x11008xi4>
 //  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x?x128xi4>) outs(%[[INIT2]] : tensor<?x128x11008xi4>) permutation = [1, 2, 0]
-//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<?x8x128xi16> to tensor<?x8x128xi32>
-//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<?x128x11008xi4> to tensor<?x128x11008xi32>
+//      CHECK:  %[[INIT_EXTSI:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x128xi32>
+//      CHECK:  %[[EXTSI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE0]] : tensor<?x8x128xi16>) outs(%[[INIT_EXTSI]] : tensor<?x8x128xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTSI_ARG_IN:.+]]: i16, %[[EXTSI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTSI_OP:.+]] = arith.extsi %[[EXTSI_ARG_IN]] : i16 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTSI_OP]] : i32
+//      CHECK:  %[[INIT_EXTUI:.+]] = tensor.empty(%[[DIM1]]) : tensor<?x128x11008xi32>
+//      CHECK:  %[[EXTUI:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[TRANSPOSE1]] : tensor<?x128x11008xi4>) outs(%[[INIT_EXTUI]] : tensor<?x128x11008xi32>) {
+// CHECK-NEXT:     ^bb0(%[[EXTUI_ARG_IN:.+]]: i4, %[[EXTUI_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:     %[[EXTUI_OP:.+]] = arith.extui %[[EXTUI_ARG_IN]] : i4 to i32
+// CHECK-NEXT:     linalg.yield %[[EXTUI_OP]] : i32
+//      CHECK:  %[[INIT0:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x11008xi32>
+//      CHECK:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<?x8x11008xi32>)
 //      CHECK:  %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<?x8x128xi32>, tensor<?x128x11008xi32>) outs(%[[FILL]] : tensor<?x8x11008xi32>)
 //      CHECK:  %[[INIT3:.+]] = tensor.empty(%[[DIM0]]) : tensor<11008x?x8xi32>
 //      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<?x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x?x8xi32>) permutation = [2, 0, 1]