[NFC] Switch to use upstream mlir::verifyCompatibleShape method. (#12243)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
index 33fd63f..a7288cb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
@@ -28,22 +28,6 @@
namespace IREE {
namespace Codegen {
-/// Returns true if the dimensions of ShapedType are compatible.
-static bool isShapedTypeDimCompatible(int64_t lhs, int64_t rhs) {
- return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic ||
- lhs == rhs;
-}
-
-/// Returns true if the dimensions of ShapedType are compatible.
-static bool areShapesCompatible(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
- if (lhs.size() != rhs.size()) {
- return false;
- }
- return llvm::all_of(llvm::zip(lhs, rhs), [](std::tuple<int64_t, int64_t> it) {
- return isShapedTypeDimCompatible(std::get<0>(it), std::get<1>(it));
- });
-}
-
/// Helper method to generate a function declaration at a module scope,
/// and a call to that function
static FailureOr<func::CallOp> createFunctionCall(RewriterBase &rewriter,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index b6d691e..602eaa0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -87,22 +87,6 @@
.Default([&](Type t) { return nullptr; });
}
-/// Returns true if the dimensions of ShapedType are compatible.
-static bool isShapedTypeDimCompatible(int64_t lhs, int64_t rhs) {
- return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic ||
- lhs == rhs;
-}
-
-/// Returns true if the dimensions of ShapedType are compatible.
-static bool areShapesCompatible(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
- if (lhs.size() != rhs.size()) {
- return false;
- }
- return llvm::all_of(llvm::zip(lhs, rhs), [](std::tuple<int64_t, int64_t> it) {
- return isShapedTypeDimCompatible(std::get<0>(it), std::get<1>(it));
- });
-}
-
/// Return true if `dimsPos` is invalid. It is invalid when: a) it contains
/// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and <
/// rank). c) the number of elements in `dimsPos` is > than `rank`.
@@ -1285,13 +1269,11 @@
// Input indicies and values must have the same shape.
if (auto inputIndices = indices()) {
auto inputIndicesType = inputIndices->getType().cast<ShapedType>();
- if (!areShapesCompatible(inputValuesType.getShape(),
- inputIndicesType.getShape()))
+ if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
return op->emitOpError("input indices/values shape must match");
}
// Output indicies and values must have the same shape.
- if (!areShapesCompatible(outputValuesType.getShape(),
- outputIndicesType.getShape()))
+ if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
return op->emitOpError("output indices/values shape must match");
// Input shape must match the output shape except for the dimension()
uint64_t dim = getDimension();
@@ -1302,8 +1284,8 @@
return true;
}
std::tuple<int64_t, int64_t> s = e.value();
- return isShapedTypeDimCompatible(std::get<0>(s),
- std::get<1>(s));
+ return succeeded(verifyCompatibleShape(std::get<0>(s),
+ std::get<1>(s)));
})) {
return op->emitOpError("incompatible input/output shapes");
}
@@ -2480,7 +2462,7 @@
if (isNchw()) {
permute<Permutation::TTNCHW_TO_TTNHWC>(expectedOutputShape);
}
- if (!areShapesCompatible(expectedOutputShape, outputShape)) {
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
@@ -2642,7 +2624,7 @@
expectedOutputShape[outputIndex] = outputTileSize * inputShape[i];
}
}
- if (!areShapesCompatible(expectedOutputShape, outputShape)) {
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
@@ -2762,7 +2744,7 @@
auto outputType = output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
- if (!areShapesCompatible(inputShape, outputShape)) {
+ if (failed(verifyCompatibleShape(inputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
int64_t inputRank = getInputOperandRank();
@@ -2861,11 +2843,11 @@
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
- if (!areShapesCompatible(queryShape, keyShape))
+ if (failed(verifyCompatibleShape(queryShape, keyShape)))
return op->emitOpError("incompatible key shape");
- if (!areShapesCompatible(queryShape, valueShape))
+ if (failed(verifyCompatibleShape(queryShape, valueShape)))
return op->emitOpError("incompatible value shape");
- if (!areShapesCompatible(queryShape, outputShape))
+ if (failed(verifyCompatibleShape(queryShape, outputShape)))
return op->emitOpError("incompatible output shape");
return success();
}
@@ -3019,8 +3001,7 @@
// The source and result must have the same rank.
if (getResultType().getRank() != getSourceType().getRank())
return emitOpError("cannot change the rank of the tensor");
- if (!areShapesCompatible(getResultType().getShape(),
- getSourceType().getShape()))
+ if (failed(verifyCompatibleShape(getResultType(), getSourceType())))
return emitOpError("expected to preserve the logical shape of the tensor");
return success();
}
@@ -3059,8 +3040,7 @@
// The source and result must have the same rank.
if (getResultType().getRank() != getSourceType().getRank())
return emitOpError("cannot change the rank of the tensor");
- if (!areShapesCompatible(getResultType().getShape(),
- getSourceType().getShape()))
+ if (failed(verifyCompatibleShape(getResultType(), getSourceType())))
return emitOpError("expected to preserve the logical shape of the tensor");
return success();
}