[DT] Drop the data-tiling hint after encodings are set. (#21724)
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.h
index 8ddc6af..c8fe810 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.h
@@ -26,6 +26,12 @@
op->setAttr(kDataTilingHint, UnitAttr::get(op->getContext()));
}
+/// Removes the attribute with `kDataTilingHint` key from the operation, if it
+/// exists.
+inline void removeDataTilingHint(Operation *op) {
+ (void)op->removeAttr(kDataTilingHint);
+}
+
/// Returns the encoding attribute from the type if there is an encoding that
/// implements SerializableAttr. Otherwise, returns null.
SerializableAttr getSerializableAttr(RankedTensorType type);
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index d6025fe..b2f9a56 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -449,6 +449,7 @@
SmallVector<linalg::LinalgOp> candidates =
getDataTilingCandidates(funcOp);
for (linalg::LinalgOp linalgOp : candidates) {
+ IREE::Encoding::removeDataTilingHint(linalgOp);
if (failed(
setDataTilingEncodings(rewriter, linalgOp, encodingOption))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
index 5161ab1..68d2b07 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
@@ -30,6 +30,7 @@
// CHECK-ALL: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-ALL-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
// CHECK-ALL: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-NOT: iree.opt.data_tiling
// CHECK-ALL-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-ALL-SAME: outs(%[[OUTS]] :
// CHECK-ALL: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<100x500xf32, #[[OUT_ENCODING]]> -> tensor<100x500xf32>