Handle LinalgExt ops in dispatch region pretty names (#13048)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index ad0e693..601ae33 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -31,18 +31,55 @@
namespace Flow {
namespace {
+static int64_t costOfDomain(ArrayRef<int64_t> domain) {
+ int64_t product = 1;
+ for (int64_t size : domain) {
+ if (size == mlir::ShapedType::kDynamic) return INT64_MAX;
+ product *= size;
+ }
+ return product;
+};
+
// Estimates the evaluation cost of a linalg op using a heuristic cost model.
static int64_t estimateLinalgOpCost(linalg::LinalgOp op) {
- if (op.hasDynamicShape()) {
- // Note: bounded dynamic shapes would be interesting, if the compiler used
- // them. For now just treat dynamic shapes as arbitrarily large.
- return INT64_MAX;
- }
+ // For linalg ops we know the iteration domain, so return the number
+ // of iterations of the iteration domain (or INT64_MAX for dynamic.)
+ int64_t cost = costOfDomain(op.getStaticLoopRanges());
+ LLVM_DEBUG(llvm::dbgs() << "// " << op->getName() << " cost: " << cost
+ << "\n");
+ return cost;
+}
- int64_t cost = 1;
- for (auto loopRange : op.getStaticLoopRanges()) {
- cost *= loopRange;
+static TensorType getMainTensorForLinalgExtOp(Operation *op) {
+ TensorType main;
+ auto operandTypes = llvm::to_vector(op->getOperandTypes());
+ auto resultTypes = llvm::to_vector(op->getResultTypes());
+ for (Type t : llvm::concat<Type>(operandTypes, resultTypes)) {
+ auto tensorType = t.dyn_cast<TensorType>();
+ if (!tensorType) continue;
+ if (!main) {
+ main = tensorType;
+ } else if (costOfDomain(tensorType.getShape()) >
+ costOfDomain(main.getShape())) {
+ main = tensorType;
+ }
}
+ return main;
+}
+
+// Estimates the evaluation cost of a LinalgExt op using a heuristic cost
+// model.
+static int64_t estimateLinalgExtOpCost(Operation *op) {
+ TensorType mainTensor = getMainTensorForLinalgExtOp(op);
+ // Use the cost of the biggest tensor of the LinalgExt op as an approximation.
+ // This is a very, very coarse approximation.
+ auto cost = mainTensor ? costOfDomain(mainTensor.getShape()) : 1;
+ // Multiply by a semi-arbitrarily chosen factor to capture that LinalgExt ops
+ // are "somewhat more expensive" than simply traversing the main tensor.
+ // This is something like the extra log(N) factor for a sort or FFT, or
+ // the amount of work done by a softmax vs a cheap elementwise on a tensor
+ // of the same shape.
+ cost *= 10;
LLVM_DEBUG(llvm::dbgs() << "// " << op->getName() << " cost: " << cost
<< "\n");
return cost;
@@ -123,6 +160,21 @@
(opTypes.empty() ? "" : "_" + opTypes);
}
+static std::string summarizeLinalgExtOp(Operation *op) {
+ auto opName = op->getName().getStringRef();
+ if (!opName.consume_front("iree_linalg_ext.")) return "";
+ std::string suffix = "";
+ if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) {
+ llvm::raw_string_ostream sstream(suffix);
+ sstream << "_";
+ sstream << loopRangesToString(mainTensor.getShape());
+ sstream << "x";
+ mainTensor.getElementType().print(sstream);
+ sstream.flush();
+ }
+ return opName.str() + suffix;
+}
+
// Summarizes the contents of a dispatch into a short string.
// This uses heuristics to aid developer debugging.
static std::string summarizeDispatchWorkgroupsOp(
@@ -165,7 +217,14 @@
<< "// new best op: '" << bestOp->getName()
<< "', cost: " << bestEstimatedCost << "\n");
})
- // TODO(scotttodd): IREE::LinalgExt::LinalgExtOp
+ .Case<IREE::LinalgExt::LinalgExtOp>([&](auto op) {
+ int64_t estimatedCost = estimateLinalgExtOpCost(op);
+ if (estimatedCost < bestEstimatedCost) return;
+ bestEstimatedCost = estimatedCost;
+ bestOp = op;
+ LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName()
+ << "', cost: " << bestEstimatedCost << "\n");
+ })
.Default([&](Operation *op) {
// No cost estimation implemented, skip.
});
@@ -191,7 +250,8 @@
bestSummary =
opName + "_" + encoding.str() + "_" + loopRangesToString(shape);
})
- // TODO(scotttodd): IREE::LinalgExt::LinalgExtOp
+ .Case<IREE::LinalgExt::LinalgExtOp>(
+ [&](auto op) { bestSummary = summarizeLinalgExtOp(op); })
.Default([&](Operation *op) {
// No summarization implemented, default to the op's name.
bestSummary = op->getName().getStringRef().str();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 6e95292..4b548b1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -321,3 +321,26 @@
}
return %3, %arg1, %arg2 : tensor<?x?xf32>, index, index
}
+
+// -----
+
+// iree_linalg_ext ops get a heuristics-driven summary in their name.
+
+// CHECK: flow.executable private @main_dispatch_0 {
+// CHECK-NEXT: flow.executable.export public @main_dispatch_0_softmax_7xf32
+// CHECK: func.func @main_dispatch_0_softmax_7xf32(
+func.func @main(%arg0: tensor<7xf32>) -> tensor<7xf32> {
+ %c7 = arith.constant 7 : index
+ %0 = flow.dispatch.workgroups[%c7](%arg0) : (tensor<7xf32>) -> tensor<7xf32> =
+ (%arg1: !flow.dispatch.tensor<readonly:tensor<7xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<7xf32>>) {
+ %1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [7], strides = [1] : !flow.dispatch.tensor<readonly:tensor<7xf32>> -> tensor<7xf32>
+ %2 = tensor.empty() : tensor<7xf32>
+ %3 = iree_linalg_ext.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32>
+ flow.dispatch.tensor.store %3, %arg2, offsets = [0], sizes = [7], strides = [1] : tensor<7xf32> -> !flow.dispatch.tensor<writeonly:tensor<7xf32>>
+ flow.return
+ } count(%arg1: index) -> (index, index, index) {
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1
+ flow.return %x, %y, %z : index, index, index
+ }
+ return %0 : tensor<7xf32>
+}