Handle LinalgExt ops in dispatch region pretty names (#13048)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index ad0e693..601ae33 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -31,18 +31,55 @@
 namespace Flow {
 namespace {
 
+static int64_t costOfDomain(ArrayRef<int64_t> domain) {
+  int64_t product = 1;
+  for (int64_t size : domain) {
+    if (size == mlir::ShapedType::kDynamic) return INT64_MAX;
+    product *= size;
+  }
+  return product;
+};
+
 // Estimates the evaluation cost of a linalg op using a heuristic cost model.
 static int64_t estimateLinalgOpCost(linalg::LinalgOp op) {
-  if (op.hasDynamicShape()) {
-    // Note: bounded dynamic shapes would be interesting, if the compiler used
-    // them. For now just treat dynamic shapes as arbitrarily large.
-    return INT64_MAX;
-  }
+  // For linalg ops we know the iteration domain, so return the number
+  // of iterations of the iteration domain (or INT64_MAX for dynamic.)
+  int64_t cost = costOfDomain(op.getStaticLoopRanges());
+  LLVM_DEBUG(llvm::dbgs() << "// " << op->getName() << " cost: " << cost
+                          << "\n");
+  return cost;
+}
 
-  int64_t cost = 1;
-  for (auto loopRange : op.getStaticLoopRanges()) {
-    cost *= loopRange;
+static TensorType getMainTensorForLinalgExtOp(Operation *op) {
+  TensorType main;
+  auto operandTypes = llvm::to_vector(op->getOperandTypes());
+  auto resultTypes = llvm::to_vector(op->getResultTypes());
+  for (Type t : llvm::concat<Type>(operandTypes, resultTypes)) {
+    auto tensorType = t.dyn_cast<TensorType>();
+    if (!tensorType) continue;
+    if (!main) {
+      main = tensorType;
+    } else if (costOfDomain(tensorType.getShape()) >
+               costOfDomain(main.getShape())) {
+      main = tensorType;
+    }
   }
+  return main;
+}
+
+// Estimates the evaluation cost of a LinalgExt op using a heuristic cost
+// model.
+static int64_t estimateLinalgExtOpCost(Operation *op) {
+  TensorType mainTensor = getMainTensorForLinalgExtOp(op);
+  // Use the cost of the biggest tensor of the LinalgExt op as an approximation.
+  // This is a very, very coarse approximation.
+  auto cost = mainTensor ? costOfDomain(mainTensor.getShape()) : 1;
+  // Multiply by a semi-arbitrarily chosen factor to capture that LinalgExt ops
+  // are "somewhat more expensive" than simply traversing the main tensor.
+  // This is something like the extra log(N) factor for a sort or FFT, or
+  // the amount of work done by a softmax vs a cheap elementwise on a tensor
+  // of the same shape.
+  cost *= 10;
   LLVM_DEBUG(llvm::dbgs() << "// " << op->getName() << " cost: " << cost
                           << "\n");
   return cost;
@@ -123,6 +160,21 @@
          (opTypes.empty() ? "" : "_" + opTypes);
 }
 
+static std::string summarizeLinalgExtOp(Operation *op) {
+  auto opName = op->getName().getStringRef();
+  if (!opName.consume_front("iree_linalg_ext.")) return "";
+  std::string suffix = "";
+  if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) {
+    llvm::raw_string_ostream sstream(suffix);
+    sstream << "_";
+    sstream << loopRangesToString(mainTensor.getShape());
+    sstream << "x";
+    mainTensor.getElementType().print(sstream);
+    sstream.flush();
+  }
+  return opName.str() + suffix;
+}
+
 // Summarizes the contents of a dispatch into a short string.
 // This uses heuristics to aid developer debugging.
 static std::string summarizeDispatchWorkgroupsOp(
@@ -165,7 +217,14 @@
                          << "// new best op: '" << bestOp->getName()
                          << "', cost: " << bestEstimatedCost << "\n");
             })
-        // TODO(scotttodd): IREE::LinalgExt::LinalgExtOp
+        .Case<IREE::LinalgExt::LinalgExtOp>([&](auto op) {
+          int64_t estimatedCost = estimateLinalgExtOpCost(op);
+          if (estimatedCost < bestEstimatedCost) return;
+          bestEstimatedCost = estimatedCost;
+          bestOp = op;
+          LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName()
+                                  << "', cost: " << bestEstimatedCost << "\n");
+        })
         .Default([&](Operation *op) {
           // No cost estimation implemented, skip.
         });
@@ -191,7 +250,8 @@
         bestSummary =
             opName + "_" + encoding.str() + "_" + loopRangesToString(shape);
       })
-      // TODO(scotttodd): IREE::LinalgExt::LinalgExtOp
+      .Case<IREE::LinalgExt::LinalgExtOp>(
+          [&](auto op) { bestSummary = summarizeLinalgExtOp(op); })
       .Default([&](Operation *op) {
         // No summarization implemented, default to the op's name.
         bestSummary = op->getName().getStringRef().str();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 6e95292..4b548b1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -321,3 +321,26 @@
   }
   return %3, %arg1, %arg2 : tensor<?x?xf32>, index, index
 }
+
+// -----
+
+// iree_linalg_ext ops get a heuristics-driven summary in their name.
+
+//      CHECK: flow.executable private @main_dispatch_0 {
+// CHECK-NEXT:   flow.executable.export public @main_dispatch_0_softmax_7xf32
+//      CHECK: func.func @main_dispatch_0_softmax_7xf32(
+func.func @main(%arg0: tensor<7xf32>) -> tensor<7xf32> {
+  %c7 = arith.constant 7 : index
+  %0 = flow.dispatch.workgroups[%c7](%arg0) : (tensor<7xf32>) -> tensor<7xf32> =
+      (%arg1: !flow.dispatch.tensor<readonly:tensor<7xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<7xf32>>) {
+    %1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [7], strides = [1] : !flow.dispatch.tensor<readonly:tensor<7xf32>> -> tensor<7xf32>
+    %2 = tensor.empty() : tensor<7xf32>
+    %3 = iree_linalg_ext.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32>
+    flow.dispatch.tensor.store %3, %arg2, offsets = [0], sizes = [7], strides = [1] : tensor<7xf32> -> !flow.dispatch.tensor<writeonly:tensor<7xf32>>
+    flow.return
+  } count(%arg1: index) -> (index, index, index) {
+    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1
+    flow.return %x, %y, %z : index, index, index
+  }
+  return %0 : tensor<7xf32>
+}