[Flow][LinalgExt] Move bitwidth helpers from Flow to LinalgExt (#17855)

Move `isBitExtendOp` and `isBitTruncateOp` into linalgext.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 7d2a75f..b8c317f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -20,10 +20,10 @@
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -451,11 +451,11 @@
   Type initElemType = getElementTypeOrSelf(init);
 
   if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
-    if (IREE::Flow::isBitExtendOp(lhsOp))
+    if (IREE::LinalgExt::isBitExtendOp(lhsOp))
       lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
   }
   if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
-    if (IREE::Flow::isBitExtendOp(rhsOp))
+    if (IREE::LinalgExt::isBitExtendOp(rhsOp))
       rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
index bc108c2..0e75587 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp
@@ -15,6 +15,7 @@
 #include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -53,7 +54,7 @@
         }
 
         // Do not fuse by expand if consumer is dequant.
-        if (isBitExtendOp(consumer)) {
+        if (LinalgExt::isBitExtendOp(consumer)) {
           return false;
         }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index 77af6c9..e39096b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -15,6 +15,7 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Casting.h"
@@ -226,7 +227,7 @@
     return false;
   }
   // Dequantization-like ops get cloned into dispatches later.
-  if (isBitExtendOp(op)) {
+  if (LinalgExt::isBitExtendOp(op)) {
     return false;
   }
   // Any Linalg named op or generic op with reduction iterator types is a root
@@ -539,7 +540,7 @@
 
   // If consumer is a dequant operation, dont fuse it. These get cloned
   // into their consumers.
-  if (isBitExtendOp(consumer)) {
+  if (LinalgExt::isBitExtendOp(consumer)) {
     return false;
   }
 
@@ -874,7 +875,7 @@
       // materializing large tensors between dispatches.
       if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
                IREE::Encoding::SetEncodingOp>(op) ||
-          isa<linalg::FillOp>(op) || isBitExtendOp(&op)) {
+          isa<linalg::FillOp>(op) || LinalgExt::isBitExtendOp(&op)) {
         continue;
       }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
index c2ea01e..f31cfaa 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseMultiUseElementwiseProducer.cpp
@@ -14,7 +14,9 @@
 
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
@@ -156,7 +158,7 @@
 
         // Dequantization-like operations should be fused with consumers to keep
         // the smaller bit width on the dispatch boundary.
-        if (isBitExtendOp(genericOp)) {
+        if (LinalgExt::isBitExtendOp(genericOp)) {
           return;
         }
 
@@ -196,7 +198,7 @@
 
           // 7. Skip dequantization-like `producer` ops as we would rather fuse
           //    by cloning the producer instead of multi-use fusion.
-          if (isBitExtendOp(producer)) {
+          if (LinalgExt::isBitExtendOp(producer)) {
             return;
           }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
index 49d4859..f95614d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
@@ -12,6 +12,7 @@
 
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
@@ -162,7 +163,7 @@
 
     // Check if the producerOp is fusible
     if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
-        !isElementwise(producerOp) || !isBitExtendOp(producerOp)) {
+        !isElementwise(producerOp) || !LinalgExt::isBitExtendOp(producerOp)) {
       return rewriter.notifyMatchFailure(producerOp,
                                          "producer op is not fusible");
     }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
index 688f536..70f4a17 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.cpp
@@ -9,6 +9,7 @@
 
 #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
 #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 
 namespace mlir::iree_compiler::IREE::Flow {
@@ -57,8 +58,6 @@
     return false;
   }
 
-  std::optional<BitWidthChangeInfo> consumerBitwidthChangeInfo =
-      isBitExtendOrTruncateOp(consumerOp);
   // Do no fuse bitextend-like operations with producers. Such ops are cloned
   // into all their use dispatches. So fusing producer with consumer here would
   // then result in producer also getting cloned into many dispatches which is
@@ -66,8 +65,7 @@
   // (except for bit-extend ops). If the consumer has only one use, then this
   // fusion is fine since cloning wont result in redundant computation of the
   // producer. (Also note that the producer is always an elementwise operation).
-  if (consumerBitwidthChangeInfo &&
-      consumerBitwidthChangeInfo->isExtensionOp() && !consumerOp->hasOneUse()) {
+  if (LinalgExt::isBitExtendOp(consumerOp) && !consumerOp->hasOneUse()) {
     return false;
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 3b6084f..4b1d017 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -9,6 +9,7 @@
 #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
@@ -528,107 +529,6 @@
 }
 
 //===---------------------------------------------------------------------===//
-// Classification of ops that change bit-widths
-//===---------------------------------------------------------------------===//
-
-Type BitWidthChangeInfo::getInputElementType() const {
-  return cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
-}
-
-std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op) {
-  auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp) {
-    return std::nullopt;
-  }
-  if (genericOp.getNumDpsInits() != 1) {
-    return std::nullopt;
-  }
-
-  // Check that the all loops are parallel
-  unsigned numLoops = genericOp.getNumLoops();
-  unsigned numParallelLoops = genericOp.getNumParallelLoops();
-  if (numLoops != numParallelLoops) {
-    return std::nullopt;
-  }
-
-  // Check that all operands that have the highest rank have bit width
-  // less than the output bit-width.
-  DenseMap<int64_t, SmallVector<OpOperand *>> rankBuckets;
-  int64_t maxOperandRank = 0;
-  for (OpOperand *input : genericOp.getDpsInputOperands()) {
-    auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
-    if (!inputType) {
-      continue;
-    }
-    int64_t currRank = inputType.getRank();
-    maxOperandRank = std::max(currRank, maxOperandRank);
-    rankBuckets[currRank].push_back(input);
-  }
-  if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) {
-    return std::nullopt;
-  }
-
-  unsigned int maxInputElementBitWidth = 0;
-  OpOperand *inputOperand;
-  for (OpOperand *operand : rankBuckets[maxOperandRank]) {
-    RankedTensorType tensorType =
-        cast<RankedTensorType>(operand->get().getType());
-    Type elementType = tensorType.getElementType();
-    if (!elementType.isIntOrFloat()) {
-      return std::nullopt;
-    }
-    unsigned elementBitWidth = Util::getTypeBitWidth(elementType);
-    if (elementBitWidth > maxInputElementBitWidth) {
-      maxInputElementBitWidth = elementBitWidth;
-      inputOperand = operand;
-    }
-  }
-  if (!inputOperand) {
-    return std::nullopt;
-  }
-  Type inputElementType =
-      cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
-
-  // Check that the identity input element bitwidth is smaller than the output
-  // element bitwidth.
-  RankedTensorType outputType =
-      dyn_cast<RankedTensorType>(genericOp->getResultTypes()[0]);
-  if (!outputType) {
-    return std::nullopt;
-  }
-  Type outputElementType = outputType.getElementType();
-  if (!outputElementType.isIntOrFloat()) {
-    return std::nullopt;
-  }
-
-  unsigned inputBitWidth = Util::getTypeBitWidth(inputElementType);
-  unsigned outputBitWidth = Util::getTypeBitWidth(outputElementType);
-  if (inputBitWidth == outputBitWidth) {
-    return std::nullopt;
-  }
-
-  // Checks specific to bit extend operations.
-  if (inputBitWidth < outputBitWidth) {
-    // Since these are cloned into dispatches, avoid expensive operations.
-    for (Operation &op : *genericOp.getBody()) {
-      if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
-        return std::nullopt;
-      }
-    }
-  }
-
-  // Checks specific to bit truncate operations.
-  if (outputBitWidth < inputBitWidth) {
-    // For now enforce that the input and output ranks match for truncates.
-    if (maxOperandRank != outputType.getRank()) {
-      return std::nullopt;
-    }
-  }
-
-  return BitWidthChangeInfo{inputOperand, outputElementType};
-}
-
-//===---------------------------------------------------------------------===//
 // Utilities to make a dispatch region isolated from above
 //===---------------------------------------------------------------------===//
 
@@ -643,7 +543,7 @@
           tensor::ExtractSliceOp, complex::CreateOp>(op)) {
     return true;
   }
-  if (isBitExtendOp(op)) {
+  if (LinalgExt::isBitExtendOp(op)) {
     return true;
   }
   if (isa<arith::ConstantOp>(op) || isa<complex::ConstantOp>(op)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index df8792b..45a375f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -104,42 +104,6 @@
 /// into a dispatch region.
 bool isClonableIntoDispatchOp(Operation *op);
 
-/// Returns true if the operation increases/decreases bitwidths of tensors.
-/// This function checks that the genericOp:
-/// 1. Has only one output.
-/// 2. Has all parallel loops.
-/// 3. Compared to the element type of the input with highest rank,
-///    the output element type has either a higher or lower bitwidth.
-struct BitWidthChangeInfo {
-  // The operand the recognizer treats as the "input".
-  // Is guaranteed to be a `RankedTensorType`.
-  OpOperand *inputOperand = nullptr;
-  // The output element type is int or float type.
-  Type outputElementType = nullptr;
-
-  // Helper methods.
-  Type getInputElementType() const;
-  bool isExtensionOp() const {
-    return getInputElementType().getIntOrFloatBitWidth() <
-           outputElementType.getIntOrFloatBitWidth();
-  }
-  bool isTruncationOp() const {
-    return outputElementType.getIntOrFloatBitWidth() <
-           getInputElementType().getIntOrFloatBitWidth();
-  }
-};
-std::optional<BitWidthChangeInfo> isBitExtendOrTruncateOp(Operation *op);
-inline bool isBitExtendOp(Operation *op) {
-  std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
-      isBitExtendOrTruncateOp(op);
-  return bitWidthChangeInfo && bitWidthChangeInfo->isExtensionOp();
-}
-inline bool isBitTruncateOp(Operation *op) {
-  std::optional<BitWidthChangeInfo> bitWidthChangeInfo =
-      isBitExtendOrTruncateOp(op);
-  return bitWidthChangeInfo && bitWidthChangeInfo->isTruncationOp();
-}
-
 /// Collect all ops that should be cloned into the given dispatch region op.
 SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
index 9f0b383..3705f7b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp
@@ -17,6 +17,7 @@
 #include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -76,7 +77,7 @@
 
   // Do not sink reshapes across dequantize operations since they are
   // cloned into their consumers.
-  if (isBitExtendOp(consumer)) {
+  if (LinalgExt::isBitExtendOp(consumer)) {
     return false;
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
index 8cec265..0da302b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
@@ -28,6 +28,7 @@
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TensorDialect",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
index 7893e71..b4c519c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -24,6 +24,7 @@
     LLVMSupport
     MLIRArithDialect
     MLIRIR
+    MLIRLinalgDialect
     MLIRMemRefDialect
     MLIRSupport
     MLIRTensorDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index e6c3548..30dccb0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Builders.h"
@@ -132,4 +133,115 @@
   return result;
 }
 
+//===---------------------------------------------------------------------===//
+// Classification of ops that change bit-widths
+//===---------------------------------------------------------------------===//
+
+enum class BitWidthChangeInfo {
+  kNull,
+  kExtend,
+  kTruncate,
+};
+
+static BitWidthChangeInfo isBitExtendOrTruncateOp(Operation *op) {
+  auto genericOp = dyn_cast<linalg::GenericOp>(op);
+  if (!genericOp) {
+    return BitWidthChangeInfo::kNull;
+  }
+
+  if (genericOp.getNumDpsInits() != 1) {
+    return BitWidthChangeInfo::kNull;
+  }
+
+  // Check that the all loops are parallel
+  unsigned numLoops = genericOp.getNumLoops();
+  unsigned numParallelLoops = genericOp.getNumParallelLoops();
+  if (numLoops != numParallelLoops) {
+    return BitWidthChangeInfo::kNull;
+  }
+
+  // Check that all operands that have the highest rank have bit width
+  // less than the output bit-width.
+  DenseMap<int64_t, SmallVector<OpOperand *>> rankBuckets;
+  int64_t maxOperandRank = 0;
+  for (OpOperand *input : genericOp.getDpsInputOperands()) {
+    auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
+    if (!inputType) {
+      continue;
+    }
+    int64_t currRank = inputType.getRank();
+    maxOperandRank = std::max(currRank, maxOperandRank);
+    rankBuckets[currRank].push_back(input);
+  }
+  if (maxOperandRank == 0 || rankBuckets[maxOperandRank].empty()) {
+    return BitWidthChangeInfo::kNull;
+  }
+
+  unsigned int maxInputElementBitWidth = 0;
+  OpOperand *inputOperand;
+  for (OpOperand *operand : rankBuckets[maxOperandRank]) {
+    RankedTensorType tensorType =
+        cast<RankedTensorType>(operand->get().getType());
+    Type elementType = tensorType.getElementType();
+    if (!elementType.isIntOrFloat()) {
+      return BitWidthChangeInfo::kNull;
+    }
+    unsigned elementBitWidth = elementType.getIntOrFloatBitWidth();
+    if (elementBitWidth > maxInputElementBitWidth) {
+      maxInputElementBitWidth = elementBitWidth;
+      inputOperand = operand;
+    }
+  }
+  if (!inputOperand) {
+    return BitWidthChangeInfo::kNull;
+  }
+  Type inputElementType =
+      cast<RankedTensorType>(inputOperand->get().getType()).getElementType();
+
+  // Check that the identity input element bitwidth is smaller than the output
+  // element bitwidth.
+  RankedTensorType outputType =
+      dyn_cast<RankedTensorType>(genericOp->getResultTypes()[0]);
+  if (!outputType) {
+    return BitWidthChangeInfo::kNull;
+  }
+  Type outputElementType = outputType.getElementType();
+  if (!outputElementType.isIntOrFloat()) {
+    return BitWidthChangeInfo::kNull;
+  }
+
+  unsigned inputBitWidth = inputElementType.getIntOrFloatBitWidth();
+  unsigned outputBitWidth = outputElementType.getIntOrFloatBitWidth();
+
+  // Checks specific to bit extend operations.
+  if (inputBitWidth < outputBitWidth) {
+    // Since these are cloned into dispatches, avoid expensive operations.
+    for (Operation &op : *genericOp.getBody()) {
+      if (op.getDialect() == op.getContext()->getLoadedDialect("math")) {
+        return BitWidthChangeInfo::kNull;
+      }
+    }
+    return BitWidthChangeInfo::kExtend;
+  }
+
+  // Checks specific to bit truncate operations.
+  if (outputBitWidth < inputBitWidth) {
+    // For now enforce that the input and output ranks match for truncates.
+    if (maxOperandRank != outputType.getRank()) {
+      return BitWidthChangeInfo::kNull;
+    }
+    return BitWidthChangeInfo::kTruncate;
+  }
+
+  return BitWidthChangeInfo::kNull;
+}
+
+bool isBitExtendOp(Operation *op) {
+  return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kExtend;
+}
+
+bool isBitTruncateOp(Operation *op) {
+  return isBitExtendOrTruncateOp(op) == BitWidthChangeInfo::kTruncate;
+}
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index eec973f..3c4e139 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -111,5 +111,21 @@
   }
 }
 
+/// Returns true if the operation increases bitwidths of tensors.
+/// This function checks that the genericOp:
+/// 1. Has only one output.
+/// 2. Has all parallel loops.
+/// 3. Compared to the element type of the input with highest rank,
+///    the output element type has a higher bitwidth.
+bool isBitExtendOp(Operation *op);
+
+/// Returns true if the operation decreases bitwidths of tensors.
+/// This function checks that the genericOp:
+/// 1. Has only one output.
+/// 2. Has all parallel loops.
+/// 3. Compared to the element type of the input with highest rank,
+///    the output element type has a lower bitwidth.
+bool isBitTruncateOp(Operation *op);
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
 #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_