[DT][NFC] Remove FailureOr<> from getEncodingInfo methods. (#19435)

We are able to use identity MaterializationEncodingInfo to represent the
"failure". Thus, we no longer need the `FailureOr` wrapper. The revision
removes the wrapper and updates the `lowerContractionOpWithEncoding`
function type signature. It does not need to pass a callback function.
Instead, we can pass the `IREE::Codegen::LayoutAttrInterface` which has
the method to query the materialization information.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index c6b0d38..05caca8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -100,16 +100,13 @@
     // itself.
     RankedTensorType tensorType =
         transposeNarrowN ? transposeIfNarrowNResult(type) : type;
-    FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
-        getEncodingInfo(tensorType);
-    if (failed(maybeEncodingInfo) ||
-        IREE::Codegen::isIdentityLayout(maybeEncodingInfo.value())) {
+    MaterializeEncodingInfo encodingInfo = getEncodingInfo(tensorType);
+    if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
       return dropEncoding(type);
     }
-    auto encodingInfo = *maybeEncodingInfo;
     auto packedType = cast<RankedTensorType>(tensor::PackOp::inferPackedType(
-        tensorType, maybeEncodingInfo->innerTileSizes,
-        maybeEncodingInfo->innerDimsPos, maybeEncodingInfo->outerDimsPerm));
+        tensorType, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
+        encodingInfo.outerDimsPerm));
 
     // There is no swizzle, we are already done. Typically the case on CPU.
     if (!encodingInfo.swizzle) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index 0a89d3a..3b59cbd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -42,7 +42,7 @@
     return layoutAttr;
   }
 
-  FailureOr<IREE::Codegen::MaterializeEncodingInfo>
+  IREE::Codegen::MaterializeEncodingInfo
   getEncodingInfo(RankedTensorType type) const {
     return layoutAttr.getEncodingInfo(type);
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
index 3a86d51..92b8ac4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
@@ -108,13 +108,9 @@
       return success();
     }
 
-    FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
+    MaterializeEncodingInfo encodingInfo =
         converter->getEncodingInfo(encodingOp.getResultType());
-    if (failed(maybeEncodingInfo)) {
-      return rewriter.notifyMatchFailure(encodingOp,
-                                         "unhandled result encoding");
-    }
-    if (!maybeEncodingInfo->swizzle) {
+    if (!encodingInfo.swizzle) {
       rewriter.replaceOp(encodingOp, packedValue.value());
       return success();
     }
@@ -128,18 +124,18 @@
             .getShape()
             .take_front(origRank));
     expandShapeShape.append(
-        getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape));
+        getExpandedTileShape(encodingInfo.swizzle->expandShape));
     RankedTensorType expandShapeType =
         encodingOp.getSourceType().clone(expandShapeShape);
 
-    SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
-        origRank, maybeEncodingInfo->swizzle->expandShape);
+    SmallVector<ReassociationIndices> reassociation =
+        getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
     auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
         loc, expandShapeType, packedValue.value(), reassociation);
 
     SmallVector<int64_t> transposePerm =
         llvm::to_vector(llvm::seq<int64_t>(0, origRank));
-    for (auto perm : maybeEncodingInfo->swizzle->permutation) {
+    for (auto perm : encodingInfo.swizzle->permutation) {
       transposePerm.push_back(origRank + perm);
     }
     SmallVector<OpFoldResult> transposeResultDims =
@@ -168,9 +164,9 @@
     auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
         getTypeConverter());
 
-    FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
+    MaterializeEncodingInfo encodingInfo =
         converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
-    if (failed(maybeEncodingInfo)) {
+    if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
       Type targetType =
           getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
       Value result = rewriter.createOrFold<tensor::CastOp>(
@@ -181,15 +177,14 @@
 
     Location loc = unsetEncodingOp.getLoc();
     Value unpackSrc = adaptor.getSource();
-    if (maybeEncodingInfo->swizzle) {
+    if (encodingInfo.swizzle) {
       int targetRank = unsetEncodingOp.getResultType().getRank();
       auto srcConvertedType =
           cast<RankedTensorType>(adaptor.getSource().getType());
       SmallVector<OpFoldResult> emptyShape =
           tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
       emptyShape.resize(targetRank);
-      for (auto i :
-           getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape)) {
+      for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
         emptyShape.push_back(rewriter.getIndexAttr(i));
       }
       auto emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -197,7 +192,7 @@
 
       SmallVector<int64_t> transposePerm =
           llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
-      for (auto perm : maybeEncodingInfo->swizzle->permutation) {
+      for (auto perm : encodingInfo.swizzle->permutation) {
         transposePerm.push_back(targetRank + perm);
       }
       auto invertedTransposePerm = invertPermutationVector(transposePerm);
@@ -205,11 +200,11 @@
           loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
 
       SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
-          targetRank, maybeEncodingInfo->swizzle->expandShape);
+          targetRank, encodingInfo.swizzle->expandShape);
       SmallVector<int64_t> unpackSrcShape(
           srcConvertedType.getShape().take_front(targetRank));
-      unpackSrcShape.append(maybeEncodingInfo->innerTileSizes.begin(),
-                            maybeEncodingInfo->innerTileSizes.end());
+      unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
+                            encodingInfo.innerTileSizes.end());
       RankedTensorType unpackSrcType =
           unsetEncodingOp.getResultType().clone(unpackSrcShape);
       unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index 4d36b53..84b8540 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -126,14 +126,11 @@
     Value source, const MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
   RankedTensorType resultType = encodingOp.getResultType();
-  FailureOr<MaterializeEncodingInfo> encodingInfo =
+  MaterializeEncodingInfo encodingInfo =
       typeConverter.getEncodingInfo(resultType);
-  if (failed(encodingInfo)) {
-    return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
-  }
 
   // Shortcut to avoid creating new operations.
-  if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
+  if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
     return source;
   }
 
@@ -142,13 +139,13 @@
     return failure();
   }
   if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
-    transposeInPlace(*encodingInfo);
+    transposeInPlace(encodingInfo);
   }
 
   // Create `tensor.empty` operation for the result of the pack operation.
   Location loc = encodingOp.getLoc();
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
-      rewriter, loc, resultType, *encodingInfo, materializeEncodingValueFn);
+      rewriter, loc, resultType, encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizesOfr)) {
     return rewriter.notifyMatchFailure(
         encodingOp, "failed to generate runtime tile size query");
@@ -158,14 +155,14 @@
   SmallVector<OpFoldResult> sourceDims =
       tensor::getMixedSizes(rewriter, loc, source);
   SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
-      rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
-      encodingInfo->outerDimsPerm);
+      rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
+      encodingInfo.outerDimsPerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
                                                   resultType.getElementType());
   return rewriter
-      .create<tensor::PackOp>(loc, source, emptyOp, encodingInfo->innerDimsPos,
+      .create<tensor::PackOp>(loc, source, emptyOp, encodingInfo.innerDimsPos,
                               *innerTileSizesOfr, paddingValue,
-                              encodingInfo->outerDimsPerm)
+                              encodingInfo.outerDimsPerm)
       .getResult();
 }
 
@@ -174,20 +171,17 @@
     Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
   RankedTensorType sourceType = encodingOp.getSourceType();
-  FailureOr<MaterializeEncodingInfo> encodingInfo =
+  MaterializeEncodingInfo encodingInfo =
       typeConverter.getEncodingInfo(sourceType);
-  if (failed(encodingInfo)) {
-    return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
-  }
 
   // Shortcut to avoid creating new operations.
-  if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
+  if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
     return packedValue;
   }
 
   auto encoding = IREE::Encoding::getEncodingAttr(sourceType);
   if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
-    transposeInPlace(*encodingInfo);
+    transposeInPlace(encodingInfo);
   }
   // Create an `tensor.empty` for the result of the unpack operation.
   Location loc = encodingOp.getLoc();
@@ -197,15 +191,15 @@
   auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
                                                   sourceType.getElementType());
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
-      rewriter, loc, sourceType, *encodingInfo, materializeEncodingValueFn);
+      rewriter, loc, sourceType, encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizesOfr)) {
     return rewriter.notifyMatchFailure(
         encodingOp, "failed to generate runtime tile size query");
   }
   return rewriter
       .create<tensor::UnPackOp>(loc, packedValue, emptyOp,
-                                encodingInfo->innerDimsPos, *innerTileSizesOfr,
-                                encodingInfo->outerDimsPerm)
+                                encodingInfo.innerDimsPos, *innerTileSizesOfr,
+                                encodingInfo.outerDimsPerm)
       .getResult();
 }
 
@@ -217,22 +211,23 @@
                     const MaterializeEncodingTypeConverter &typeConverter,
                     MaterializeEncodingValueFn materializeEncodingValueFn) {
   auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
-  FailureOr<MaterializeEncodingInfo> encodingInfo =
+  MaterializeEncodingInfo encodingInfo =
       typeConverter.getEncodingInfo(emptyType);
   Location loc = emptyOp.getLoc();
-  if (failed(encodingInfo)) {
-    Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
-        loc, emptyOp.getMixedSizes(), emptyType.getElementType());
-    return newEmptyOp;
+  if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
+    return rewriter
+        .create<tensor::EmptyOp>(loc, emptyOp.getMixedSizes(),
+                                 emptyType.getElementType())
+        .getOperation();
   }
 
   if (typeConverter.getTransposeNarrowN() &&
       isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
-    transposeInPlace(*encodingInfo);
+    transposeInPlace(encodingInfo);
   }
 
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
-      rewriter, loc, emptyType, *encodingInfo, materializeEncodingValueFn);
+      rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizesOfr)) {
     return rewriter.notifyMatchFailure(
         emptyOp, "failed to generate runtime tile size query");
@@ -241,9 +236,9 @@
   SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
   (void)foldDynamicIndexList(sourceDims);
   SmallVector<OpFoldResult> newShape = tensor::PackOp::getResultShape(
-      rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
-      encodingInfo->outerDimsPerm);
-  newShape = getSwizzledShape(newShape, *encodingInfo);
+      rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
+      encodingInfo.outerDimsPerm);
+  newShape = getSwizzledShape(newShape, encodingInfo);
   Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
       loc, newShape, emptyType.getElementType());
   return newEmptyOp;
@@ -262,10 +257,10 @@
     return rewriter.notifyMatchFailure(genericOp,
                                        "Output indexing map is not identity");
   }
-  FailureOr<MaterializeEncodingInfo> outMaterializeEncodingInfo =
+  MaterializeEncodingInfo outMaterializeEncodingInfo =
       typeConverter.getEncodingInfo(
           cast<RankedTensorType>(outputOperand->get().getType()));
-  if (failed(outMaterializeEncodingInfo)) {
+  if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(
         genericOp, "MaterializeEncodingInfo failed for output");
   }
@@ -277,20 +272,20 @@
   // Compute the new indexing maps for the packed layout. This assumes that
   // the output map is identity, and that all iterator types are parallel.
   SmallVector<int64_t> outInnerDimsPos =
-      outMaterializeEncodingInfo->innerDimsPos;
+      outMaterializeEncodingInfo.innerDimsPos;
   SmallVector<int64_t> outInverseOuterDimsPerm =
-      invertPermutationVector(outMaterializeEncodingInfo->outerDimsPerm);
+      invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
   SmallVector<AffineMap> packedIndexingMaps;
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
-    FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+    MaterializeEncodingInfo materializeEncodingInfo =
         typeConverter.getEncodingInfo(
             cast<RankedTensorType>(inputOperand->get().getType()));
-    if (failed(materializeEncodingInfo)) {
+    if (IREE::Codegen::isIdentityLayout(materializeEncodingInfo)) {
       return rewriter.notifyMatchFailure(
           genericOp, "MaterializeEncodingInfo failed for input");
     }
-    SmallVector<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
-    SmallVector<int64_t> outerDimsPerm = materializeEncodingInfo->outerDimsPerm;
+    ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo.innerDimsPos;
+    ArrayRef<int64_t> outerDimsPerm = materializeEncodingInfo.outerDimsPerm;
     AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
     // Permute result dims to the input packed domain, and map dims to the
     // output packed domain.
@@ -388,28 +383,28 @@
     return failure();
   }
 
-  FailureOr<MaterializeEncodingInfo> encodingInfo =
+  MaterializeEncodingInfo encodingInfo =
       typeConverter.getEncodingInfo(boundTensorType);
-  if (failed(encodingInfo)) {
+  if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
     return failure();
   }
   if (typeConverter.getTransposeNarrowN() &&
       isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
-    transposeInPlace(*encodingInfo);
+    transposeInPlace(encodingInfo);
   }
 
   SmallVector<OpFoldResult> targetShape =
       getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
   auto innerTileSizes = getInnerTileSizesOfr(
-      builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
+      builder, loc, boundTensorType, encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizes)) {
     return failure();
   }
   SmallVector<OpFoldResult> convertedTargetShape =
       tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
-                                     encodingInfo->innerDimsPos,
-                                     encodingInfo->outerDimsPerm);
-  return getSwizzledShape(convertedTargetShape, *encodingInfo);
+                                     encodingInfo.innerDimsPos,
+                                     encodingInfo.outerDimsPerm);
+  return getSwizzledShape(convertedTargetShape, encodingInfo);
 }
 
 /// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
@@ -756,17 +751,10 @@
       return success();
     }
 
-    // TODO(hanchung): This is a transition state for moving the implementation
-    // details to backend attributes. We won't need the function type argument
-    // after all the backends that support encodings implement the attribute.
-    auto getEncodingInfoWrapper =
-        [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
-      return converter->getEncodingInfo(type);
-    };
     FailureOr<Operation *> convertedOp =
         IREE::Codegen::lowerContractionOpWithEncoding(
             rewriter, op, operands, converter->getTransposeNarrowN(),
-            getEncodingInfoWrapper);
+            converter->getLayoutAttr());
     if (failed(convertedOp)) {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
index c4ad11e..c152aee 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
@@ -90,8 +90,5 @@
   std::optional<TileSwizzle> swizzle;
 };
 
-using ResolveEncodingInfoFn =
-    std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType type)>;
-
 } // namespace mlir::iree_compiler::IREE::Codegen
 #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
index bf0a569..32dbc46 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
@@ -510,7 +510,7 @@
 FailureOr<Operation *>
 lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
                                ValueRange operands, bool transposeNarrowN,
-                               ResolveEncodingInfoFn getEncodingInfo) {
+                               LayoutAttrInterface layoutAttr) {
   if (!linalgOp.hasPureTensorSemantics()) {
     return failure();
   }
@@ -535,42 +535,42 @@
     return failure();
   }
 
-  FailureOr<MaterializeEncodingInfo> encodingInfo =
-      getEncodingInfo(cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
+  MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo(
+      cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
 
+  if (isIdentityLayout(encodingInfo)) {
+    return dropEncodingAndCloneOp(builder, linalgOp,
+                                  operands.take_front(inputs.size()),
+                                  operands.drop_front(inputs.size()));
+  }
+
+  bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
+  SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
+  SmallVector<ReassociationIndices> ri;
+  Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri,
+                                 elemTypes, /*operandIdx=*/0);
+  Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri,
+                                 elemTypes, /*operandIdx=*/1);
+  Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
+                                    ri, elemTypes, /*operandIdx=*/2);
+  if (transpose) {
+    std::swap(newLhs, newRhs);
+  }
+  Type newResultType = newResult.getType();
+  auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
   Operation *result;
-  if (failed(encodingInfo) || isIdentityLayout(encodingInfo.value())) {
-    result = dropEncodingAndCloneOp(builder, linalgOp,
-                                    operands.take_front(inputs.size()),
-                                    operands.drop_front(inputs.size()));
+  if (cDims->batch.empty()) {
+    result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
+                                             ValueRange{newLhs, newRhs},
+                                             ValueRange{newResult});
   } else {
-    bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
-    SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
-    SmallVector<ReassociationIndices> ri;
-    Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder,
-                                   ri, elemTypes, /*operandIdx=*/0);
-    Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder,
-                                   ri, elemTypes, /*operandIdx=*/1);
-    Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
-                                      ri, elemTypes, /*operandIdx=*/2);
-    if (transpose) {
-      std::swap(newLhs, newRhs);
-    }
-    Type newResultType = newResult.getType();
-    auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
-    if (cDims->batch.empty()) {
-      result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
-                                               ValueRange{newLhs, newRhs},
-                                               ValueRange{newResult});
-    } else {
-      result = builder.create<linalg::BatchMmt4DOp>(
-          linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
-          ValueRange{newResult});
-    }
-    if (!ri.empty()) {
-      result = builder.create<tensor::CollapseShapeOp>(
-          linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
-    }
+    result = builder.create<linalg::BatchMmt4DOp>(
+        linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
+        ValueRange{newResult});
+  }
+  if (!ri.empty()) {
+    result = builder.create<tensor::CollapseShapeOp>(
+        linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
   }
   return result;
 }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
index f08aec2..1bee3ec 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_
 #define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_
 
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
 #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
 #include "llvm/Support/raw_ostream.h"
@@ -95,7 +96,7 @@
 FailureOr<Operation *>
 lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
                                ValueRange operands, bool transposeNarrowN,
-                               ResolveEncodingInfoFn getEncodingInfo);
+                               LayoutAttrInterface layoutAttr);
 
 } // namespace mlir::iree_compiler::IREE::Codegen
 
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
index 53c1fed..a847abd 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.h"
 
 #include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
@@ -308,12 +309,9 @@
       return nullptr;
     }
 
-    auto resolver =
-        [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
-      return getEncodingInfo(layoutAttr, type);
-    };
     FailureOr<Operation *> newOp = Codegen::lowerContractionOpWithEncoding(
-        b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, resolver);
+        b, linalgOp, convertedOperands, /*transposeNarrowN=*/true,
+        cast<IREE::Codegen::LayoutAttrInterface>(layoutAttr));
     return newOp.value_or(nullptr);
   }
 };
@@ -395,12 +393,9 @@
       return nullptr;
     }
 
-    auto resolver =
-        [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
-      return getEncodingInfo(layoutAttr, type);
-    };
     FailureOr<Operation *> newOp = Codegen::lowerContractionOpWithEncoding(
-        b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, resolver);
+        b, linalgOp, convertedOperands, /*transposeNarrowN=*/true,
+        cast<IREE::Codegen::LayoutAttrInterface>(layoutAttr));
     return newOp.value_or(nullptr);
   }
 };