Add fallback for undo-ing encodings. (#15302)

If the callback function returns a failure (which means that the
configuration is not implemented yet), it will undo the encodings.

The revision uses i32.i32.i32 as a test case.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 5207163..1fff730 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -50,17 +50,20 @@
                                originalType.getElementType(), encoding);
 }
 
+static RankedTensorType dropEncoding(RankedTensorType type) {
+  return RankedTensorType::get(type.getShape(), type.getElementType());
+}
+
 /// For a given tensor type with an encoding, return the materialized
 /// type to use for it. If no encoding is set, then return the tensor type
 /// itself.
 static RankedTensorType
 getMaterializedType(RankedTensorType tensorType,
                     MaterializeEncodingFn materializeEncodingFn) {
-
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
       materializeEncodingFn(tensorType);
   if (failed(materializeEncodingInfo)) {
-    return tensorType;
+    return dropEncoding(tensorType);
   }
   return tensor::PackOp::inferPackedType(
              getOriginalTypeWithEncoding(tensorType),
@@ -70,15 +73,30 @@
       .cast<RankedTensorType>();
 }
 
+static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
+                                         ValueRange convertedInputOperands,
+                                         ValueRange convertedOutputOperands) {
+  SmallVector<Value> operands;
+  operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
+  operands.append(convertedOutputOperands.begin(),
+                  convertedOutputOperands.end());
+  return mlir::clone(
+      builder, op,
+      {dropEncoding(
+          convertedOutputOperands[0].getType().cast<RankedTensorType>())},
+      operands);
+}
+
 //===---------------------------------------------------------------------===//
 // Methods to convert the encoding to parameters of the Pack operation
 //===---------------------------------------------------------------------===//
 
 /// Given the `encoding` return the `MaterializeEncodingInfo` to use for
-/// materializing the pack op.
-// TODO(ravishankarm): This is currently hard-coded here for convenience. When
-// used in IREE, this will be computed based on the architecture information in
-// `hal.executable.variant`.
+/// materializing the pack op. This is mainly for testing. The configurations
+/// are arbitrary values.
+// TODO(hanchung): Move the implementation to Codegen/Common. This is currently
+// hard-coded here for testing convenience. When used in IREE, this will be
+// computed based on the architecture information in `hal.executable.variant`.
 // A real implementation would return tile sizes that depend on at least the
 // `tensorType`'s element type (e.g. different tile sizes for i8 vs f32, because
 // the SIMD instructions may have different shapes).
@@ -93,12 +111,15 @@
 
   auto user = encoding.getUser().getValue();
   auto role = encoding.getRole().getValue();
+  // Below is for testing purpose. It only materialize for f32 cases.
   switch (user) {
   case EncodingUser::MATMUL:
   case EncodingUser::BATCH_MATMUL:
-    return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
+    if (tensorType.getElementType().isF32()) {
+      return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
+    }
   }
-  llvm_unreachable("unhandled EncodingUser case");
+  return failure();
 }
 
 //===---------------------------------------------------------------------===//
@@ -227,11 +248,10 @@
 /// - rhs encoding with role=RHS
 /// - result encoding with role=RESULT
 /// to linalg.mmt4d op.
-static FailureOr<Operation *>
-lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
-                    ValueRange convertedInputOperands,
-                    ValueRange convertedOutputOperands, MaterializeEncodingFn,
-                    MaterializeEncodingValueFn) {
+static FailureOr<Operation *> lowerOpWithEncoding(
+    RewriterBase &rewriter, linalg::MatmulOp matmulOp,
+    ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
+    MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn) {
   if (!matmulOp.hasTensorSemantics())
     return failure();
   auto inputs = matmulOp.getDpsInputOperands();
@@ -256,10 +276,20 @@
           mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) {
     return failure();
   }
-  Operation *mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
-      matmulOp.getLoc(), convertedOutputOperands[0].getType(),
-      convertedInputOperands, convertedOutputOperands);
-  return mmt4DOp;
+
+  FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+      materializeEncodingFn(getOriginalTypeWithEncoding(
+          matmulOp.getResultTypes()[0].cast<RankedTensorType>()));
+  Operation *result;
+  if (failed(materializeEncodingInfo)) {
+    result = dropEncodingAndCloneOp(rewriter, matmulOp, convertedInputOperands,
+                                    convertedOutputOperands);
+  } else {
+    result = rewriter.create<linalg::Mmt4DOp>(
+        matmulOp.getLoc(), convertedOutputOperands[0].getType(),
+        convertedInputOperands, convertedOutputOperands);
+  }
+  return result;
 }
 
 /// Utility method to convert from `linalg.batch_matmul` with
@@ -267,11 +297,10 @@
 /// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
 /// - result encoding with user=BATCH_MATMUL_*, role=RESULT
 /// to linalg.batch_mmt4d op.
-static FailureOr<Operation *>
-lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
-                    ValueRange convertedInputOperands,
-                    ValueRange convertedOutputOperands, MaterializeEncodingFn,
-                    MaterializeEncodingValueFn) {
+static FailureOr<Operation *> lowerOpWithEncoding(
+    RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
+    ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
+    MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn) {
   if (!batchMatmulOp.hasTensorSemantics())
     return failure();
   auto inputs = batchMatmulOp.getDpsInputOperands();
@@ -294,10 +323,20 @@
       resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
     return failure();
   }
-  Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
-      batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
-      convertedInputOperands, convertedOutputOperands);
-  return batchMmt4DOp;
+  FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+      materializeEncodingFn(getOriginalTypeWithEncoding(
+          batchMatmulOp.getResultTypes()[0].cast<RankedTensorType>()));
+  Operation *result;
+  if (failed(materializeEncodingInfo)) {
+    result =
+        dropEncodingAndCloneOp(rewriter, batchMatmulOp, convertedInputOperands,
+                               convertedOutputOperands);
+  } else {
+    result = rewriter.create<linalg::BatchMmt4DOp>(
+        batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
+        convertedInputOperands, convertedOutputOperands);
+  }
+  return result;
 }
 
 /// Utility method to convert from `linalg.fill` on `tensor` type with
@@ -326,10 +365,13 @@
       emptyOp->getResultTypes()[0].cast<RankedTensorType>());
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
       materializeEncodingFn(resultType);
-  if (failed(materializeEncodingInfo)) {
-    return rewriter.notifyMatchFailure(emptyOp, "unhandled result encoding");
-  }
   Location loc = emptyOp.getLoc();
+  if (failed(materializeEncodingInfo)) {
+    Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
+        loc, emptyOp.getMixedSizes(), resultType.getElementType());
+    return newEmptyOp;
+  }
+
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
       getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
                            materializeEncodingValueFn);
@@ -372,16 +414,24 @@
     auto packOp = lowerSetEncodingOpToPackOp(
         rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
         this->materializeEncodingValueFn);
-    if (failed(packOp))
-      return rewriter.notifyMatchFailure(encodingOp,
-                                         "failed to convert to pack op");
+    if (failed(packOp)) {
+      Value result = adaptor.getSource();
+      Type targetType =
+          getTypeConverter()->convertType(encodingOp.getResultType());
+      if (targetType != result.getType()) {
+        result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
+                                                 targetType, result);
+      }
+      rewriter.replaceOp(encodingOp, result);
+      return success();
+    }
     rewriter.replaceOp(encodingOp, packOp->getResult());
     return success();
   }
 };
 
 /// Convert `unset_encoding` op to `unpack` op.
-struct UnsetEncodingOpToPackOpConversion
+struct UnsetEncodingOpToUnPackOpConversion
     : public OpMaterializeEncodingPattern<UnsetEncodingOp> {
   using OpMaterializeEncodingPattern<
       UnsetEncodingOp>::OpMaterializeEncodingPattern;
@@ -396,15 +446,25 @@
     auto unpackOp = lowerUnsetEncodingToUnpackOp(
         rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
         this->materializeEncodingValueFn);
-    if (failed(unpackOp))
-      return rewriter.notifyMatchFailure(encodingOp,
-                                         "failed to convert to unpack op");
+    if (failed(unpackOp)) {
+      Value result = adaptor.getSource();
+      Type targetType =
+          getTypeConverter()->convertType(encodingOp.getResultType());
+      if (targetType != result.getType()) {
+        result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
+                                                 targetType, result);
+      }
+      rewriter.replaceOp(encodingOp, result);
+      return success();
+    }
     rewriter.replaceOp(encodingOp, unpackOp->getResult());
     return success();
   }
 };
 
-/// Convert `upper_bound_tile_size` op to `constant` op.
+/// Convert `upper_bound_tile_size` op to `constant` op. If the
+/// `materializeEncodingFn` returns a failure, the pattern will materialize it
+/// to the same shape.
 struct UpperBoundTileSizeToConstantOpConversion
     : public OpRewritePattern<UpperBoundTileSizeOp> {
   UpperBoundTileSizeToConstantOpConversion(
@@ -418,8 +478,11 @@
     auto constants = lowerUpperBoundTileSizeOpToConstants(
         rewriter, upperBoundTileSizeOp, materializeEncodingFn);
     if (failed(constants)) {
-      return rewriter.notifyMatchFailure(upperBoundTileSizeOp,
-                                         "failed to convert to constant op");
+      SmallVector<Value> results(upperBoundTileSizeOp.getNumResults(),
+                                 rewriter.create<arith::ConstantIndexOp>(
+                                     upperBoundTileSizeOp.getLoc(), 1));
+      rewriter.replaceOp(upperBoundTileSizeOp, results);
+      return success();
     }
     rewriter.replaceOp(upperBoundTileSizeOp, *constants);
     return success();
@@ -560,7 +623,7 @@
                   MaterializeDPSOperation<linalg::BatchMatmulOp>,
                   MaterializeOperation<tensor::EmptyOp>,
                   SetEncodingOpToPackOpConversion,
-                  UnsetEncodingOpToPackOpConversion>(
+                  UnsetEncodingOpToUnPackOpConversion>(
       patterns.getContext(), typeConverter, materializeEncodingValueFn);
   ::mlir::memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 9c87e95..2ca2772 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -338,3 +338,30 @@
 // CHECK-SAME:       outs(%[[FILL]] :
 //      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
 //      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @matmul_i32i32.i32(%arg0: tensor<2x4xi32>, %arg1: tensor<4x2xi32>) -> tensor<2x2xi32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c0_i32 = arith.constant 0 : i32
+  %padded = tensor.pad %arg0 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg2: index, %arg3: index):
+    tensor.yield %c0_i32 : i32
+  } : tensor<2x4xi32> to tensor<?x?xi32>
+  %0 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [i32, i32, i32], original_type = tensor<2x4xi32>>>
+  %padded_0 = tensor.pad %arg1 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg2: index, %arg3: index):
+    tensor.yield %c0_i32 : i32
+  } : tensor<4x2xi32> to tensor<?x?xi32>
+  %1 = iree_linalg_ext.set_encoding %padded_0 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [i32, i32, i32], original_type = tensor<4x2xi32>>>
+  %2 = tensor.empty(%c2, %c2) : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+  %3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+  %4 = linalg.matmul ins(%0, %1 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [i32, i32, i32], original_type = tensor<2x4xi32>>>, tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [i32, i32, i32], original_type = tensor<4x2xi32>>>) outs(%3 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>>
+  %5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i32, i32, i32], original_type = tensor<2x2xi32>>> -> tensor<?x?xi32>
+  %extracted_slice = tensor.extract_slice %5[0, 0] [2, 2] [1, 1] : tensor<?x?xi32> to tensor<2x2xi32>
+  return %extracted_slice : tensor<2x2xi32>
+}
+// CHECK: func.func @matmul_i32i32.i32
+// CHECK:   linalg.fill
+// CHECK:   linalg.matmul