[Flow][LinalgExt] Move bitwidth helpers from Flow to LinalgExt (#17855)
Move `isBitExtendOp` and `isBitTruncateOp` into linalgext.
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 7d2a75f..b8c317f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -20,10 +20,10 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -451,11 +451,11 @@
Type initElemType = getElementTypeOrSelf(init);
if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
- if (IREE::Flow::isBitExtendOp(lhsOp))
+ if (IREE::LinalgExt::isBitExtendOp(lhsOp))
lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
}
if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
- if (IREE::Flow::isBitExtendOp(rhsOp))
+ if (IREE::LinalgExt::isBitExtendOp(rhsOp))
rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
index bc108c2..0e75587 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -53,7 +54,7 @@
}
// Do not fuse by expand if consumer is dequant.
- if (isBitExtendOp(consumer)) {
+ if (LinalgExt::isBitExtendOp(consumer)) {
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index 77af6c9..e39096b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
@@ -226,7 +227,7 @@
return false;
}
// Dequantization-like ops get cloned into dispatches later.
- if (isBitExtendOp(op)) {
+ if (LinalgExt::isBitExtendOp(op)) {
return false;
}
// Any Linalg named op or generic op with reduction iterator types is a root
@@ -539,7 +540,7 @@
// If consumer is a dequant operation, dont fuse it. These get cloned
// into their consumers.
- if (isBitExtendOp(consumer)) {
+ if (LinalgExt::isBitExtendOp(consumer)) {
return false;
}
@@ -874,7 +875,7 @@
// materializing large tensors between dispatches.
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
IREE::Encoding::SetEncodingOp>(op) ||
- isa<linalg::FillOp>(op) || isBitExtendOp(&op)) {
+ isa<linalg::FillOp>(op) || LinalgExt::isBitExtendOp(&op)) {
continue;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
index c2ea01e..f31cfaa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
@@ -14,7 +14,9 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -156,7 +158,7 @@
// Dequantization-like operations should be fused with consumers to keep
// the smaller bit width on the dispatch boundary.
- if (isBitExtendOp(genericOp)) {
+ if (LinalgExt::isBitExtendOp(genericOp)) {
return;
}
@@ -196,7 +198,7 @@
// 7. Skip dequantization-like `producer` ops as we would rather fuse
// by cloning the producer instead of multi-use fusion.
- if (isBitExtendOp(producer)) {
+ if (LinalgExt::isBitExtendOp(producer)) {
return;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
index 49d4859..f95614d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
@@ -162,7 +163,7 @@
// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
- !isElementwise(producerOp) || !isBitExtendOp(producerOp)) {
+ !isElementwise(producerOp) || !LinalgExt::isBitExtendOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
index 688f536..70f4a17 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
@@ -9,6 +9,7 @@
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
namespace mlir::iree_compiler::IREE::Flow {
@@ -57,8 +58,6 @@
return false;
}
- std::optional<BitWidthChangeInfo> consumerBitwidthChangeInfo =
- isBitExtendOrTruncateOp(consumerOp);
// Do no fuse bitextend-like operations with producers. Such ops are cloned
// into all their use dispatches. So fusing producer with consumer here would
// then result in producer also getting cloned into many dispatches which is
@@ -66,8 +65,7 @@
// (except for bit-extend ops). If the consumer has only one use, then this
// fusion is fine since cloning wont result in redundant computation of the
// producer. (Also note that the producer is always an elementwise operation).
- if (consumerBitwidthChangeInfo &&
- consumerBitwidthChangeInfo->isExtensionOp() && !consumerOp->hasOneUse()) {
+ if (LinalgExt::isBitExtendOp(consumerOp) && !consumerOp->hasOneUse()) {
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 3b6084f..4b1d017 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
@@ -528,107 +529,6 @@
}
//===---------------------------------------------------------------------===//
-// Classification of ops that change bit-widths
-//===---------------------------------------------------------------------===//
-
-Type BitWidthChangeInfo::getInputElementType() const {
- return cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
-}
-
-std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op) {
- auto genericOp = dyn_cast<linalg::GenericOp>(op);
- if (!genericOp) {
- return std::nullopt;
- }
- if (genericOp.getNumDpsInits() != 1) {
- return std::nullopt;
- }
-
- // Check that the all loops are parallel
- unsigned numLoops = genericOp.getNumLoops();
- unsigned numParallelLoops = genericOp.getNumParallelLoops();
- if (numLoops != numParallelLoops) {
- return std::nullopt;
- }
-
- // Check that all operands that have the highest rank have bit width
- // less than the output bit-width.
- DenseMap<int64_t, SmallVector<OpOperand *>> rankBuckets;
- int64_t maxOperandRank = 0;
- for (OpOperand *input : genericOp.getDpsInputOperands()) {
- auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
- if (!inputType) {
- continue;
- }
- int64_t currRank = inputType.getRank();
- maxOperandRank = std::max(currRank, maxOperandRank);
- rankBuckets[currRank].push_back(input);
- }
- if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) {
- return std::nullopt;
- }
-
- unsigned int maxInputElementBitWidth = 0;
- OpOperand *inputOperand;
- for (OpOperand *operand : rankBuckets[maxOperandRank]) {
- RankedTensorType tensorType =
- cast<RankedTensorType>(operand->get().getType());
- Type elementType = tensorType.getElementType();
- if (!elementType.isIntOrFloat()) {
- return std::nullopt;
- }
- unsigned elementBitWidth = Util::getTypeBitWidth(elementType);
- if (elementBitWidth > maxInputElementBitWidth) {
- maxInputElementBitWidth = elementBitWidth;
- inputOperand = operand;
- }
- }
- if (!inputOperand) {
- return std::nullopt;
- }
- Type inputElementType =
- cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
-
- // Check that the identity input element bitwidth is smaller than the output
- // element bitwidth.
- RankedTensorType outputType =
- dyn_cast<RankedTensorType>(genericOp->getResultTypes()[0]);
- if (!outputType) {
- return std::nullopt;
- }
- Type outputElementType = outputType.getElementType();
- if (!outputElementType.isIntOrFloat()) {
- return std::nullopt;
- }
-
- unsigned inputBitWidth = Util::getTypeBitWidth(inputElementType);
- unsigned outputBitWidth = Util::getTypeBitWidth(outputElementType);
- if (inputBitWidth == outputBitWidth) {
- return std::nullopt;
- }
-
- // Checks specific to bit extend operations.
- if (inputBitWidth < outputBitWidth) {
- // Since these are cloned into dispatches, avoid expensive operations.
- for (Operation &op : *genericOp.getBody()) {
- if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
- return std::nullopt;
- }
- }
- }
-
- // Checks specific to bit truncate operations.
- if (outputBitWidth < inputBitWidth) {
- // For now enforce that the input and output ranks match for truncates.
- if (maxOperandRank != outputType.getRank()) {
- return std::nullopt;
- }
- }
-
- return BitWidthChangeInfo{inputOperand, outputElementType};
-}
-
-//===---------------------------------------------------------------------===//
// Utilities to make a dispatch region isolated from above
//===---------------------------------------------------------------------===//
@@ -643,7 +543,7 @@
tensor::ExtractSliceOp, complex::CreateOp>(op)) {
return true;
}
- if (isBitExtendOp(op)) {
+ if (LinalgExt::isBitExtendOp(op)) {
return true;
}
if (isa<arith::ConstantOp>(op) || isa<complex::ConstantOp>(op)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index df8792b..45a375f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -104,42 +104,6 @@
/// into a dispatch region.
bool isClonableIntoDispatchOp(Operation *op);
-/// Returns true if the operation increases/decreases bitwidths of tensors.
-/// This function checks that the genericOp:
-/// 1. Has only one output.
-/// 2. Has all parallel loops.
-/// 3. Compared to the element type of the input with highest rank,
-/// the output element type has either a higher or lower bitwidth.
-struct BitWidthChangeInfo {
- // The operand the recognizer treats as the "input".
- // Is guaranteed to be a `RankedTensorType`.
- OpOperand *inputOperand = nullptr;
- // The output element type is int or float type.
- Type outputElementType = nullptr;
-
- // Helper methods.
- Type getInputElementType() const;
- bool isExtensionOp() const {
- return getInputElementType().getIntOrFloatBitWidth() <
- outputElementType.getIntOrFloatBitWidth();
- }
- bool isTruncationOp() const {
- return outputElementType.getIntOrFloatBitWidth() <
- getInputElementType().getIntOrFloatBitWidth();
- }
-};
-std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op);
-inline bool isBitExtendOp(Operation *op) {
- std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
- isBitExtendOrTruncateOp(op);
- return bitWidthChangeInfo && bitWidthChangeInfo->isExtensionOp();
-}
-inline bool isBitTruncateOp(Operation *op) {
- std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
- isBitExtendOrTruncateOp(op);
- return bitWidthChangeInfo && bitWidthChangeInfo->isTruncationOp();
-}
-
/// Collect all ops that should be cloned into the given dispatch region op.
SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
index 9f0b383..3705f7b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -76,7 +77,7 @@
// Do not sink reshapes across dequantize operations since they are
// cloned into their consumers.
- if (isBitExtendOp(consumer)) {
+ if (LinalgExt::isBitExtendOp(consumer)) {
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
index 8cec265..0da302b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
@@ -28,6 +28,7 @@
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
index 7893e71..b4c519c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -24,6 +24,7 @@
LLVMSupport
MLIRArithDialect
MLIRIR
+ MLIRLinalgDialect
MLIRMemRefDialect
MLIRSupport
MLIRTensorDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index e6c3548..30dccb0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -8,6 +8,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
@@ -132,4 +133,115 @@
return result;
}
+//===---------------------------------------------------------------------===//
+// Classification of ops that change bit-widths
+//===---------------------------------------------------------------------===//
+
+enum class BitWidthChangeInfo {
+ kNull,
+ kExtend,
+ kTruncate,
+};
+
+static BitWidthChangeInfo isBitExtendOrTruncateOp(Operation *op) {
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp) {
+ return BitWidthChangeInfo::kNull;
+ }
+
+ if (genericOp.getNumDpsInits() != 1) {
+ return BitWidthChangeInfo::kNull;
+ }
+
+ // Check that the all loops are parallel
+ unsigned numLoops = genericOp.getNumLoops();
+ unsigned numParallelLoops = genericOp.getNumParallelLoops();
+ if (numLoops != numParallelLoops) {
+ return BitWidthChangeInfo::kNull;
+ }
+
+ // Check that all operands that have the highest rank have bit width
+ // less than the output bit-width.
+ DenseMap<int64_t, SmallVector<OpOperand *>> rankBuckets;
+ int64_t maxOperandRank = 0;
+ for (OpOperand *input : genericOp.getDpsInputOperands()) {
+ auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
+ if (!inputType) {
+ continue;
+ }
+ int64_t currRank = inputType.getRank();
+ maxOperandRank = std::max(currRank, maxOperandRank);
+ rankBuckets[currRank].push_back(input);
+ }
+ if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) {
+ return BitWidthChangeInfo::kNull;
+ }
+
+ unsigned int maxInputElementBitWidth = 0;
+ OpOperand *inputOperand;
+ for (OpOperand *operand : rankBuckets[maxOperandRank]) {
+ RankedTensorType tensorType =
+ cast<RankedTensorType>(operand->get().getType());
+ Type elementType = tensorType.getElementType();
+ if (!elementType.isIntOrFloat()) {
+ return BitWidthChangeInfo::kNull;
+ }
+ unsigned elementBitWidth = elementType.getIntOrFloatBitWidth();
+ if (elementBitWidth > maxInputElementBitWidth) {
+ maxInputElementBitWidth = elementBitWidth;
+ inputOperand = operand;
+ }
+ }
+ if (!inputOperand) {
+ return BitWidthChangeInfo::kNull;
+ }
+ Type inputElementType =
+ cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
+
+ // Check that the identity input element bitwidth is smaller than the output
+ // element bitwidth.
+ RankedTensorType outputType =
+ dyn_cast<RankedTensorType>(genericOp->getResultTypes()[0]);
+ if (!outputType) {
+ return BitWidthChangeInfo::kNull;
+ }
+ Type outputElementType = outputType.getElementType();
+ if (!outputElementType.isIntOrFloat()) {
+ return BitWidthChangeInfo::kNull;
+ }
+
+ unsigned inputBitWidth = inputElementType.getIntOrFloatBitWidth();
+ unsigned outputBitWidth = outputElementType.getIntOrFloatBitWidth();
+
+ // Checks specific to bit extend operations.
+ if (inputBitWidth < outputBitWidth) {
+ // Since these are cloned into dispatches, avoid expensive operations.
+ for (Operation &op : *genericOp.getBody()) {
+ if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
+ return BitWidthChangeInfo::kNull;
+ }
+ }
+ return BitWidthChangeInfo::kExtend;
+ }
+
+ // Checks specific to bit truncate operations.
+ if (outputBitWidth < inputBitWidth) {
+ // For now enforce that the input and output ranks match for truncates.
+ if (maxOperandRank != outputType.getRank()) {
+ return BitWidthChangeInfo::kNull;
+ }
+ return BitWidthChangeInfo::kTruncate;
+ }
+
+ return BitWidthChangeInfo::kNull;
+}
+
+bool isBitExtendOp(Operation *op) {
+ return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kExtend;
+}
+
+bool isBitTruncateOp(Operation *op) {
+ return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kTruncate;
+}
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index eec973f..3c4e139 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -111,5 +111,21 @@
}
}
+/// Returns true if the operation increases bitwidths of tensors.
+/// This function checks that the genericOp:
+/// 1. Has only one output.
+/// 2. Has all parallel loops.
+/// 3. Compared to the element type of the input with highest rank,
+/// the output element type has a higher bitwidth.
+bool isBitExtendOp(Operation *op);
+
+/// Returns true if the operation decreases bitwidths of tensors.
+/// This function checks that the genericOp:
+/// 1. Has only one output.
+/// 2. Has all parallel loops.
+/// 3. Compared to the element type of the input with highest rank,
+/// the output element type has a lower bitwidth.
+bool isBitTruncateOp(Operation *op);
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_