[LinalgExt] Remove default implementation for getStaticLoopRanges (#18745)

The default implementation of getStaticLoopRanges is dangerous and
causes unexpected bugs. It only works for operands with distinct loop
ranges as dimensions. It's better to have operations specify it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5599f47..8e6eb45 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -672,7 +672,12 @@
 
   // Get iteration domain bounds.
   OpBuilder b(op);
-  SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
+  FailureOr<SmallVector<int64_t>> maybeBounds = op.getStaticLoopRanges();
+  if (failed(maybeBounds)) {
+    return failure();
+  }
+
+  ArrayRef<int64_t> bounds = maybeBounds.value();
 
   auto opInfo =
       IREE::LinalgExt::AttentionOpDetail::get(op.getIndexingMapsArray())
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 3acbf0c..f28a004 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -83,8 +83,10 @@
     return (llvm::cast<ConcreteType>(op).getNumLoops());
   }
 
-  SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
-    return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
+  FailureOr<SmallVector<int64_t>>
+  getStaticLoopRanges(mlir::Operation *op) const {
+    return SmallVector<int64_t>(
+        llvm::cast<ConcreteType>(op).getStaticLoopRanges());
   }
 
   AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
@@ -128,6 +130,12 @@
         }));
   }
 
+  FailureOr<SmallVector<int64_t>> getStaticLoopRanges(Operation *op) const {
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    // Softmax loop range is the input shape.
+    return SmallVector<int64_t>(softmaxOp.getInputOperandType().getShape());
+  }
+
   AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
                                          OpResult result) const {
     return getIndexingMapsForResults(op)[result.getResultNumber()];
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 9607c9e..3942201 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -110,18 +110,12 @@
       /*desc=*/[{
         Return the static loop ranges.
       }],
-      /*retTy=*/"SmallVector<int64_t, 4>",
+      /*retTy=*/"FailureOr<SmallVector<int64_t>>",
       /*methodName=*/"getStaticLoopRanges",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-          SmallVector<int64_t, 4> loopRanges;
-          llvm::for_each($_op.getOperands(), [&](Value operand) {
-          if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
-            llvm::append_range(loopRanges, shapedType.getShape());
-            }
-          });
-          return loopRanges;
+        return failure();
       }]
     >,
     InterfaceMethod<
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 500c203..1e076a5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -252,6 +252,11 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+FailureOr<SmallVector<int64_t>> ScatterOp::getStaticLoopRanges() {
+  // Scatter loop ranges are loop ranges for update.
+  return SmallVector<int64_t>(getUpdateType().getShape());
+}
+
 SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
   Builder builder(getContext());
   return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
@@ -1321,8 +1326,8 @@
       getIndexingMaps().getAsValueRange<AffineMapAttr>());
 }
 
-SmallVector<int64_t, 4> AttentionOp::getStaticLoopRanges() {
-  SmallVector<int64_t, 4> bounds(getIterationDomainRank());
+FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
+  SmallVector<int64_t> bounds(getIterationDomainRank());
   SmallVector<bool> dimsFound(getIterationDomainRank(), false);
 
   // batch(s), m, k1
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index f00ce46..c7d6e98 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -94,7 +94,9 @@
 
 def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
     [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-     DeclareOpInterfaceMethods<LinalgFusionInterface>,
+     DeclareOpInterfaceMethods<LinalgFusionInterface,
+      ["getIndexingMapsForResults", "getIndexingMapsForOperands",
+       "getStaticLoopRanges"]>,
      DeclareOpInterfaceMethods<TilingInterface,
         ["generateScalarImplementation",
          "getIterationDomain",
@@ -469,7 +471,8 @@
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DestinationStyleOpInterface, LinalgExtInterface,
      DeclareOpInterfaceMethods<LinalgFusionInterface,
-      ["getIndexingMapsForResults", "getIndexingMapsForOperands"]>,
+      ["getIndexingMapsForResults", "getIndexingMapsForOperands",
+       "getStaticLoopRanges"]>,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
@@ -528,8 +531,6 @@
 
     SmallVector<AffineMap> getIndexingMapsArray();
 
-    SmallVector<int64_t, 4> getStaticLoopRanges();
-
     AffineMap getQueryMap() {
       return cast<AffineMap>(getIndexingMapsArray()[0]);
     }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 450493f..9f9b755 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -67,8 +67,12 @@
   if (reassociationMaps.empty())
     return failure();
   AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand);
-  SmallVector<int64_t, 4> originalLoopRange = op.getStaticLoopRanges();
-  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+  FailureOr<SmallVector<int64_t>> originalLoopRange = op.getStaticLoopRanges();
+  if (failed(originalLoopRange)) {
+    return failure();
+  }
+  originalLoopExtent.assign(originalLoopRange->begin(),
+                            originalLoopRange->end());
 
   reassociation.clear();
   expandedShapeMap.clear();
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
index 3aba7fe..e866022 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -568,9 +568,15 @@
   // TODO(#12664): This is unnecessary requirement, but we need a better config
   // to tile the consumer with a larger iteration space.
   if (!options.aggressiveFusion) {
-    auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
-    auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
-    if (producerIterationSpace.size() < consumerIterationSpace.size()) {
+    FailureOr<SmallVector<int64_t>> producerIterationSpace =
+        producerFusionOp.getStaticLoopRanges();
+    FailureOr<SmallVector<int64_t>> consumerIterationSpace =
+        consumerFusionOp.getStaticLoopRanges();
+    if (failed(producerIterationSpace) || failed(consumerIterationSpace)) {
+      return false;
+    }
+    if (producerIterationSpace.value().size() <
+        consumerIterationSpace.value().size()) {
       return false;
     }
   }