[Stream] Update more op folders to verify matching types (#16070)
Fixes following updates to stricter folder verification rolled in with
#16012
Additionally cleans up some surrounding code and removes unused
includes.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index ab8a9f2..810f547 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -5,22 +5,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <algorithm>
-#include <numeric>
#include <optional>
-#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/ADT/StringExtras.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
@@ -28,7 +21,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
@@ -1223,7 +1215,8 @@
// If operand comes from import with the same properties then fold.
// These checks are conservative, since encoding changes may be meaningful.
auto importOp = getSource().getDefiningOp<TensorImportOp>();
- if (importOp && getSourceEncoding() == importOp.getResultEncoding() &&
+ if (importOp && importOp.getSource().getType() == getType() &&
+ getSourceEncoding() == importOp.getResultEncoding() &&
getSourceEncodingDims() == importOp.getResultEncodingDims() &&
getSourceSize() == importOp.getResultSize() &&
getAffinity() == importOp.getAffinity()) {
@@ -1358,10 +1351,8 @@
// stream.tensor.clone
//===----------------------------------------------------------------------===//
-OpFoldResult TensorCloneOp::fold(FoldAdaptor operands) {
- auto users = getResult().getUsers();
- if (!users.empty() && std::next(users.begin()) == users.end()) {
- // If the second user is the end it means there's one user.
+OpFoldResult TensorCloneOp::fold(FoldAdaptor) {
+ if (getResult().hasOneUse() && getType() == getSource().getType()) {
return getSource();
}
return {};
@@ -1374,7 +1365,8 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TensorCloneOp cloneOp,
PatternRewriter &rewriter) const override {
- if (!IREE::Util::TiedOpInterface::hasAnyTiedUses(cloneOp.getResult())) {
+ if (cloneOp.getType() == cloneOp.getSource().getType() &&
+ !IREE::Util::TiedOpInterface::hasAnyTiedUses(cloneOp.getResult())) {
rewriter.replaceOp(cloneOp, cloneOp.getSource());
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
index 4ce5944..1d6aadf 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
@@ -24,6 +24,18 @@
return %1 : !hal.buffer_view
}
+// -----
+
+// CHECK-LABEL: @NofoldTensorExportOpBufferToView
+func.func @NofoldTensorExportOpBufferToView(%arg0: !hal.buffer, %arg1: index) -> !hal.buffer_view {
+ // CHECK: %[[IMPORT:.+]] = stream.tensor.import
+ // CHECK: %[[EXPORT:.+]] = stream.tensor.export %[[IMPORT]]
+ // CHECK: return %[[EXPORT]] : !hal.buffer_view
+ %c20 = arith.constant 20 : index
+ %0 = stream.tensor.import %arg0 : !hal.buffer -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
+ %1 = stream.tensor.export %0 : tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20} -> !hal.buffer_view
+ return %1 : !hal.buffer_view
+}
// -----
@@ -172,6 +184,16 @@
// -----
+// CHECK-LABEL: @NofoldTensorCloneOp
+func.func @NofoldTensorCloneOp(%arg0: !stream.resource<external>, %arg1: index) -> !stream.resource<*> {
+ // CHECK: %[[CLONE:.+]] = stream.tensor.clone
+ %0 = stream.tensor.clone %arg0 : tensor<2x2xf32> in !stream.resource<external>{%arg1} -> tensor<2x2xf32> in !stream.resource<*>{%arg1}
+ // CHECK: return %[[CLONE]] : !stream.resource<*>
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
// CHECK-LABEL: @ElideUnneededTensorClones
func.func @ElideUnneededTensorClones(%arg0: !stream.resource<*>, %arg1: index) -> f32 {
%c0 = arith.constant 0 : index