Add fallback for undo-ing encodings. (#15302)
If the callback function returns a failure (which means that the
configuration is not implemented yet), it will undo the encodings.
The revision uses i32.i32.i32 as a test case.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 5207163..1fff730 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -50,17 +50,20 @@
originalType.getElementType(), encoding);
}
+static RankedTensorType dropEncoding(RankedTensorType type) {
+ return RankedTensorType::get(type.getShape(), type.getElementType());
+}
+
/// For a given tensor type with an encoding, return the materialized
/// type to use for it. If no encoding is set, then return the tensor type
/// itself.
static RankedTensorType
getMaterializedType(RankedTensorType tensorType,
MaterializeEncodingFn materializeEncodingFn) {
-
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(tensorType);
if (failed(materializeEncodingInfo)) {
- return tensorType;
+ return dropEncoding(tensorType);
}
return tensor::PackOp::inferPackedType(
getOriginalTypeWithEncoding(tensorType),
@@ -70,15 +73,30 @@
.cast<RankedTensorType>();
}
+static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
+ ValueRange convertedInputOperands,
+ ValueRange convertedOutputOperands) {
+ SmallVector<Value> operands;
+ operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
+ operands.append(convertedOutputOperands.begin(),
+ convertedOutputOperands.end());
+ return mlir::clone(
+ builder, op,
+ {dropEncoding(
+ convertedOutputOperands[0].getType().cast<RankedTensorType>())},
+ operands);
+}
+
//===---------------------------------------------------------------------===//
// Methods to convert the encoding to parameters of the Pack operation
//===---------------------------------------------------------------------===//
/// Given the `encoding` return the `MaterializeEncodingInfo` to use for
-/// materializing the pack op.
-// TODO(ravishankarm): This is currently hard-coded here for convenience. When
-// used in IREE, this will be computed based on the architecture information in
-// `hal.executable.variant`.
+/// materializing the pack op. This is mainly for testing. The configurations
+/// are arbitrary values.
+// TODO(hanchung): Move the implementation to Codegen/Common. This is currently
+// hard-coded here for testing convenience. When used in IREE, this will be
+// computed based on the architecture information in `hal.executable.variant`.
// A real implementation would return tile sizes that depend on at least the
// `tensorType`'s element type (e.g. different tile sizes for i8 vs f32, because
// the SIMD instructions may have different shapes).
@@ -93,12 +111,15 @@
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
+ // Below is for testing purpose. It only materialize for f32 cases.
switch (user) {
case EncodingUser::MATMUL:
case EncodingUser::BATCH_MATMUL:
- return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
+ if (tensorType.getElementType().isF32()) {
+ return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
+ }
}
- llvm_unreachable("unhandled EncodingUser case");
+ return failure();
}
//===---------------------------------------------------------------------===//
@@ -227,11 +248,10 @@
/// - rhs encoding with role=RHS
/// - result encoding with role=RESULT
/// to linalg.mmt4d op.
-static FailureOr<Operation *>
-lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
- ValueRange convertedInputOperands,
- ValueRange convertedOutputOperands, MaterializeEncodingFn,
- MaterializeEncodingValueFn) {
+static FailureOr<Operation *> lowerOpWithEncoding(
+ RewriterBase &rewriter, linalg::MatmulOp matmulOp,
+ ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
+ MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn) {
if (!matmulOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputOperands();
@@ -256,10 +276,20 @@
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) {
return failure();
}
- Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
- matmulOp.getLoc(), convertedOutputOperands[0].getType(),
- convertedInputOperands, convertedOutputOperands);
- return mmt4DOp;
+
+ FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+ materializeEncodingFn(getOriginalTypeWithEncoding(
+ matmulOp.getResultTypes()[0].cast<RankedTensorType>()));
+ Operation *result;
+ if (failed(materializeEncodingInfo)) {
+ result = dropEncodingAndCloneOp(rewriter, matmulOp, convertedInputOperands,
+ convertedOutputOperands);
+ } else {
+ result = rewriter.create<linalg::Mmt4DOp>(
+ matmulOp.getLoc(), convertedOutputOperands[0].getType(),
+ convertedInputOperands, convertedOutputOperands);
+ }
+ return result;
}
/// Utility method to convert from `linalg.batch_matmul` with
@@ -267,11 +297,10 @@
/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
/// - result encoding with user=BATCH_MATMUL_*, role=RESULT
/// to linalg.batch_mmt4d op.
-static FailureOr<Operation *>
-lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
- ValueRange convertedInputOperands,
- ValueRange convertedOutputOperands, MaterializeEncodingFn,
- MaterializeEncodingValueFn) {
+static FailureOr<Operation *> lowerOpWithEncoding(
+ RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
+ ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
+ MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn) {
if (!batchMatmulOp.hasTensorSemantics())
return failure();
auto inputs = batchMatmulOp.getDpsInputOperands();
@@ -294,10 +323,20 @@
resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
return failure();
}
- Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
- batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
- convertedInputOperands, convertedOutputOperands);
- return batchMmt4DOp;
+ FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+ materializeEncodingFn(getOriginalTypeWithEncoding(
+ batchMatmulOp.getResultTypes()[0].cast<RankedTensorType>()));
+ Operation *result;
+ if (failed(materializeEncodingInfo)) {
+ result =
+ dropEncodingAndCloneOp(rewriter, batchMatmulOp, convertedInputOperands,
+ convertedOutputOperands);
+ } else {
+ result = rewriter.create<linalg::BatchMmt4DOp>(
+ batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
+ convertedInputOperands, convertedOutputOperands);
+ }
+ return result;
}
/// Utility method to convert from `linalg.fill` on `tensor` type with
@@ -326,10 +365,13 @@
emptyOp->getResultTypes()[0].cast<RankedTensorType>());
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(resultType);
- if (failed(materializeEncodingInfo)) {
- return rewriter.notifyMatchFailure(emptyOp, "unhandled result encoding");
- }
Location loc = emptyOp.getLoc();
+ if (failed(materializeEncodingInfo)) {
+ Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, emptyOp.getMixedSizes(), resultType.getElementType());
+ return newEmptyOp;
+ }
+
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
materializeEncodingValueFn);
@@ -372,16 +414,24 @@
auto packOp = lowerSetEncodingOpToPackOp(
rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
this->materializeEncodingValueFn);
- if (failed(packOp))
- return rewriter.notifyMatchFailure(encodingOp,
- "failed to convert to pack op");
+ if (failed(packOp)) {
+ Value result = adaptor.getSource();
+ Type targetType =
+ getTypeConverter()->convertType(encodingOp.getResultType());
+ if (targetType != result.getType()) {
+ result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
+ targetType, result);
+ }
+ rewriter.replaceOp(encodingOp, result);
+ return success();
+ }
rewriter.replaceOp(encodingOp, packOp->getResult());
return success();
}
};
/// Convert `unset_encoding` op to `unpack` op.
-struct UnsetEncodingOpToPackOpConversion
+struct UnsetEncodingOpToUnPackOpConversion
: public OpMaterializeEncodingPattern<UnsetEncodingOp> {
using OpMaterializeEncodingPattern<
UnsetEncodingOp>::OpMaterializeEncodingPattern;
@@ -396,15 +446,25 @@
auto unpackOp = lowerUnsetEncodingToUnpackOp(
rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
this->materializeEncodingValueFn);
- if (failed(unpackOp))
- return rewriter.notifyMatchFailure(encodingOp,
- "failed to convert to unpack op");
+ if (failed(unpackOp)) {
+ Value result = adaptor.getSource();
+ Type targetType =
+ getTypeConverter()->convertType(encodingOp.getResultType());
+ if (targetType != result.getType()) {
+ result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
+ targetType, result);
+ }
+ rewriter.replaceOp(encodingOp, result);
+ return success();
+ }
rewriter.replaceOp(encodingOp, unpackOp->getResult());
return success();
}
};
-/// Convert `upper_bound_tile_size` op to `constant` op.
+/// Convert `upper_bound_tile_size` op to `constant` op. If the
+/// `materializeEncodingFn` returns a failure, the pattern will materialize it
+/// to the same shape.
struct UpperBoundTileSizeToConstantOpConversion
: public OpRewritePattern<UpperBoundTileSizeOp> {
UpperBoundTileSizeToConstantOpConversion(
@@ -418,8 +478,11 @@
auto constants = lowerUpperBoundTileSizeOpToConstants(
rewriter, upperBoundTileSizeOp, materializeEncodingFn);
if (failed(constants)) {
- return rewriter.notifyMatchFailure(upperBoundTileSizeOp,
- "failed to convert to constant op");
+ SmallVector<Value> results(upperBoundTileSizeOp.getNumResults(),
+ rewriter.create<arith::ConstantIndexOp>(
+ upperBoundTileSizeOp.getLoc(), 1));
+ rewriter.replaceOp(upperBoundTileSizeOp, results);
+ return success();
}
rewriter.replaceOp(upperBoundTileSizeOp, *constants);
return success();
@@ -560,7 +623,7 @@
MaterializeDPSOperation<linalg::BatchMatmulOp>,
MaterializeOperation<tensor::EmptyOp>,
SetEncodingOpToPackOpConversion,
- UnsetEncodingOpToPackOpConversion>(
+ UnsetEncodingOpToUnPackOpConversion>(
patterns.getContext(), typeConverter, materializeEncodingValueFn);
::mlir::memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 9c87e95..2ca2772 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -338,3 +338,30 @@
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @matmul_i32i32.i32(%arg0: tensor<2x4xi32>, %arg1: tensor<4x2xi32>) -> tensor<2x2xi32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c0_i32 = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[0, 0] high[%c0, %c0] {
+ ^bb0(%arg2: index, %arg3: index):
+ tensor.yield %c0_i32 : i32
+ } : tensor<2x4xi32> to tensor<?x?xi32>
+ %0 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [i32, i32, i32], original_type = tensor<2x4xi32>>>
+ %padded_0 = tensor.pad %arg1 low[0, 0] high[%c0, %c0] {
+ ^bb0(%arg2: index, %arg3: index):
+ tensor.yield %c0_i32 : i32
+ } : tensor<4x2xi32> to tensor<?x?xi32>
+ %1 = iree_linalg_ext.set_encoding %padded_0 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [i32, i32, i32], original_type = tensor<4x2xi32>>>
+ %2 = tensor.empty(%c2, %c2) : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+ %3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+ %4 = linalg.matmul ins(%0, %1 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [i32, i32, i32], original_type = tensor<2x4xi32>>>, tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [i32, i32, i32], original_type = tensor<4x2xi32>>>) outs(%3 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+ %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>> -> tensor<?x?xi32>
+ %extracted_slice = tensor.extract_slice %5[0, 0] [2, 2] [1, 1] : tensor<?x?xi32> to tensor<2x2xi32>
+ return %extracted_slice : tensor<2x2xi32>
+}
+// CHECK: func.func @matmul_i32i32.i32
+// CHECK: linalg.fill
+// CHECK: linalg.matmul