Add `InferIntDivisibilityInterface` for `arith.muli`. (#18994)
Fixes #18973
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
index 7a45116..930c9f5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
@@ -60,6 +60,7 @@
/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns.
struct BlockDynamicDimensionsPass final
: impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> {
+ using Base::Base;
void runOnOperation() override;
};
} // namespace
@@ -222,6 +223,33 @@
prunedOperandsList);
}
+/// Generic method to block dynamic dimensions for all tensor operands.
+/// Only used for testing for now
+static LogicalResult
+blockDynamicDimensions(RewriterBase &rewriter,
+ const TensorDynamicDimAnalysis &dynamicDimAnalysis,
+ Operation *operation, bool test) {
+ return TypeSwitch<Operation *, LogicalResult>(operation)
+ .Case<IREE::LinalgExt::AttentionOp>([&](auto attentionOp) {
+ return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
+ attentionOp);
+ })
+ .Default([&](Operation *op) {
+ if (!test) {
+ return success();
+ }
+ // The default path here is for now only for testing.
+ llvm::SmallDenseSet<int64_t> tensorOperandsList;
+ for (OpOperand &opOperand : operation->getOpOperands()) {
+ if (isa<RankedTensorType>(opOperand.get().getType())) {
+ tensorOperandsList.insert(opOperand.getOperandNumber());
+ }
+ }
+ return blockDynamicDimensions(rewriter, dynamicDimAnalysis, operation,
+ tensorOperandsList);
+ });
+}
+
void BlockDynamicDimensionsPass::runOnOperation() {
Operation *operation = getOperation();
MLIRContext *context = &getContext();
@@ -231,12 +259,10 @@
}
IRRewriter rewriter(context);
- auto walkResult = operation->walk(
- [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult {
- rewriter.setInsertionPoint(attentionOp);
- return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
- attentionOp);
- });
+ auto walkResult = operation->walk([&](Operation *op) -> WalkResult {
+ rewriter.setInsertionPoint(op);
+ return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op, test);
+ });
if (walkResult.wasInterrupted()) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 5aa3ef4..bb0676f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -23,6 +23,9 @@
: Pass<"iree-codegen-block-dynamic-dimensions"> {
let summary = "Expand dynamic dimensions that are known to be multiples of "
"statically known values.";
+ let options = [
+ Option<"test", "test", "bool", /*default=*/"false", "Enable test mode">
+ ];
}
def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
index b0e7667..fa8f774 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
@@ -59,7 +59,8 @@
// Update solver info
auto *divisibilityState =
solver.lookupState<IREE::Util::IntegerDivisibilityLattice>(dynamicDim);
- if (divisibilityState && !divisibilityState->getValue().isUninitialized()) {
+ if (divisibilityState && !divisibilityState->getValue().isUninitialized() &&
+ divisibilityState->getValue().getValue().sdiv() != 1) {
updateDivisibilityInfo(divisibilityInfo, tensorValue, dimIndex,
divisibilityState->getValue().getValue());
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
index 819c412..4dab261 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions{test}, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
@@ -99,3 +99,36 @@
// CHECK: ins(%[[Q]], %[[K]], %[[V]], %{{.+}}, %[[MASK]] :
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]]
+
+// -----
+
+func.func @basic_blocking_test(%arg0 : index) -> tensor<?xf32> {
+ %0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 16> : index
+ %1 = tensor.empty(%0) : tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @basic_blocking_test(
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?x16xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EMPTY]]
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_blocking(%arg0 : index) -> tensor<?xf32> {
+ %1 = tensor.empty(%arg0) : tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_blocking(
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+// CHECK: return %[[EMPTY]]
+
+// -----
+
+func.func @no_unit_blocking(%arg0 : index) -> tensor<?xf32> {
+ %0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 1> : index
+ %1 = tensor.empty(%0) : tensor<?xf32>
+ return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_unit_blocking(
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+// CHECK: return %[[EMPTY]]
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
index 3e2235b..0c80b48 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
@@ -45,3 +45,15 @@
%1 = arith.remui %0, %cst : index
util.return %1 : index
}
+
+// -----
+
+util.func @muli_divisibility(%arg0 : index) -> index {
+ %cst = arith.constant 16 : index
+ %0 = arith.muli %arg0, %cst : index
+ %1 = arith.remui %0, %cst : index
+ util.return %1 : index
+}
+// CHECK-LABEL: @muli_divisibility
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: return %[[C0]]
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
index 2097cbf..1ce8e18 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
namespace mlir::iree_compiler {
@@ -45,6 +46,28 @@
}
};
+struct ArithMulIInferIntDivisibilityOpInterface
+ : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
+ ArithMulIInferIntDivisibilityOpInterface, arith::MulIOp> {
+
+ void inferResultDivisibility(
+ Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+ IREE::Util::SetIntDivisibilityFn setResultDivs) const {
+ auto mulOp = cast<arith::MulIOp>(op);
+ APInt intVal;
+ if (!matchPattern(mulOp.getLhs(), m_ConstantInt(&intVal))) {
+ if (!matchPattern(mulOp.getRhs(), m_ConstantInt(&intVal))) {
+ return;
+ }
+ }
+
+ uint64_t udiv = intVal.getZExtValue();
+ uint64_t sdiv = std::abs(intVal.getSExtValue());
+ setResultDivs(mulOp.getResult(),
+ IREE::Util::ConstantIntDivisibility(udiv, sdiv));
+ }
+};
+
//===----------------------------------------------------------------------===//
// GlobalOpInterface
//===----------------------------------------------------------------------===//
@@ -328,6 +351,8 @@
arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp>(context);
arith::ConstantOp::attachInterface<
ArithConstantInferIntDivisibilityOpInterface>(*context);
+ arith::MulIOp::attachInterface<ArithMulIInferIntDivisibilityOpInterface>(
+ *context);
});
registry.addExtension(