Pass to block dynamic dimensions of operands of `iree_linalg_ext.attention`. (#18874)

The use of `IntegerRangeAnalysis` and `IntegerDivisibilityAnalysis`
gives range and divisibility information for constants passed to the
dispatch. This can be used to infer the range and divisibility
information for all tensor values in the dispatch. This PR adds an
analysis to do this.

This analysis is then used to expand the dimensions of operands of the
attention operation that are dynamic, but are known to be divisible by
a compile-time static value. This gets the operations into a form that
can be compiled by the AMDGPU backend and target the mfma intrinsics.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>

---------

Signed-off-by: MaheshRavishankar <mravisha@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 7aca986..d9d23b2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -86,6 +86,7 @@
     name = "Common",
     srcs = [
         "AddFastMathFlags.cpp",
+        "BlockDynamicDimensions.cpp",
         "BubbleUpOrdinalOps.cpp",
         "BufferizationAnalysis.cpp",
         "BufferizeCopyOnlyDispatchesPass.cpp",
@@ -137,6 +138,7 @@
         "RemoveSingleIterationLoop.cpp",
         "ReplaceSlowMinMaxOps.cpp",
         "SplitFullPartialTransferPass.cpp",
+        "TensorDynamicDimAnalysis.cpp",
         "TensorToVectorVectorizePad.cpp",
         "TestExecutablePreprocessing.cpp",
         "TestPartitionableLoopsInterface.cpp",
@@ -155,6 +157,7 @@
         "ExtractAddressComputation.h",
         "PassUtils.h",
         "Passes.h",
+        "TensorDynamicDimAnalysis.h",
         "TileSizeSelection.h",
         "Transforms.h",
         "UserConfig.h",
@@ -176,6 +179,7 @@
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
+        "//compiler/src/iree/compiler/Dialect/Util/Analysis",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//compiler/src/iree/compiler/Utils",
         "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
@@ -191,6 +195,7 @@
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:BufferizationInterfaces",
         "@llvm-project//mlir:BufferizationTransforms",
+        "@llvm-project//mlir:DestinationStyleOpInterface",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FuncTransforms",
diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
new file mode 100644
index 0000000..7a45116
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
@@ -0,0 +1,302 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+using TensorDivisibilityInfo =
+    llvm::SmallDenseMap<unsigned, IREE::Util::ConstantIntDivisibility>;
+
+namespace {
+
+struct RemoveOptimizationBarrier final
+    : public OpRewritePattern<IREE::Util::OptimizationBarrierOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOp(barrierOp, barrierOp.getOperands());
+    return success();
+  }
+};
+
+/// This pass is used to materialize information about dynamic dimensions of
+/// `tensor` operands of an operation in the IR. If a dynamic dimension is
+/// known to be a multiple of a compile-time constant value, this pass
+/// expands the shape of the operands. For example if a `tensor` operand
+/// is of shape `tensor<...x?x...>` and that dimension is known to be a
+/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the
+/// size of the new dynamic dimension is 1/16-th the size of the original
+/// dynamic dimension size. This is done in two steps.
+/// 1) Replace operands with such dynamic dimension with the result of a
+///    `tensor.expand_shape/tensor.collapse_shape` pair
+///    to materialize the new static dimension and immediately fold it away. A
+///    optimization barrier is added in between to prevent these operations from
+///    being folded.
+/// 2) Use patterns that propagate the `tensor.collapse_shape` down to
+///    manipulate the operation appropriately. This
+///    allows re-using the (fairly complex) logic used to expand dimensions of
+///    operations implemented in the propagation patterns.
+/// At the end of the pass the optimization barriers are removed to fold away
+/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns.
+struct BlockDynamicDimensionsPass final
+    : impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+/// Retrieve the divisibility information for dynamic dimensions of `v` if
+/// known.
+static TensorDivisibilityInfo
+getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis,
+                          Value v) {
+  TensorDivisibilityInfo divisibilityInfo;
+  auto tensorType = dyn_cast<RankedTensorType>(v.getType());
+  if (!tensorType) {
+    return divisibilityInfo;
+  }
+
+  for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
+    if (!tensorType.isDynamicDim(index))
+      continue;
+    std::optional<IREE::Util::ConstantIntDivisibility> dimDivisibility =
+        dynamicDimAnalysis.getDivisibilityInfo(v, index);
+    if (!dimDivisibility)
+      continue;
+    divisibilityInfo[index] = std::move(dimDivisibility.value());
+  }
+
+  return divisibilityInfo;
+}
+
+/// For a `v` if the dimension is known to be multiple of a compile-time static
+/// value, insert
+///
+/// ```mlir
+/// %v_expand = tensor.expand_shape %v
+/// %barrier = util.optimization.barrier %v
+/// %v_collapse = tensor.collapse_shape %barrier
+/// ```
+///
+/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are
+/// inverses of each other. The `util.optimization.barrier` avoid these from
+/// getting folded away during reshape propagation. Return the result of the
+/// `tensor.collapse_shape generated.
+static std::optional<Value>
+blockDynamicDimensionsOfValue(RewriterBase &rewriter,
+                              const TensorDivisibilityInfo &divisibilityInfo,
+                              Value v) {
+  auto tensorType = dyn_cast<RankedTensorType>(v.getType());
+  if (!tensorType) {
+    return std::nullopt;
+  }
+
+  // Check if we know that the operands have a divisibility information.
+  SmallVector<OpFoldResult> outputShape;
+  SmallVector<ReassociationIndices> reassociation;
+  Location loc = v.getLoc();
+
+  for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
+    reassociation.emplace_back(ReassociationIndices{});
+
+    // Check if this needs division.
+    if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) {
+      reassociation.back().push_back(outputShape.size());
+      outputShape.push_back(rewriter.getIndexAttr(dim));
+      continue;
+    }
+
+    // Split the dynamic based on the divisibility info.
+    IREE::Util::ConstantIntDivisibility currDivisibility =
+        divisibilityInfo.lookup(index);
+    uint64_t factor = currDivisibility.sdiv();
+    AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+    AffineExpr divExpr = s0.floorDiv(factor);
+    Value sourceDim = rewriter.create<tensor::DimOp>(loc, v, index).getResult();
+    OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, divExpr, ArrayRef<OpFoldResult>{sourceDim});
+    OpFoldResult newStaticDim = rewriter.getIndexAttr(factor);
+
+    reassociation.back().push_back(outputShape.size());
+    reassociation.back().push_back(outputShape.size() + 1);
+
+    outputShape.push_back(newDynamicDim);
+    outputShape.push_back(newStaticDim);
+  }
+
+  auto staticOutputShape =
+      llvm::map_to_vector(outputShape, [](OpFoldResult ofr) {
+        if (auto staticShapeAttr = dyn_cast<Attribute>(ofr)) {
+          return cast<IntegerAttr>(staticShapeAttr).getInt();
+        }
+        return ShapedType::kDynamic;
+      });
+  auto outputType = RankedTensorType::get(
+      staticOutputShape, tensorType.getElementType(), tensorType.getEncoding());
+
+  Value expandShape = rewriter.create<tensor::ExpandShapeOp>(
+      loc, outputType, v, reassociation, outputShape);
+  Value barrier =
+      rewriter.create<IREE::Util::OptimizationBarrierOp>(loc, expandShape)
+          .getResult(0);
+  Value collapseShape = rewriter.create<tensor::CollapseShapeOp>(
+      loc, tensorType, barrier, reassociation);
+  return collapseShape;
+}
+
+/// For an operation, replace the operands at indices specified in
+/// `limitToOperandIndices` with the result of
+/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the
+/// information about dynamic dimensions that are known to be a multiple of a
+/// compile-time static value. For example,
+///
+/// ```mlir
+/// %1 = <some_op>(..., %0, ...) : ... , tensor<4x?x6xf32>
+/// ```
+///
+/// If the dynamic dimension is known to be a multiple of 16, then generate
+///
+/// ```mlir
+/// %expanded = tensor.expand_shape %0 :
+///    tensor<4x?x5xf32> into tensor<4x?x16x6xf32>
+/// %barrier = util.optimization.barrier %expanded
+/// %collapsed = tensor.collapse_shape %barrier
+///     : tensor<4x?x16x5xf32> into tensor<4x?x5xf32>
+/// %1 = <some_op>(..., %collaped, ...) : ... , tensor<4x?x6xf32>
+/// ```
+static LogicalResult blockDynamicDimensions(
+    RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis,
+    Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandIndices) {
+  OpBuilder::InsertionGuard g(rewriter);
+
+  for (OpOperand &operand : operation->getOpOperands()) {
+    if (!limitToOperandIndices.contains(operand.getOperandNumber()))
+      continue;
+    if (operand.get().getDefiningOp<tensor::CollapseShapeOp>())
+      continue;
+    TensorDivisibilityInfo operandDivisibilityInfo =
+        getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get());
+    if (operandDivisibilityInfo.empty())
+      continue;
+    std::optional<Value> newOperand = blockDynamicDimensionsOfValue(
+        rewriter, operandDivisibilityInfo, operand.get());
+    if (newOperand) {
+      rewriter.modifyOpInPlace(operation,
+                               [&]() { operand.set(newOperand.value()); });
+    }
+  }
+  return success();
+}
+
+/// Insert `tensor.expand_shape` operations to materialize in IR information
+/// about dynamic dimensions that are known to be a multiple of a compile-time
+/// know value, for the operands of `iree_linalg_ext.attention` operation.
+static LogicalResult
+blockDynamicDimensions(RewriterBase &rewriter,
+                       const TensorDynamicDimAnalysis &dynamicDimAnalysis,
+                       IREE::LinalgExt::AttentionOp attentionOp) {
+  // Only block the q and k values.
+  llvm::SmallDenseSet<int64_t> prunedOperandsList;
+  prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber());
+  prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber());
+  return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp,
+                                prunedOperandsList);
+}
+
+void BlockDynamicDimensionsPass::runOnOperation() {
+  Operation *operation = getOperation();
+  MLIRContext *context = &getContext();
+  TensorDynamicDimAnalysis dynamicDimAnalysis(operation);
+  if (failed(dynamicDimAnalysis.run())) {
+    return signalPassFailure();
+  }
+
+  IRRewriter rewriter(context);
+  auto walkResult = operation->walk(
+      [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult {
+        rewriter.setInsertionPoint(attentionOp);
+        return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
+                                      attentionOp);
+      });
+  if (walkResult.wasInterrupted()) {
+    return signalPassFailure();
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "After blocking dimensions:\n";
+    operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n";
+  });
+
+  {
+    RewritePatternSet bubbleExpandShapePatterns(context);
+    // Add patterns to "push down" the `tensor.collapse_shape` patterns (which
+    // are the dual of the patterns to "bubble up" `tensor.expand_shape`
+    // patterns)
+    linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; };
+    linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
+                                                      controlFn);
+    IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
+        bubbleExpandShapePatterns, controlFn);
+    // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and
+    // "pushed-down" `tensor.collapse_shape` operation with their interface
+    // bindings or `tensor.empty` operations.
+    populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
+    tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
+    // Add some additional patterns that can simplify the IR and remove dead
+    // operations.
+    memref::populateResolveRankedShapedTypeResultDimsPatterns(
+        bubbleExpandShapePatterns);
+    populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns);
+    if (failed(applyPatternsAndFoldGreedily(
+            operation, std::move(bubbleExpandShapePatterns)))) {
+      operation->emitOpError(
+          "failed in application of bubble up expand shape patterns");
+      return signalPassFailure();
+    }
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "After reshape propagation:\n";
+    operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n";
+  });
+
+  // Delete the optimization barrier and run some further cleanup.
+  {
+    RewritePatternSet removeBarrierOpsPatterns(context);
+    removeBarrierOpsPatterns.insert<RemoveOptimizationBarrier>(context);
+    tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
+                                                       context);
+    tensor::CollapseShapeOp::getCanonicalizationPatterns(
+        removeBarrierOpsPatterns, context);
+    if (failed(applyPatternsAndFoldGreedily(
+            operation, std::move(removeBarrierOpsPatterns)))) {
+      operation->emitOpError("failed in cleanup patterns");
+      return signalPassFailure();
+    }
+  }
+
+  return;
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 764bc25..ee7c406 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -72,11 +72,13 @@
     "ExtractAddressComputation.h"
     "PassUtils.h"
     "Passes.h"
+    "TensorDynamicDimAnalysis.h"
     "TileSizeSelection.h"
     "Transforms.h"
     "UserConfig.h"
   SRCS
     "AddFastMathFlags.cpp"
+    "BlockDynamicDimensions.cpp"
     "BubbleUpOrdinalOps.cpp"
     "BufferizationAnalysis.cpp"
     "BufferizeCopyOnlyDispatchesPass.cpp"
@@ -128,6 +130,7 @@
     "RemoveSingleIterationLoop.cpp"
     "ReplaceSlowMinMaxOps.cpp"
     "SplitFullPartialTransferPass.cpp"
+    "TensorDynamicDimAnalysis.cpp"
     "TensorToVectorVectorizePad.cpp"
     "TestExecutablePreprocessing.cpp"
     "TestPartitionableLoopsInterface.cpp"
@@ -154,6 +157,7 @@
     MLIRArithUtils
     MLIRBufferizationDialect
     MLIRBufferizationTransforms
+    MLIRDestinationStyleOpInterface
     MLIRFuncDialect
     MLIRFuncTransforms
     MLIRFunctionInterfaces
@@ -203,6 +207,7 @@
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::LinalgExt::Transforms
+    iree::compiler::Dialect::Util::Analysis
     iree::compiler::Dialect::Util::IR
     iree::compiler::Utils
   PUBLIC
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 6a5a9b5..5aa3ef4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -19,6 +19,12 @@
                 "given a floating-point mode.";
 }
 
+def BlockDynamicDimensionsPass
+    : Pass<"iree-codegen-block-dynamic-dimensions"> {
+  let summary = "Expand dynamic dimensions that are known to be multiples of "
+                "statically known values.";
+}
+
 def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
   let summary = "Bubbles op ordinal ops to allow for workgroup count computation";
   let description = [{
diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
new file mode 100644
index 0000000..b0e7667
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
@@ -0,0 +1,236 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#define DEBUG_TYPE "iree-codegen-dynamic-dim-analysis"
+
+namespace mlir::iree_compiler {
+
+//===---------------------------------------------------------------------===//
+// Helper function to update tensor dynamic dimension info
+//===---------------------------------------------------------------------===//
+
+static void
+updateRangeInfo(TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo,
+                Value v, unsigned dim, const ConstantIntRanges &range) {
+  assert(!rangeInfo.contains({v, dim}) &&
+         "overwriting existing dim range info");
+  rangeInfo.insert({{v, dim},
+                    ConstantIntRanges(range.umin(), range.umax(), range.smin(),
+                                      range.smax())});
+}
+
+static void updateDivisibilityInfo(
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    Value v, unsigned dim,
+    const IREE::Util::ConstantIntDivisibility &divisibility) {
+  assert(!divisibilityInfo.contains({v, dim}) &&
+         "overwriting existing dim divisibility info");
+  divisibilityInfo[{v, dim}] = divisibility;
+}
+
+// Update the dynamic dim analysis to record the range/divisibility information
+// for `tensorValue` at dimension `dimIndex` based on the range/divisibility
+// information of an integer/index value `dynamicDim`.
+static void updateTensorDimInfo(
+    Value tensorValue, unsigned dimIndex, Value dynamicDim,
+    const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  // Update range info.
+  auto *rangeState =
+      solver.lookupState<dataflow::IntegerValueRangeLattice>(dynamicDim);
+  if (rangeState && !rangeState->getValue().isUninitialized()) {
+    updateRangeInfo(rangeInfo, tensorValue, dimIndex,
+                    rangeState->getValue().getValue());
+  }
+
+  // Update solver info
+  auto *divisibilityState =
+      solver.lookupState<IREE::Util::IntegerDivisibilityLattice>(dynamicDim);
+  if (divisibilityState && !divisibilityState->getValue().isUninitialized()) {
+    updateDivisibilityInfo(divisibilityInfo, tensorValue, dimIndex,
+                           divisibilityState->getValue().getValue());
+  }
+}
+
+//===---------------------------------------------------------------------===//
+// Transfer functions for updating dynamic dimension of results of operation.
+//===---------------------------------------------------------------------===//
+
+// Helper function to just transfer the range and divisibility information
+// `source` value to `dest` value.
+static void transferTensorDimInfo(
+    Value source, Value dest, const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  // expected that `source` and `dest` are of `RankedTensorType` and of the same
+  // type.
+  assert(source.getType() == dest.getType());
+  auto sourceType = cast<RankedTensorType>(source.getType());
+  for (auto index : llvm::seq<unsigned>(0, sourceType.getRank())) {
+    // Transfer range info
+    auto rangeIt = rangeInfo.find({source, index});
+    if (rangeIt != rangeInfo.end()) {
+      updateRangeInfo(rangeInfo, dest, index, rangeIt->second);
+    }
+
+    auto divisibilityIt = divisibilityInfo.find({source, index});
+    if (divisibilityIt != divisibilityInfo.end()) {
+      updateDivisibilityInfo(divisibilityInfo, dest, index,
+                             divisibilityIt->second);
+    }
+  }
+}
+
+// Update the tensor dimension information for result of a
+// `flow.dispatch.tensor.load` operation.
+static void updateTensorDimInfo(
+    IREE::Flow::DispatchTensorLoadOp flowLoadOp, const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  // If there are no dynamic dimensions, nothing to do.
+  if (flowLoadOp.getType().hasStaticShape()) {
+    return;
+  }
+  // Check that all strides are 1. Abort otherwise
+  if (llvm::any_of(flowLoadOp.getMixedStrides(),
+                   [](OpFoldResult s) { return !isConstantIntValue(s, 1); })) {
+    return;
+  }
+
+  Value result = flowLoadOp.getResult();
+  for (auto [index, size] : llvm::enumerate(flowLoadOp.getMixedSizes())) {
+    auto dynamicDim = dyn_cast<Value>(size);
+    if (!dynamicDim) {
+      continue;
+    }
+    updateTensorDimInfo(result, index, dynamicDim, solver, divisibilityInfo,
+                        rangeInfo);
+  }
+}
+
+// Update the tensor dimension information for result of a `tensor.empty`
+// operation.
+static void updateTensorDimInfo(
+    tensor::EmptyOp emptyOp, const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  auto dimOperands = emptyOp.getOperands();
+  if (dimOperands.empty()) {
+    return;
+  }
+
+  Value result = emptyOp.getResult();
+  auto resultType = cast<RankedTensorType>(result.getType());
+  int dimOperandIndex = 0;
+  for (auto [index, shape] : llvm::enumerate(resultType.getShape())) {
+    if (!ShapedType::isDynamic(shape))
+      continue;
+    updateTensorDimInfo(result, index, dimOperands[dimOperandIndex++], solver,
+                        divisibilityInfo, rangeInfo);
+  }
+}
+
+// Update the tensor dimension information for results of an operation that
+// implements the `DestinationStyleOpInterface`.
+static void updateTensorDimInfo(
+    DestinationStyleOpInterface dstStyleOp, const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  for (auto [index, result] : llvm::enumerate(dstStyleOp->getResults())) {
+    auto resultTensorType = dyn_cast<RankedTensorType>(result.getType());
+    if (!resultTensorType || resultTensorType.hasStaticShape()) {
+      continue;
+    }
+    Value source = dstStyleOp.getDpsInitOperand(index)->get();
+    transferTensorDimInfo(source, result, solver, divisibilityInfo, rangeInfo);
+  }
+}
+
+// Dispatch to the method that updates the dimension information for an
+// operation.
+static void updateTensorDimInfo(
+    Operation *op, const DataFlowSolver &solver,
+    TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo,
+    TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Start updating op\n";
+    op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n";
+  });
+
+  TypeSwitch<Operation *, void>(op)
+      .Case<IREE::Flow::DispatchTensorLoadOp, tensor::EmptyOp>([&](auto op) {
+        updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo);
+      })
+      .Case<DestinationStyleOpInterface>([&](auto op) {
+        updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo);
+      });
+
+  LLVM_DEBUG({
+    for (auto [resultIndex, result] : llvm::enumerate(op->getResults())) {
+      auto tensorType = dyn_cast<RankedTensorType>(result.getType());
+      if (!tensorType)
+        continue;
+      for (auto index : llvm::seq<unsigned>(0, tensorType.getRank())) {
+        std::optional<ConstantIntRanges> range;
+        std::optional<IREE::Util::ConstantIntDivisibility> divisibility;
+        auto rangeIt = rangeInfo.find({result, index});
+        if (rangeIt != rangeInfo.end()) {
+          range = rangeIt->second;
+        }
+        auto divisibilityIt = divisibilityInfo.find({result, index});
+        if (divisibilityIt != divisibilityInfo.end()) {
+          divisibility = divisibilityIt->second;
+        }
+        if (!range && !divisibility) {
+          continue;
+        }
+        llvm::dbgs() << "\tDim Info: Result number : " << resultIndex
+                     << ", dim " << index;
+        if (range) {
+          llvm::dbgs() << " : Range " << range.value();
+        }
+        if (divisibility) {
+          llvm::dbgs() << " : Divisibility " << divisibility.value();
+        }
+        llvm::dbgs() << "\n";
+      }
+    }
+  });
+}
+
+TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp)
+    : rootOperation(rootOp) {
+  solver.load<mlir::dataflow::DeadCodeAnalysis>();
+  solver.load<mlir::dataflow::IntegerRangeAnalysis>();
+  solver.load<IREE::Util::IntegerDivisibilityAnalysis>();
+}
+
+LogicalResult TensorDynamicDimAnalysis::run() {
+  if (failed(solver.initializeAndRun(rootOperation))) {
+    return failure();
+  }
+
+  // Walk the IR pre-order, forward and update the dynamic information for each
+  // tensor.
+  rootOperation->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo);
+  });
+
+  return success();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h
new file mode 100644
index 0000000..13bdb5c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h
@@ -0,0 +1,65 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+
+namespace mlir::iree_compiler {
+
+/// Analysis to compute information about dynamic dimensions of tensors.
+///
+/// Using the IntegerRangeAnalysis and the IntegerDivisibilityAnalysis
+/// this analysis builds information about the range and divisibility of dynamic
+/// dimensions of tensor operands in the program. The analysis can then be
+/// queried to get the range and divisibility info for any tensor value for any
+/// dynamic dimension.
+/// TODO: This is not a dataflow analysis or does not update information on IR
+/// changes. This could be potentially expensive and is really meant to be used
+/// before any transformations to the dispatch. If this needs to be more
+/// efficient then this needs to be converted to a data flow solver.
+class TensorDynamicDimAnalysis {
+public:
+  explicit TensorDynamicDimAnalysis(Operation *rootOperation);
+
+  LogicalResult run();
+
+  using TensorDimDivisibilityInfo =
+      DenseMap<std::tuple<Value, unsigned>,
+               IREE::Util::ConstantIntDivisibility>;
+  using TensorDimRangeInfo =
+      DenseMap<std::tuple<Value, unsigned>, ConstantIntRanges>;
+
+  std::optional<ConstantIntRanges> getRangeInfo(Value v,
+                                                unsigned dimIndex) const {
+    auto it = rangeInfo.find({v, dimIndex});
+    if (it == rangeInfo.end()) {
+      return std::nullopt;
+    }
+    return it->second;
+  }
+
+  std::optional<IREE::Util::ConstantIntDivisibility>
+  getDivisibilityInfo(Value v, unsigned dimIndex) const {
+    auto it = divisibilityInfo.find({v, dimIndex});
+    if (it == divisibilityInfo.end()) {
+      return std::nullopt;
+    }
+    return it->second;
+  }
+
+private:
+  DataFlowSolver solver;
+
+  // Operation scope within which the analysis is run.
+  Operation *rootOperation;
+
+  // Map of tensor value to integer divisibility information for each dimension.
+  TensorDimDivisibilityInfo divisibilityInfo;
+  TensorDimRangeInfo rangeInfo;
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 7879b58..ab1a76a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -21,6 +21,7 @@
             "add_fmfs.mlir",
             "affinemin_canonicalization.mlir",
             "batch_matmuls.mlir",
+            "block_dynamic_dims.mlir",
             "bubble_up_ordinal_ops.mlir",
             "bufferize_copy_only_dispatches.mlir",
             "canonicalize_interface_load_store.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 832319e..3ac6423 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -17,6 +17,7 @@
     "add_fmfs.mlir"
     "affinemin_canonicalization.mlir"
     "batch_matmuls.mlir"
+    "block_dynamic_dims.mlir"
     "bubble_up_ordinal_ops.mlir"
     "bufferize_copy_only_dispatches.mlir"
     "canonicalize_interface_load_store.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
new file mode 100644
index 0000000..819c412
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
@@ -0,0 +1,101 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
+
+#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+    #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
+func.func @block_attention_dims() {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 8.837890e-02 : f16
+  %m_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+  %k2_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+  %0:2 = util.assume.int
+      %m_in<umin = 16, umax = 4080, udiv = 16>,
+      %k2_in<umin = 16, umax = 4080, udiv = 32>
+    : index, index
+  %m = flow.dispatch.workload.ordinal %0#0, 0 : index
+  %k2 = flow.dispatch.workload.ordinal %0#1, 1 : index
+  %q_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect")
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%m}
+  %key_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect")
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%k2}
+  %value_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect")
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%k2}
+  %mask_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect")
+      : !flow.dispatch.tensor<readonly:tensor<4x32x?x?xf16>>{%m, %k2}
+  %output_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) flags(Indirect)
+      : !flow.dispatch.tensor<writeonly:tensor<4x?x32x128xf16>>{%m}
+  %q = flow.dispatch.tensor.load %q_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%m} -> tensor<4x?x32x128xf16>
+  %key = flow.dispatch.tensor.load %key_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%k2} -> tensor<4x?x32x128xf16>
+  %value = flow.dispatch.tensor.load %value_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<4x?x32x128xf16>>{%k2} -> tensor<4x?x32x128xf16>
+  %mask = flow.dispatch.tensor.load %mask_in, offsets = [0, 0, 0, 0], sizes = [4, 32, %m, %k2], strides = [1, 1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<4x32x?x?xf16>>{%m, %k2} -> tensor<4x32x?x?xf16>
+  %1 = tensor.empty(%m) : tensor<4x?x32x128xf16>
+  %2 = tensor.empty(%m) : tensor<4x32x?x128xf16>
+  %attn = iree_linalg_ext.attention {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d4)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>,
+                       affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]}
+      ins(%q, %key, %value, %cst, %mask : tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, f16, tensor<4x32x?x?xf16>)
+      outs(%2 : tensor<4x32x?x128xf16>) {
+    ^bb0(%b0 : f16) :
+      iree_linalg_ext.yield %b0 : f16
+  }-> tensor<4x32x?x128xf16>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%attn : tensor<4x32x?x128xf16>) outs(%1 : tensor<4x?x32x128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    linalg.yield %in : f16
+  } -> tensor<4x?x32x128xf16>
+  flow.dispatch.tensor.store %result, %output_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1]
+      : tensor<4x?x32x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<4x?x32x128xf16>>{%m}
+  return
+}
+// CHECK-LABEL: func @block_attention_dims()
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32 : index
+//   CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
+//   CHECK-DAG:   %[[M:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 0 : index
+//   CHECK-DAG:   %[[K2:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 1 : index
+//   CHECK-DAG:   %[[M_DYNAMIC:.+]] = arith.divui %[[M]], %[[C16]]
+//       CHECK:   %[[Q_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(0)
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x16x32x128xf16>>{%[[M_DYNAMIC]]}
+//       CHECK:   %[[K2_DYNAMIC:.+]] = arith.divui %[[K2]], %[[C32]]
+//       CHECK:   %[[K_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(1)
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x32x32x128xf16>>{%[[K2_DYNAMIC]]}
+//       CHECK:   %[[V_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(2)
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x32x32x128xf16>>{%[[K2_DYNAMIC]]}
+//       CHECK:   %[[MASK_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(3)
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x32x?x16x?x32xf16>>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]}
+//       CHECK:   %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(4)
+//  CHECK-SAME:       !flow.dispatch.tensor<writeonly:tensor<4x?x16x32x128xf16>>{%[[M_DYNAMIC]]}
+//       CHECK:   %[[Q:.+]] = flow.dispatch.tensor.load %[[Q_BINDING]]
+//  CHECK-SAME:       sizes = [4, %[[M_DYNAMIC]], 16, 32, 128]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x16x32x128xf16>>{%[[M_DYNAMIC]]}
+//       CHECK:   %[[K:.+]] = flow.dispatch.tensor.load %[[K_BINDING]]
+//  CHECK-SAME:       sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x32x32x128xf16>>{%[[K2_DYNAMIC]]}
+//       CHECK:   %[[V:.+]] = flow.dispatch.tensor.load %[[V_BINDING]]
+//  CHECK-SAME:       sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x?x32x32x128xf16>>{%[[K2_DYNAMIC]]}
+//       CHECK:   %[[MASK:.+]] = flow.dispatch.tensor.load %[[MASK_BINDING]]
+//  CHECK-SAME:       sizes = [4, 32, %[[M_DYNAMIC]], 16, %[[K2_DYNAMIC]], 32]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<4x32x?x16x?x32xf16>>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]}
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//       CHECK:       ins(%[[Q]], %[[K]], %[[V]], %{{.+}}, %[[MASK]] :
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//       CHECK:   flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index aab73c9..86f65e1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1179,6 +1179,9 @@
     funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
     addCommonTargetExecutablePreprocessingPasses(funcPassManager);
     addEncodingToNopPasses(funcPassManager);
+    funcPassManager.addPass(createBlockDynamicDimensionsPass);
+    funcPassManager.addPass(createCanonicalizerPass);
+    funcPassManager.addPass(createCSEPass);
   }
   modulePassManager.addPass(createMaterializeUserConfigsPass());
   modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 5e5b2ef..f19a665 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1348,6 +1348,27 @@
 }
 
 //===----------------------------------------------------------------------===//
+// flow.dispatch.workload.ordinal
+//===----------------------------------------------------------------------===//
+
+void DispatchWorkloadOrdinalOp::inferResultDivisibility(
+    ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+    IREE::Util::SetIntDivisibilityFn setResultDivisibility) {
+  if (argDivs[0].isUninitialized()) {
+    setResultDivisibility(getResult(),
+                          IREE::Util::ConstantIntDivisibility(1, 1));
+    return;
+  }
+  setResultDivisibility(getResult(), argDivs[0].getValue());
+}
+
+void DispatchWorkloadOrdinalOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  assert(!argRanges.empty() && "expected range of input to be set");
+  setResultRange(getResult(), argRanges[0]);
+}
+
+//===----------------------------------------------------------------------===//
 // flow.executable
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 301ce8b..69d8cc4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -16,6 +16,7 @@
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1741,7 +1742,11 @@
 }
 
 def FLOW_DispatchWorkloadOrdinalOp :
-    FLOW_PureOp<"dispatch.workload.ordinal"> {
+    FLOW_PureOp<"dispatch.workload.ordinal", [
+      DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
+      DeclareOpInterfaceMethods<InferIntRangeInterface,
+        ["inferResultRanges"]>
+    ]> {
   let arguments = (ins
     Index:$operand,
     IndexAttr:$ordinal