[GlobalOptimization] Support SetEncoding on batch matmul cases with p… (#15371)

…roducer CastOpInterface ops

This commit allows SetEncoding to match on CastOpInterface ops that are
producers of BatchMatmulOps. This allows inferring the correct input
types when the input casting is not implicit inside the BatchMatmul
body. This also gives a way to infer signedness on input types through
the specific CastOpInterface op types.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index d036b9c..7bb6437 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -53,9 +53,11 @@
         "Passes.cpp",
         "RemoveZeroExtentTensors.cpp",
         "SetEncoding.cpp",
+        "Utils.cpp",
     ],
     hdrs = [
         "Passes.h",
+        "Utils.h",
     ],
     deps = [
         ":PassHeaders",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 88781e7..ea706e1 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -38,6 +38,7 @@
     GlobalOptimization
   HDRS
     "Passes.h"
+    "Utils.h"
   SRCS
     "Convert1X1FilterConv2DToMatmul.cpp"
     "DetachElementwiseFromNamedOps.cpp"
@@ -48,6 +49,7 @@
     "Passes.cpp"
     "RemoveZeroExtentTensors.cpp"
     "SetEncoding.cpp"
+    "Utils.cpp"
   DEPS
     ::PassHeaders
     ::PassesIncGen
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
index 342c19d..789e6a5 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
@@ -15,10 +15,12 @@
 #include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
 #include "iree/compiler/GlobalOptimization/PassDetail.h"
 #include "iree/compiler/GlobalOptimization/Passes.h"
+#include "iree/compiler/GlobalOptimization/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
@@ -136,17 +138,61 @@
       getAttr(narrow.M), getAttr(narrow.N));
 }
 
-static Value padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
-                               IREE::LinalgExt::EncodingUser user,
-                               IREE::LinalgExt::EncodingRole role,
-                               TypeRange operandTypes,
-                               MatmulNarrowSizes narrow) {
-  // No need to specify original_type in the encoding passed to pad(), because
+// Creates a linalg::GenericOp that performs an element-wise cast of the same
+// type as performed in `castOp`, and returns the result enceoded with
+// `encodingAttr`. The element type of `encoded` is expected to be the same as
+// the element type of the input to `castOp`, which can be a CastOpInterface op
+// on a tensor or single element.
+static Value castEncodedResult(OpBuilder &builder, Location loc, Value encoded,
+                               CastOpInterface castOp,
+                               IREE::LinalgExt::EncodingAttr encodingAttr) {
+  auto encodedType = cast<RankedTensorType>(encoded.getType());
+  auto castResultElemType = getElementTypeOrSelf(castOp->getResultTypes()[0]);
+  auto castedType = RankedTensorType::get(encodedType.getShape(),
+                                          castResultElemType, encodingAttr);
+  assert(encodedType.getElementType() ==
+             getElementTypeOrSelf(castOp->getOperandTypes()[0]) &&
+         "Expected encoded element type to be the same as the cast input "
+         "element type");
+  SmallVector<OpFoldResult> inputMixedSizes =
+      tensor::getMixedSizes(builder, loc, encoded);
+  Value init = builder.create<tensor::EmptyOp>(
+      loc, inputMixedSizes, castResultElemType, encodingAttr);
+  SmallVector<AffineMap> maps(
+      2, AffineMap::getMultiDimIdentityMap(castedType.getRank(),
+                                           builder.getContext()));
+  SmallVector<utils::IteratorType> iteratorTypes(castedType.getRank(),
+                                                 utils::IteratorType::parallel);
+  auto genericOp = castOp->getParentOfType<linalg::GenericOp>();
+  NamedAttrList castAttrs =
+      genericOp ? genericOp->getAttrs() : castOp->getAttrs();
+  return builder
+      .create<linalg::GenericOp>(
+          loc, castedType, encoded, init, maps, iteratorTypes,
+          [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+            Value castRes =
+                b.create(nestedLoc, castOp->getName().getIdentifier(), args[0],
+                         castResultElemType)
+                    ->getResult(0);
+            b.create<linalg::YieldOp>(nestedLoc, castRes);
+          },
+          castAttrs)
+      ->getResult(0);
+}
+
+static Value
+padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
+                  IREE::LinalgExt::EncodingUser user,
+                  IREE::LinalgExt::EncodingRole role, TypeRange operandTypes,
+                  MatmulNarrowSizes narrow,
+                  std::optional<CastOpInterface> castOp = std::nullopt) {
+  Value padSource = castOp ? source.getDefiningOp()->getOperand(0) : source;
+  // No need to specify original_type in the encoding poadded to pad(), because
   // the operand there is the `source` tensor, so it will default to reading its
   // original shape.
   auto encodingForPad = makeEncoding(builder, user, role, operandTypes,
                                      /*originalType=*/Type{}, narrow);
-  Value padded = pad(builder, loc, source, encodingForPad);
+  Value padded = pad(builder, loc, padSource, encodingForPad);
   // For setEncoding() below, we potentially need to specify an encoding with an
   // explicit original_type, because the operand there is the padded tensor
   // returned by pad() above, but we want setEncoding to be aware of the
@@ -154,11 +200,16 @@
   // verbosity, we only specify the original original_type when it differs from
   // the tensor type that the encoding is applied to.
   auto encodingForSetEncoding = encodingForPad;
-  if (padded.getType() != source.getType()) {
+  if (padded.getType() != padSource.getType()) {
     encodingForSetEncoding = makeEncoding(builder, user, role, operandTypes,
-                                          source.getType(), narrow);
+                                          padSource.getType(), narrow);
   }
-  return setEncoding(builder, loc, padded, encodingForSetEncoding);
+  Value encoded = setEncoding(builder, loc, padded, encodingForSetEncoding);
+  if (castOp) {
+    encoded = castEncodedResult(builder, loc, encoded, castOp.value(),
+                                encodingForSetEncoding);
+  }
+  return encoded;
 }
 
 static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc,
@@ -297,8 +348,12 @@
       }
       return {};
     };
-    Type lhsElemType = getElemType(origLhs);
-    Type rhsElemType = getElemType(origRhs);
+    std::optional<CastOpInterface> maybeLhsCastOp = getDefiningCastOp(origLhs);
+    std::optional<CastOpInterface> maybeRhsCastOp = getDefiningCastOp(origRhs);
+    Type lhsElemType = maybeLhsCastOp ? getCastElemType(origLhs).value()
+                                      : getElemType(origLhs);
+    Type rhsElemType = maybeRhsCastOp ? getCastElemType(origRhs).value()
+                                      : getElemType(origRhs);
     Type outElemType = getElemType(origOut);
 
     if (!lhsElemType || !rhsElemType || !outElemType) {
@@ -310,13 +365,17 @@
     MatmulNarrowSizes narrowSizes =
         getMatmulNarrowSizes(origOut.getType().cast<ShapedType>());
     Location loc = matmulOp.getLoc();
-    TypeRange operandTypes = matmulOp->getOperandTypes();
-    Value encodedLhs = padAndSetEncoding(rewriter, loc, origLhs, user,
-                                         IREE::LinalgExt::EncodingRole::LHS,
-                                         operandTypes, narrowSizes);
-    Value encodedRhs = padAndSetEncoding(rewriter, loc, origRhs, user,
-                                         IREE::LinalgExt::EncodingRole::RHS,
-                                         operandTypes, narrowSizes);
+    SmallVector<Type> operandTypes(matmulOp->getOperandTypes());
+    operandTypes[0] =
+        cast<RankedTensorType>(operandTypes[0]).clone(lhsElemType);
+    operandTypes[1] =
+        cast<RankedTensorType>(operandTypes[1]).clone(rhsElemType);
+    Value encodedLhs = padAndSetEncoding(
+        rewriter, loc, origLhs, user, IREE::LinalgExt::EncodingRole::LHS,
+        operandTypes, narrowSizes, maybeLhsCastOp);
+    Value encodedRhs = padAndSetEncoding(
+        rewriter, loc, origRhs, user, IREE::LinalgExt::EncodingRole::RHS,
+        operandTypes, narrowSizes, maybeRhsCastOp);
     Value encodedOut = padAndSetEncoding(rewriter, loc, origOut, user,
                                          IREE::LinalgExt::EncodingRole::RESULT,
                                          operandTypes, narrowSizes);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp
new file mode 100644
index 0000000..a1a5859
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp
@@ -0,0 +1,57 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/GlobalOptimization/Utils.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace GlobalOptimization {
+
+std::optional<CastOpInterface> getDefiningCastOp(Value input) {
+  auto castOp = input.getDefiningOp<CastOpInterface>();
+  if (castOp) {
+    return castOp;
+  }
+  auto genericOp = input.getDefiningOp<linalg::GenericOp>();
+  if (!genericOp || genericOp.getNumDpsInputs() != 1 ||
+      genericOp.getNumDpsInits() != 1 ||
+      genericOp.getBody()->getOperations().size() != 2 ||
+      !isElementwise(genericOp)) {
+    return std::nullopt;
+  }
+  auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
+  castOp = yieldOp->getOperand(0).getDefiningOp<CastOpInterface>();
+  if (!castOp) {
+    return std::nullopt;
+  }
+  Value castIn = castOp->getOperand(0);
+  if (castIn.isa<BlockArgument>() &&
+      castIn.cast<BlockArgument>().getArgNumber() != 0) {
+    return std::nullopt;
+  }
+  return castOp;
+}
+
+std::optional<Type> getCastElemType(Value input) {
+  std::optional<CastOpInterface> castOp = getDefiningCastOp(input);
+  if (!castOp) {
+    return std::nullopt;
+  }
+  Type castSrcElemType = getElementTypeOrSelf(castOp.value()->getOperand(0));
+  if (isa<arith::ExtUIOp>(castOp.value())) {
+    int64_t bitWidth = castSrcElemType.getIntOrFloatBitWidth();
+    return IntegerType::get(castOp->getContext(), bitWidth,
+                            IntegerType::SignednessSemantics::Unsigned);
+  }
+  return castSrcElemType;
+}
+
+} // namespace GlobalOptimization
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.h b/compiler/src/iree/compiler/GlobalOptimization/Utils.h
new file mode 100644
index 0000000..03987bb
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.h
@@ -0,0 +1,35 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_UTILS_H_
+#define IREE_COMPILER_GLOBALOPTIMIZATION_UTILS_H_
+
+#include <optional>
+
+namespace mlir {
+class Type;
+class Value;
+class CastOpInterface;
+
+namespace iree_compiler {
+namespace GlobalOptimization {
+
+// If the producer is a CastOpInterface, or a linalg::GenericOp that performs
+// only a CastOpInterface on its input, return the CastOpInterface op.
+// Otherwise, return std::nullopt.
+//
+// **Note: If the CastOpInterface has been generalized, the return Operation
+//         is the body CastOpInterface op, not the linalg::GenericOp.
+std::optional<CastOpInterface> getDefiningCastOp(Value input);
+
+// Returns the source element type of the defining CastOpInterface of `input`,
+// if there is one. Otherwise return std::nullopt.
+std::optional<Type> getCastElemType(Value input);
+
+} // namespace GlobalOptimization
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_GLOBALOPTIMIZATION_UTILS_H_
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
index e675c55..5c7495c 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir
@@ -669,3 +669,89 @@
 // CHECK-NOT:     unset_encoding
 // CHECK:         linalg.matmul
 // CHECK:         linalg.batch_matmul
+
+// -----
+
+func.func @batch_matmul_casted_i8i8i32(%arg0 : tensor<64x100x250xi8>, %arg1 : tensor<64x250x500xi8>,
+      %arg2 : tensor<64x100x500xi32>) -> tensor<64x100x500xi32> {
+  %0 = tensor.empty() : tensor<64x250x500xi32>
+  %casted0 = arith.extui %arg0 : tensor<64x100x250xi8> to tensor<64x100x250xi32>
+  %casted1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
+                                              affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 
+                              iterator_types = ["parallel", "parallel", "parallel"]} 
+                              ins(%arg1 : tensor<64x250x500xi8>) 
+                              outs(%0 : tensor<64x250x500xi32>) {
+  ^bb0(%in: i8, %out: i32):
+      %2 = arith.extsi %in : i8 to i32
+      linalg.yield %2 : i32
+  } -> tensor<64x250x500xi32>
+  %1 = linalg.batch_matmul ins(%casted0, %casted1 : tensor<64x100x250xi32>, tensor<64x250x500xi32>)
+      outs(%arg2 : tensor<64x100x500xi32>) -> tensor<64x100x500xi32>
+  return %1 : tensor<64x100x500xi32>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0 + 250)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0 + 100)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0 + 64)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0 + 500)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//      CHECK: func @batch_matmul_casted_i8i8i32(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<64x100x250xi8>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<64x250x500xi8>
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<64x100x500xi32>
+//  CHECK-DAG:     %[[C64:.+]] = arith.constant 64 : index
+//  CHECK-DAG:     %[[C100:.+]] = arith.constant 100 : index
+//  CHECK-DAG:     %[[C250:.+]] = arith.constant 250 : index
+//  CHECK-DAG:     %[[C500:.+]] = arith.constant 500 : index
+//      CHECK:   %[[LHS_TILE_SIZE:.+]]:3 = iree_linalg_ext.upper_bound_tile_size tensor<64x100x250xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32]>> -> index, index, index
+//      CHECK:   %[[LHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#0, %[[C64]]]
+//      CHECK:   %[[LHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#1, %[[C100]]]
+//      CHECK:   %[[LHS_PADDING_SIZE2:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#2, %[[C250]]]
+//      CHECK:   %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0] high[%[[LHS_PADDING_SIZE0]], %[[LHS_PADDING_SIZE1]], %[[LHS_PADDING_SIZE2]]]
+//      CHECK:       tensor<64x100x250xi8> to tensor<?x?x?xi8>
+//  CHECK-DAG:   %[[LHS_DIM0:.+]] = affine.apply #[[MAP1]]()[%[[LHS_TILE_SIZE]]#2, %[[C250]]]
+//  CHECK-DAG:   %[[LHS_DIM1:.+]] = affine.apply #[[MAP2]]()[%[[LHS_TILE_SIZE]]#1, %[[C100]]]
+//  CHECK-DAG:   %[[LHS_DIM2:.+]] = affine.apply #[[MAP3]]()[%[[LHS_TILE_SIZE]]#0, %[[C64]]]
+//      CHECK:   %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]]
+// CHECK-SAME:       tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32], original_type = tensor<64x100x250xi8>>>
+//      CHECK:   %[[INIT_LHS_CAST:.+]] = tensor.empty(%[[LHS_DIM2]], %[[LHS_DIM1]], %[[LHS_DIM0]]) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32], original_type = tensor<64x100x250xi8>>>
+//      CHECK:   %[[LHS_CASTED:.+]] = linalg.generic {indexing_maps = [#[[MAP5]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:       ins(%[[LHS]] : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32], original_type = tensor<64x100x250xi8>>>)
+// CHECK-SAME:       outs(%[[INIT_LHS_CAST]] : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32], original_type = tensor<64x100x250xi8>>>)
+// CHECK-NEXT:   ^bb0(%[[LHS_ARG_IN:.+]]: i8, %[[LHS_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:   %[[LHS_CAST_OP:.+]] = arith.extui %[[LHS_ARG_IN]] : i8 to i32
+// CHECK-NEXT:   linalg.yield %[[LHS_CAST_OP]] : i32
+// CHECK-NEXT:   -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [ui8, i8, i32], original_type = tensor<64x100x250xi8>>>
+//      CHECK:   %[[RHS_TILE_SIZE:.+]]:3 = iree_linalg_ext.upper_bound_tile_size tensor<64x250x500xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32]>> -> index, index, index
+//      CHECK:   %[[RHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#0, %[[C64]]]
+//      CHECK:   %[[RHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#1, %[[C250]]]
+//      CHECK:   %[[RHS_PADDING_SIZE2:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#2, %[[C500]]]
+//      CHECK:   %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0, 0] high[%[[RHS_PADDING_SIZE0]], %[[RHS_PADDING_SIZE1]], %[[RHS_PADDING_SIZE2]]]
+//      CHECK:       tensor<64x250x500xi8> to tensor<?x?x?xi8>
+//  CHECK-DAG:   %[[RHS_DIM0:.+]] = affine.apply #[[MAP4]]()[%[[RHS_TILE_SIZE]]#2, %[[C500]]]
+//  CHECK-DAG:   %[[RHS_DIM1:.+]] = affine.apply #[[MAP1]]()[%[[RHS_TILE_SIZE]]#1, %[[C250]]]
+//  CHECK-DAG:   %[[RHS_DIM2:.+]] = affine.apply #[[MAP3]]()[%[[RHS_TILE_SIZE]]#0, %[[C64]]]
+//      CHECK:   %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]]
+// CHECK-SAME:       tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32], original_type = tensor<64x250x500xi8>>>
+//      CHECK:   %[[INIT_RHS_CAST:.+]] = tensor.empty(%[[RHS_DIM2]], %[[RHS_DIM1]], %[[RHS_DIM0]]) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32], original_type = tensor<64x250x500xi8>>>
+//      CHECK:   %[[RHS_CASTED:.+]] = linalg.generic {indexing_maps = [#[[MAP5]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:       ins(%[[RHS]] : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32], original_type = tensor<64x250x500xi8>>>)
+// CHECK-SAME:       outs(%[[INIT_RHS_CAST]] : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32], original_type = tensor<64x250x500xi8>>>)
+// CHECK-NEXT:   ^bb0(%[[RHS_ARG_IN:.+]]: i8, %[[RHS_ARG_OUT:.+]]: i32):
+// CHECK-NEXT:   %[[RHS_CAST_OP:.+]] = arith.extsi %[[RHS_ARG_IN]] : i8 to i32
+// CHECK-NEXT:   linalg.yield %[[RHS_CAST_OP]] : i32
+// CHECK-NEXT:   -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [ui8, i8, i32], original_type = tensor<64x250x500xi8>>>
+//      CHECK:   %[[OUTS_TILE_SIZE:.+]]:3 = iree_linalg_ext.upper_bound_tile_size tensor<64x100x500xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [ui8, i8, i32]>> -> index, index, index
+//      CHECK:   %[[OUTS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#0, %[[C64]]]
+//      CHECK:   %[[OUTS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#1, %[[C100]]]
+//      CHECK:   %[[OUTS_PADDING_SIZE2:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#2, %[[C500]]]
+//      CHECK:   %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0, 0, 0] high[%[[OUTS_PADDING_SIZE0]], %[[OUTS_PADDING_SIZE1]], %[[OUTS_PADDING_SIZE2]]]
+//      CHECK:       tensor<64x100x500xi32> to tensor<?x?x?xi32>
+//      CHECK:   %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]]
+// CHECK-SAME:       tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [ui8, i8, i32], original_type = tensor<64x100x500xi32>>>
+//      CHECK:   %[[BATCH_MATMUL:.+]] = linalg.batch_matmul
+// CHECK-SAME:       ins(%[[LHS_CASTED]], %[[RHS_CASTED]] :
+// CHECK-SAME:       outs(%[[OUTS]] :
+//      CHECK:   %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[BATCH_MATMUL]]
+//      CHECK:   %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0, 0, 0] [64, 100, 500] [1, 1, 1]
+//      CHECK:   return %[[RESULT]]