[NFC] Switch to use upstream mlir::verifyCompatibleShape method. (#12243)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp index 33fd63f..a7288cb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
@@ -28,22 +28,6 @@ namespace IREE { namespace Codegen { -/// Returns true if the dimensions of ShapedType are compatible. -static bool isShapedTypeDimCompatible(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || - lhs == rhs; -} - -/// Returns true if the dimensions of ShapedType are compatible. -static bool areShapesCompatible(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - return llvm::all_of(llvm::zip(lhs, rhs), [](std::tuple<int64_t, int64_t> it) { - return isShapedTypeDimCompatible(std::get<0>(it), std::get<1>(it)); - }); -} - /// Helper method to generate a function declaration at a module scope, /// and a call to that function static FailureOr<func::CallOp> createFunctionCall(RewriterBase &rewriter,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp index b6d691e..602eaa0 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -87,22 +87,6 @@ .Default([&](Type t) { return nullptr; }); } -/// Returns true if the dimensions of ShapedType are compatible. -static bool isShapedTypeDimCompatible(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || - lhs == rhs; -} - -/// Returns true if the dimensions of ShapedType are compatible. -static bool areShapesCompatible(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - return llvm::all_of(llvm::zip(lhs, rhs), [](std::tuple<int64_t, int64_t> it) { - return isShapedTypeDimCompatible(std::get<0>(it), std::get<1>(it)); - }); -} - /// Return true if `dimsPos` is invalid. It is invalid when: a) it contains /// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and < /// rank). c) the number of elements in `dimsPos` is > than `rank`. @@ -1285,13 +1269,11 @@ // Input indicies and values must have the same shape. if (auto inputIndices = indices()) { auto inputIndicesType = inputIndices->getType().cast<ShapedType>(); - if (!areShapesCompatible(inputValuesType.getShape(), - inputIndicesType.getShape())) + if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) return op->emitOpError("input indices/values shape must match"); } // Output indicies and values must have the same shape. - if (!areShapesCompatible(outputValuesType.getShape(), - outputIndicesType.getShape())) + if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) return op->emitOpError("output indices/values shape must match"); // Input shape must match the output shape except for the dimension() uint64_t dim = getDimension(); @@ -1302,8 +1284,8 @@ return true; } std::tuple<int64_t, int64_t> s = e.value(); - return isShapedTypeDimCompatible(std::get<0>(s), - std::get<1>(s)); + return succeeded(verifyCompatibleShape(std::get<0>(s), + std::get<1>(s))); })) { return op->emitOpError("incompatible input/output shapes"); } @@ -2480,7 +2462,7 @@ if (isNchw()) { permute<Permutation::TTNCHW_TO_TTNHWC>(expectedOutputShape); } - if (!areShapesCompatible(expectedOutputShape, outputShape)) { + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return op->emitOpError("incompatible output shape"); } return success(); @@ -2642,7 +2624,7 @@ expectedOutputShape[outputIndex] = outputTileSize * inputShape[i]; } } - if (!areShapesCompatible(expectedOutputShape, outputShape)) { + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return op->emitOpError("incompatible output shape"); } return success(); @@ -2762,7 +2744,7 @@ auto outputType = output().getType().cast<ShapedType>(); ArrayRef<int64_t> inputShape = inputType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape(); - if (!areShapesCompatible(inputShape, outputShape)) { + if (failed(verifyCompatibleShape(inputShape, outputShape))) { return op->emitOpError("incompatible output shape"); } int64_t inputRank = getInputOperandRank(); @@ -2861,11 +2843,11 @@ ArrayRef<int64_t> keyShape = keyType.getShape(); ArrayRef<int64_t> valueShape = valueType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape(); - if (!areShapesCompatible(queryShape, keyShape)) + if (failed(verifyCompatibleShape(queryShape, keyShape))) return op->emitOpError("incompatible key shape"); - if (!areShapesCompatible(queryShape, valueShape)) + if (failed(verifyCompatibleShape(queryShape, valueShape))) return op->emitOpError("incompatible value shape"); - if (!areShapesCompatible(queryShape, outputShape)) + if (failed(verifyCompatibleShape(queryShape, outputShape))) return op->emitOpError("incompatible output shape"); return success(); } @@ -3019,8 +3001,7 @@ // The source and result must have the same rank. if (getResultType().getRank() != getSourceType().getRank()) return emitOpError("cannot change the rank of the tensor"); - if (!areShapesCompatible(getResultType().getShape(), - getSourceType().getShape())) + if (failed(verifyCompatibleShape(getResultType(), getSourceType()))) return emitOpError("expected to preserve the logical shape of the tensor"); return success(); } @@ -3059,8 +3040,7 @@ // The source and result must have the same rank. if (getResultType().getRank() != getSourceType().getRank()) return emitOpError("cannot change the rank of the tensor"); - if (!areShapesCompatible(getResultType().getShape(), - getSourceType().getShape())) + if (failed(verifyCompatibleShape(getResultType(), getSourceType()))) return emitOpError("expected to preserve the logical shape of the tensor"); return success(); }