[LinalgExt] Remove default implementation for getStaticLoopRanges (#18745)
The default implementation of getStaticLoopRanges is dangerous and
causes unexpected bugs. It only works for operands with distinct loop
ranges as dimensions. It's better to have operations specify it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5599f47..8e6eb45 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -672,7 +672,12 @@
// Get iteration domain bounds.
OpBuilder b(op);
- SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
+ FailureOr<SmallVector<int64_t>> maybeBounds = op.getStaticLoopRanges();
+ if (failed(maybeBounds)) {
+ return failure();
+ }
+
+ ArrayRef<int64_t> bounds = maybeBounds.value();
auto opInfo =
IREE::LinalgExt::AttentionOpDetail::get(op.getIndexingMapsArray())
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 3acbf0c..f28a004 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -83,8 +83,10 @@
return (llvm::cast<ConcreteType>(op).getNumLoops());
}
- SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
- return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
+ FailureOr<SmallVector<int64_t>>
+ getStaticLoopRanges(mlir::Operation *op) const {
+ return SmallVector<int64_t>(
+ llvm::cast<ConcreteType>(op).getStaticLoopRanges());
}
AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
@@ -128,6 +130,12 @@
}));
}
+ FailureOr<SmallVector<int64_t>> getStaticLoopRanges(Operation *op) const {
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ // Softmax loop range is the input shape.
+ return SmallVector<int64_t>(softmaxOp.getInputOperandType().getShape());
+ }
+
AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return getIndexingMapsForResults(op)[result.getResultNumber()];
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 9607c9e..3942201 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -110,18 +110,12 @@
/*desc=*/[{
Return the static loop ranges.
}],
- /*retTy=*/"SmallVector<int64_t, 4>",
+ /*retTy=*/"FailureOr<SmallVector<int64_t>>",
/*methodName=*/"getStaticLoopRanges",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<int64_t, 4> loopRanges;
- llvm::for_each($_op.getOperands(), [&](Value operand) {
- if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
- llvm::append_range(loopRanges, shapedType.getShape());
- }
- });
- return loopRanges;
+ return failure();
}]
>,
InterfaceMethod<
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 500c203..1e076a5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -252,6 +252,11 @@
.reifyResultShapes(b, reifiedReturnShapes);
}
+FailureOr<SmallVector<int64_t>> ScatterOp::getStaticLoopRanges() {
+ // Scatter loop ranges are loop ranges for update.
+ return SmallVector<int64_t>(getUpdateType().getShape());
+}
+
SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
@@ -1321,8 +1326,8 @@
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}
-SmallVector<int64_t, 4> AttentionOp::getStaticLoopRanges() {
- SmallVector<int64_t, 4> bounds(getIterationDomainRank());
+FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
+ SmallVector<int64_t> bounds(getIterationDomainRank());
SmallVector<bool> dimsFound(getIterationDomainRank(), false);
// batch(s), m, k1
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index f00ce46..c7d6e98 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -94,7 +94,9 @@
def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
- DeclareOpInterfaceMethods<LinalgFusionInterface>,
+ DeclareOpInterfaceMethods<LinalgFusionInterface,
+ ["getIndexingMapsForResults", "getIndexingMapsForOperands",
+ "getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
@@ -469,7 +471,8 @@
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<LinalgFusionInterface,
- ["getIndexingMapsForResults", "getIndexingMapsForOperands"]>,
+ ["getIndexingMapsForResults", "getIndexingMapsForOperands",
+ "getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
@@ -528,8 +531,6 @@
SmallVector<AffineMap> getIndexingMapsArray();
- SmallVector<int64_t, 4> getStaticLoopRanges();
-
AffineMap getQueryMap() {
return cast<AffineMap>(getIndexingMapsArray()[0]);
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 450493f..9f9b755 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -67,8 +67,12 @@
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand);
- SmallVector<int64_t, 4> originalLoopRange = op.getStaticLoopRanges();
- originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+ FailureOr<SmallVector<int64_t>> originalLoopRange = op.getStaticLoopRanges();
+ if (failed(originalLoopRange)) {
+ return failure();
+ }
+ originalLoopExtent.assign(originalLoopRange->begin(),
+ originalLoopRange->end());
reassociation.clear();
expandedShapeMap.clear();
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
index 3aba7fe..e866022 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -568,9 +568,15 @@
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
if (!options.aggressiveFusion) {
- auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
- auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
- if (producerIterationSpace.size() < consumerIterationSpace.size()) {
+ FailureOr<SmallVector<int64_t>> producerIterationSpace =
+ producerFusionOp.getStaticLoopRanges();
+ FailureOr<SmallVector<int64_t>> consumerIterationSpace =
+ consumerFusionOp.getStaticLoopRanges();
+ if (failed(producerIterationSpace) || failed(consumerIterationSpace)) {
+ return false;
+ }
+ if (producerIterationSpace.value().size() <
+ consumerIterationSpace.value().size()) {
return false;
}
}