[DT][NFC] Remove FailureOr<> from getEncodingInfo methods. (#19435)
We are able to use identity MaterializationEncodingInfo to represent the
"failure". Thus, we no longer need the `FailureOr` wrapper. The revision
removes the wrapper and updates the `lowerContractionOpWithEncoding`
function type signature. It does not need to pass a callback function.
Instead, we can pass the `IREE::Codegen::LayoutAttrInterface` which has
the method to query the materialization information.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index c6b0d38..05caca8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -100,16 +100,13 @@
// itself.
RankedTensorType tensorType =
transposeNarrowN ? transposeIfNarrowNResult(type) : type;
- FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
- getEncodingInfo(tensorType);
- if (failed(maybeEncodingInfo) ||
- IREE::Codegen::isIdentityLayout(maybeEncodingInfo.value())) {
+ MaterializeEncodingInfo encodingInfo = getEncodingInfo(tensorType);
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return dropEncoding(type);
}
- auto encodingInfo = *maybeEncodingInfo;
auto packedType = cast<RankedTensorType>(tensor::PackOp::inferPackedType(
- tensorType, maybeEncodingInfo->innerTileSizes,
- maybeEncodingInfo->innerDimsPos, maybeEncodingInfo->outerDimsPerm));
+ tensorType, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
+ encodingInfo.outerDimsPerm));
// There is no swizzle, we are already done. Typically the case on CPU.
if (!encodingInfo.swizzle) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index 0a89d3a..3b59cbd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -42,7 +42,7 @@
return layoutAttr;
}
- FailureOr<IREE::Codegen::MaterializeEncodingInfo>
+ IREE::Codegen::MaterializeEncodingInfo
getEncodingInfo(RankedTensorType type) const {
return layoutAttr.getEncodingInfo(type);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
index 3a86d51..92b8ac4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
@@ -108,13 +108,9 @@
return success();
}
- FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
+ MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(encodingOp.getResultType());
- if (failed(maybeEncodingInfo)) {
- return rewriter.notifyMatchFailure(encodingOp,
- "unhandled result encoding");
- }
- if (!maybeEncodingInfo->swizzle) {
+ if (!encodingInfo.swizzle) {
rewriter.replaceOp(encodingOp, packedValue.value());
return success();
}
@@ -128,18 +124,18 @@
.getShape()
.take_front(origRank));
expandShapeShape.append(
- getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape));
+ getExpandedTileShape(encodingInfo.swizzle->expandShape));
RankedTensorType expandShapeType =
encodingOp.getSourceType().clone(expandShapeShape);
- SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
- origRank, maybeEncodingInfo->swizzle->expandShape);
+ SmallVector<ReassociationIndices> reassociation =
+ getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, packedValue.value(), reassociation);
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, origRank));
- for (auto perm : maybeEncodingInfo->swizzle->permutation) {
+ for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(origRank + perm);
}
SmallVector<OpFoldResult> transposeResultDims =
@@ -168,9 +164,9 @@
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
- FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
+ MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
- if (failed(maybeEncodingInfo)) {
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
Type targetType =
getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
Value result = rewriter.createOrFold<tensor::CastOp>(
@@ -181,15 +177,14 @@
Location loc = unsetEncodingOp.getLoc();
Value unpackSrc = adaptor.getSource();
- if (maybeEncodingInfo->swizzle) {
+ if (encodingInfo.swizzle) {
int targetRank = unsetEncodingOp.getResultType().getRank();
auto srcConvertedType =
cast<RankedTensorType>(adaptor.getSource().getType());
SmallVector<OpFoldResult> emptyShape =
tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
emptyShape.resize(targetRank);
- for (auto i :
- getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape)) {
+ for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
emptyShape.push_back(rewriter.getIndexAttr(i));
}
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -197,7 +192,7 @@
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
- for (auto perm : maybeEncodingInfo->swizzle->permutation) {
+ for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(targetRank + perm);
}
auto invertedTransposePerm = invertPermutationVector(transposePerm);
@@ -205,11 +200,11 @@
loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
- targetRank, maybeEncodingInfo->swizzle->expandShape);
+ targetRank, encodingInfo.swizzle->expandShape);
SmallVector<int64_t> unpackSrcShape(
srcConvertedType.getShape().take_front(targetRank));
- unpackSrcShape.append(maybeEncodingInfo->innerTileSizes.begin(),
- maybeEncodingInfo->innerTileSizes.end());
+ unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
+ encodingInfo.innerTileSizes.end());
RankedTensorType unpackSrcType =
unsetEncodingOp.getResultType().clone(unpackSrcShape);
unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index 4d36b53..84b8540 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -126,14 +126,11 @@
Value source, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType resultType = encodingOp.getResultType();
- FailureOr<MaterializeEncodingInfo> encodingInfo =
+ MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(resultType);
- if (failed(encodingInfo)) {
- return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
- }
// Shortcut to avoid creating new operations.
- if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return source;
}
@@ -142,13 +139,13 @@
return failure();
}
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
- transposeInPlace(*encodingInfo);
+ transposeInPlace(encodingInfo);
}
// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
- rewriter, loc, resultType, *encodingInfo, materializeEncodingValueFn);
+ rewriter, loc, resultType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
@@ -158,14 +155,14 @@
SmallVector<OpFoldResult> sourceDims =
tensor::getMixedSizes(rewriter, loc, source);
SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
- rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
- encodingInfo->outerDimsPerm);
+ rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
+ encodingInfo.outerDimsPerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
resultType.getElementType());
return rewriter
- .create<tensor::PackOp>(loc, source, emptyOp, encodingInfo->innerDimsPos,
+ .create<tensor::PackOp>(loc, source, emptyOp, encodingInfo.innerDimsPos,
*innerTileSizesOfr, paddingValue,
- encodingInfo->outerDimsPerm)
+ encodingInfo.outerDimsPerm)
.getResult();
}
@@ -174,20 +171,17 @@
Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType sourceType = encodingOp.getSourceType();
- FailureOr<MaterializeEncodingInfo> encodingInfo =
+ MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(sourceType);
- if (failed(encodingInfo)) {
- return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
- }
// Shortcut to avoid creating new operations.
- if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return packedValue;
}
auto encoding = IREE::Encoding::getEncodingAttr(sourceType);
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
- transposeInPlace(*encodingInfo);
+ transposeInPlace(encodingInfo);
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
@@ -197,15 +191,15 @@
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
sourceType.getElementType());
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
- rewriter, loc, sourceType, *encodingInfo, materializeEncodingValueFn);
+ rewriter, loc, sourceType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
return rewriter
.create<tensor::UnPackOp>(loc, packedValue, emptyOp,
- encodingInfo->innerDimsPos, *innerTileSizesOfr,
- encodingInfo->outerDimsPerm)
+ encodingInfo.innerDimsPos, *innerTileSizesOfr,
+ encodingInfo.outerDimsPerm)
.getResult();
}
@@ -217,22 +211,23 @@
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
- FailureOr<MaterializeEncodingInfo> encodingInfo =
+ MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(emptyType);
Location loc = emptyOp.getLoc();
- if (failed(encodingInfo)) {
- Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
- loc, emptyOp.getMixedSizes(), emptyType.getElementType());
- return newEmptyOp;
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
+ return rewriter
+ .create<tensor::EmptyOp>(loc, emptyOp.getMixedSizes(),
+ emptyType.getElementType())
+ .getOperation();
}
if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
- transposeInPlace(*encodingInfo);
+ transposeInPlace(encodingInfo);
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
- rewriter, loc, emptyType, *encodingInfo, materializeEncodingValueFn);
+ rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
@@ -241,9 +236,9 @@
SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
(void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape = tensor::PackOp::getResultShape(
- rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
- encodingInfo->outerDimsPerm);
- newShape = getSwizzledShape(newShape, *encodingInfo);
+ rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
+ encodingInfo.outerDimsPerm);
+ newShape = getSwizzledShape(newShape, encodingInfo);
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, newShape, emptyType.getElementType());
return newEmptyOp;
@@ -262,10 +257,10 @@
return rewriter.notifyMatchFailure(genericOp,
"Output indexing map is not identity");
}
- FailureOr<MaterializeEncodingInfo> outMaterializeEncodingInfo =
+ MaterializeEncodingInfo outMaterializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(outputOperand->get().getType()));
- if (failed(outMaterializeEncodingInfo)) {
+ if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for output");
}
@@ -277,20 +272,20 @@
// Compute the new indexing maps for the packed layout. This assumes that
// the output map is identity, and that all iterator types are parallel.
SmallVector<int64_t> outInnerDimsPos =
- outMaterializeEncodingInfo->innerDimsPos;
+ outMaterializeEncodingInfo.innerDimsPos;
SmallVector<int64_t> outInverseOuterDimsPerm =
- invertPermutationVector(outMaterializeEncodingInfo->outerDimsPerm);
+ invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
SmallVector<AffineMap> packedIndexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
- FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+ MaterializeEncodingInfo materializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(inputOperand->get().getType()));
- if (failed(materializeEncodingInfo)) {
+ if (IREE::Codegen::isIdentityLayout(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for input");
}
- SmallVector<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
- SmallVector<int64_t> outerDimsPerm = materializeEncodingInfo->outerDimsPerm;
+ ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo.innerDimsPos;
+ ArrayRef<int64_t> outerDimsPerm = materializeEncodingInfo.outerDimsPerm;
AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
// Permute result dims to the input packed domain, and map dims to the
// output packed domain.
@@ -388,28 +383,28 @@
return failure();
}
- FailureOr<MaterializeEncodingInfo> encodingInfo =
+ MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(boundTensorType);
- if (failed(encodingInfo)) {
+ if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return failure();
}
if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
- transposeInPlace(*encodingInfo);
+ transposeInPlace(encodingInfo);
}
SmallVector<OpFoldResult> targetShape =
getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
auto innerTileSizes = getInnerTileSizesOfr(
- builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
+ builder, loc, boundTensorType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizes)) {
return failure();
}
SmallVector<OpFoldResult> convertedTargetShape =
tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
- encodingInfo->innerDimsPos,
- encodingInfo->outerDimsPerm);
- return getSwizzledShape(convertedTargetShape, *encodingInfo);
+ encodingInfo.innerDimsPos,
+ encodingInfo.outerDimsPerm);
+ return getSwizzledShape(convertedTargetShape, encodingInfo);
}
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
@@ -756,17 +751,10 @@
return success();
}
- // TODO(hanchung): This is a transition state for moving the implementation
- // details to backend attributes. We won't need the function type argument
- // after all the backends that support encodings implement the attribute.
- auto getEncodingInfoWrapper =
- [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
- return converter->getEncodingInfo(type);
- };
FailureOr<Operation *> convertedOp =
IREE::Codegen::lowerContractionOpWithEncoding(
rewriter, op, operands, converter->getTransposeNarrowN(),
- getEncodingInfoWrapper);
+ converter->getLayoutAttr());
if (failed(convertedOp)) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
index c4ad11e..c152aee 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
@@ -90,8 +90,5 @@
std::optional<TileSwizzle> swizzle;
};
-using ResolveEncodingInfoFn =
- std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType type)>;
-
} // namespace mlir::iree_compiler::IREE::Codegen
#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
index bf0a569..32dbc46 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
@@ -510,7 +510,7 @@
FailureOr<Operation *>
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
ValueRange operands, bool transposeNarrowN,
- ResolveEncodingInfoFn getEncodingInfo) {
+ LayoutAttrInterface layoutAttr) {
if (!linalgOp.hasPureTensorSemantics()) {
return failure();
}
@@ -535,42 +535,42 @@
return failure();
}
- FailureOr<MaterializeEncodingInfo> encodingInfo =
- getEncodingInfo(cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
+ MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo(
+ cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
+ if (isIdentityLayout(encodingInfo)) {
+ return dropEncodingAndCloneOp(builder, linalgOp,
+ operands.take_front(inputs.size()),
+ operands.drop_front(inputs.size()));
+ }
+
+ bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
+ SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
+ SmallVector<ReassociationIndices> ri;
+ Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri,
+ elemTypes, /*operandIdx=*/0);
+ Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri,
+ elemTypes, /*operandIdx=*/1);
+ Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
+ ri, elemTypes, /*operandIdx=*/2);
+ if (transpose) {
+ std::swap(newLhs, newRhs);
+ }
+ Type newResultType = newResult.getType();
+ auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
Operation *result;
- if (failed(encodingInfo) || isIdentityLayout(encodingInfo.value())) {
- result = dropEncodingAndCloneOp(builder, linalgOp,
- operands.take_front(inputs.size()),
- operands.drop_front(inputs.size()));
+ if (cDims->batch.empty()) {
+ result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
+ ValueRange{newLhs, newRhs},
+ ValueRange{newResult});
} else {
- bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
- SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
- SmallVector<ReassociationIndices> ri;
- Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder,
- ri, elemTypes, /*operandIdx=*/0);
- Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder,
- ri, elemTypes, /*operandIdx=*/1);
- Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
- ri, elemTypes, /*operandIdx=*/2);
- if (transpose) {
- std::swap(newLhs, newRhs);
- }
- Type newResultType = newResult.getType();
- auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
- if (cDims->batch.empty()) {
- result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
- ValueRange{newLhs, newRhs},
- ValueRange{newResult});
- } else {
- result = builder.create<linalg::BatchMmt4DOp>(
- linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
- ValueRange{newResult});
- }
- if (!ri.empty()) {
- result = builder.create<tensor::CollapseShapeOp>(
- linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
- }
+ result = builder.create<linalg::BatchMmt4DOp>(
+ linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
+ ValueRange{newResult});
+ }
+ if (!ri.empty()) {
+ result = builder.create<tensor::CollapseShapeOp>(
+ linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
}
return result;
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
index f08aec2..1bee3ec 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_
#define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "llvm/Support/raw_ostream.h"
@@ -95,7 +96,7 @@
FailureOr<Operation *>
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
ValueRange operands, bool transposeNarrowN,
- ResolveEncodingInfoFn getEncodingInfo);
+ LayoutAttrInterface layoutAttr);
} // namespace mlir::iree_compiler::IREE::Codegen
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
index 53c1fed..a847abd 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
@@ -308,12 +309,9 @@
return nullptr;
}
- auto resolver =
- [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
- return getEncodingInfo(layoutAttr, type);
- };
FailureOr<Operation *> newOp = Codegen::lowerContractionOpWithEncoding(
- b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, resolver);
+ b, linalgOp, convertedOperands, /*transposeNarrowN=*/true,
+ cast<IREE::Codegen::LayoutAttrInterface>(layoutAttr));
return newOp.value_or(nullptr);
}
};
@@ -395,12 +393,9 @@
return nullptr;
}
- auto resolver =
- [&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
- return getEncodingInfo(layoutAttr, type);
- };
FailureOr<Operation *> newOp = Codegen::lowerContractionOpWithEncoding(
- b, linalgOp, convertedOperands, /*transposeNarrowN=*/true, resolver);
+ b, linalgOp, convertedOperands, /*transposeNarrowN=*/true,
+ cast<IREE::Codegen::LayoutAttrInterface>(layoutAttr));
return newOp.value_or(nullptr);
}
};