Adding `iree.tensor.trace` support for printf debugging. (#16746)

Any tensor level op prior to or during flow can be annotated with the
`iree.tensor.trace` attribute to have `flow.tensor.trace` ops for all
tensor operands and results generated by the pass. The attribute can
either be a unit attr to have the trace key chosen automatically or a
string attr to specify it. We run the pass once at the head of the
pipeline prior to dispatch region formation and again once after, but
users can also slice out IR at any phase, add the attributes, use
iree-opt to run the pass, and pipe it back through the pipeline to
continue compilation.

This is printf debugging: it's not great, but it gets the job done.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index db86928..cdebd05 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -49,6 +49,7 @@
         "FusionOfTensorOps.cpp",
         "InitializeEmptyTensors.cpp",
         "InjectDispatchTracing.cpp",
+        "InjectTensorTracing.cpp",
         "InsertDispatchDebugTargets.cpp",
         "InterchangeGenericOps.cpp",
         "InterchangeTransposeGenericOps.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index cc01b43..e777ca4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -48,6 +48,7 @@
     "FusionOfTensorOps.cpp"
     "InitializeEmptyTensors.cpp"
     "InjectDispatchTracing.cpp"
+    "InjectTensorTracing.cpp"
     "InsertDispatchDebugTargets.cpp"
     "InterchangeGenericOps.cpp"
     "InterchangeTransposeGenericOps.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp
new file mode 100644
index 0000000..768a907
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp
@@ -0,0 +1,100 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir::iree_compiler::IREE::Flow {
+
+static std::string inferTraceKey(Operation *op) {
+  return TypeSwitch<Operation *, std::string>(op)
+      .Case<IREE::Flow::DispatchOp>(
+          [&](auto op) { return op.getEntryPointName(); })
+      .Case<IREE::Util::CallOp>([&](auto op) { return op.getCallee().str(); })
+      .Default([](auto *op) { return op->getName().getStringRef().str(); });
+}
+
+static SmallVector<Value> filterTensorValues(ValueRange &&range) {
+  SmallVector<Value> result;
+  for (auto value : range) {
+    if (llvm::isa<TensorType>(value.getType()))
+      result.push_back(value);
+  }
+  return result;
+}
+
+static SmallVector<Value> getTensorOperands(Operation *op) {
+  if (auto dispatchRegionOp = dyn_cast<IREE::Flow::DispatchRegionOp>(op)) {
+    llvm::SetVector<Value> argumentsSet;
+    mlir::getUsedValuesDefinedAbove(dispatchRegionOp.getBody(), argumentsSet);
+    return filterTensorValues(argumentsSet.takeVector());
+  }
+  return filterTensorValues(op->getOperands());
+}
+
+static void injectTracingOnOp(Operation *op, StringRef traceKey) {
+  OpBuilder builder(op);
+  auto inputTensors = getTensorOperands(op);
+  if (!inputTensors.empty()) {
+    builder.create<IREE::Flow::TensorTraceOp>(
+        op->getLoc(), builder.getStringAttr(traceKey + " inputs"),
+        inputTensors);
+  }
+
+  builder.setInsertionPointAfter(op);
+  auto outputTensors = filterTensorValues(op->getResults());
+  if (!outputTensors.empty()) {
+    builder.create<IREE::Flow::TensorTraceOp>(
+        op->getLoc(), builder.getStringAttr(traceKey + " outputs"),
+        outputTensors);
+  }
+}
+
+class InjectTensorTracingPass
+    : public InjectTensorTracingBase<InjectTensorTracingPass> {
+public:
+  InjectTensorTracingPass() = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, IREE::Flow::FlowDialect,
+                    tensor::TensorDialect>();
+  }
+
+  void runOnOperation() override {
+    auto attrName = StringAttr::get(&getContext(), "iree.tensor.trace");
+    auto funcOp = getOperation();
+    funcOp.walk([&](Operation *op) {
+      if (auto attr = op->getAttr(attrName)) {
+        std::string traceKey;
+        if (auto stringAttr = dyn_cast<StringAttr>(attr))
+          traceKey = stringAttr.getValue().str();
+        else
+          traceKey = inferTraceKey(op);
+        injectTracingOnOp(op, traceKey);
+        op->removeAttr(attrName);
+      }
+    });
+  }
+};
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createInjectTensorTracingPass() {
+  return std::make_unique<InjectTensorTracingPass>();
+}
+
+} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 56edc9e..ea3ae99 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -126,6 +126,11 @@
   // Start of Flow pipeline, verify input legality.
   passManager.addPass(IREE::Flow::createVerifyInputLegalityPass());
 
+  // Inject tensor tracing early as we need to have the tracers in the IR
+  // prior to dispatch region formation where we may lose access to them.
+  FunctionLikeNest(passManager)
+      .addPass(IREE::Flow::createInjectTensorTracingPass);
+
   // Transform pad operations into linalg.fill + tensor.insert_slice.
   // This is a WAR for not having native pad handling.
   if (!clEnablePadHandling && !clEnableFusePaddingIntoLinalgProducerOps) {
@@ -270,6 +275,9 @@
       // match later stages.
       .addPredicatedPass(clTraceDispatchTensors,
                          IREE::Flow::createInjectDispatchTracingPass)
+      // Inject tensor tracing late for any attributes that were added by the
+      // passes above after we've formed dispatch regions.
+      .addPass(IREE::Flow::createInjectTensorTracingPass)
       // Cleanup the IR after we are done.
       .addPass(IREE::Flow::createCleanupTensorShapesPass)
       .addPass(mlir::createCanonicalizerPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 60cf678..467f334 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -166,6 +166,10 @@
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createInjectDispatchTracingPass();
 
+// Injects tensor tracing on ops annotated with `iree.tensor.trace`.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createInjectTensorTracingPass();
+
 // Crops the program and inserts trace markers at the specified symbols.
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createInsertDebugTargetAtSymbolPass(std::string breakDebugTarget = "",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 743b19e..5e9e0d4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -191,6 +191,12 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createInjectDispatchTracingPass()";
 }
 
+def InjectTensorTracing :
+    InterfacePass<"iree-flow-inject-tensor-tracing", "mlir::FunctionOpInterface"> {
+  let summary = "Injects tensor tracing on ops annotated with `iree.tensor.trace`.";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createInjectTensorTracingPass()";
+}
+
 def InsertDebugTargetAtSymbol :
     Pass<"iree-flow-insert-debug-target-at-symbol", "mlir::ModuleOp"> {
   let summary = "Crops and/or traces the program at the specified symbol";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index 8021ed7..f8dcdb3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -37,6 +37,7 @@
             "fusion_of_tensor_ops.mlir",
             "initialize_empty_tensors.mlir",
             "inject_dispatch_tracing.mlir",
+            "inject_tensor_tracing.mlir",
             "insert_dispatch_debug_targets.mlir",
             "interchange_generic_ops.mlir",
             "interchange_transpose_generic_ops.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index d1a8d7d..7fcf6e0 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -35,6 +35,7 @@
     "fusion_of_tensor_ops.mlir"
     "initialize_empty_tensors.mlir"
     "inject_dispatch_tracing.mlir"
+    "inject_tensor_tracing.mlir"
     "insert_dispatch_debug_targets.mlir"
     "interchange_generic_ops.mlir"
     "interchange_transpose_generic_ops.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/inject_tensor_tracing.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/inject_tensor_tracing.mlir
new file mode 100644
index 0000000..f03084b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/inject_tensor_tracing.mlir
@@ -0,0 +1,85 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(util.func(iree-flow-inject-tensor-tracing))' --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL: util.func public @traceTensorOp
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>, %[[ARG1:.+]]: tensor<4xf32>)
+util.func public @traceTensorOp(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+  //      CHECK: flow.tensor.trace "arith.addf inputs" = [%[[ARG0]] : tensor<4xf32>, %[[ARG1]] : tensor<4xf32>]
+  // CHECK-NEXT: %[[RESULT:.+]] = arith.addf
+  //  CHECK-NOT: iree.tensor.trace
+  %result = arith.addf %arg0, %arg1 {iree.tensor.trace} : tensor<4xf32>
+  // CHECK-NEXT: flow.tensor.trace "arith.addf outputs" = [%[[RESULT]] : tensor<4xf32>]
+  util.return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: util.func public @traceDispatchRegion
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>, %[[ARG1:.+]]: tensor<4xi32>)
+util.func public @traceDispatchRegion(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> tensor<4xf32> {
+  //      CHECK: flow.tensor.trace "flow.dispatch.region inputs" = [%[[ARG0]] : tensor<4xf32>, %[[ARG1]] : tensor<4xi32>]
+  // CHECK-NEXT: %[[RESULT:.+]] = flow.dispatch.region
+  //  CHECK-NOT: iree.tensor.trace
+  %result = flow.dispatch.region -> (tensor<4xf32>) attributes {iree.tensor.trace} {
+    %0 = "some.op"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32>
+    flow.return %0 : tensor<4xf32>
+  }
+  //      CHECK: flow.tensor.trace "flow.dispatch.region outputs" = [%[[RESULT]] : tensor<4xf32>]
+  util.return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: util.func public @traceDispatch
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?xf32>, %[[ARG1:.+]]: tensor<?xi32>)
+util.func public @traceDispatch(%arg0: tensor<?xf32>, %arg1: tensor<?xi32>) -> (tensor<?xf32>, tensor<?xi16>) {
+  %c0 = arith.constant 0 : index
+  //  CHECK-DAG: %[[ARG0_D0:.+]] = tensor.dim %[[ARG0]], %c0
+  %arg0_d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  //  CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %c0
+  %arg1_d0 = tensor.dim %arg1, %c0 : tensor<?xi32>
+  //      CHECK: flow.tensor.trace "ex::entry0 inputs" = [%[[ARG0]] : tensor<?xf32>{%[[ARG0_D0]]}, %[[ARG1]] : tensor<?xi32>{%[[ARG1_D0]]}]
+  // CHECK-NEXT: %[[RESULT:.+]]:2 = flow.dispatch @ex::@entry0
+  //  CHECK-NOT: iree.tensor.trace
+  %result:2 = flow.dispatch @ex::@entry0(%arg0, %arg1) {iree.tensor.trace} : (tensor<?xf32>{%arg0_d0}, tensor<?xi32>{%arg1_d0}) -> (%arg0 as tensor<?xf32>{%arg0_d0}, tensor<?xi16>{%arg1_d0})
+  // CHECK-NEXT: flow.tensor.trace "ex::entry0 outputs" = [%[[RESULT]]#0 : tensor<?xf32>{%[[ARG0_D0]]}, %[[RESULT]]#1 : tensor<?xi16>{%[[ARG1_D0]]}]
+  util.return %result#0, %result#1 : tensor<?xf32>, tensor<?xi16>
+}
+
+// -----
+
+util.func private @callee(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> tensor<4xf32>
+
+// CHECK-LABEL: util.func public @traceCall
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>, %[[ARG1:.+]]: tensor<4xi32>)
+util.func public @traceCall(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>) {
+  //      CHECK: flow.tensor.trace "callee inputs" = [%[[ARG0]] : tensor<4xf32>, %[[ARG1]] : tensor<4xi32>]
+  // CHECK-NEXT: %[[RESULT0:.+]] = util.call @callee
+  //  CHECK-NOT: iree.tensor.trace
+  %result0 = util.call @callee(%arg0, %arg1) {iree.tensor.trace} : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32>
+  // CHECK-NEXT: flow.tensor.trace "callee outputs" = [%[[RESULT0]] : tensor<4xf32>]
+  //      CHECK: flow.tensor.trace "a key inputs" = [%[[ARG0]] : tensor<4xf32>, %[[ARG1]] : tensor<4xi32>]
+  // CHECK-NEXT: %[[RESULT1:.+]] = util.call @callee
+  //  CHECK-NOT: iree.tensor.trace
+  %result1 = util.call @callee(%arg0, %arg1) {iree.tensor.trace = "a key"} : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32>
+  // CHECK-NEXT: flow.tensor.trace "a key outputs" = [%[[RESULT1]] : tensor<4xf32>]
+  util.return %result0, %result1 : tensor<4xf32>, tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: util.func public @traceNested
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>, %[[ARG1:.+]]: tensor<4xf32>
+util.func public @traceNested(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %cond: i1) -> tensor<4xf32> {
+  // CHECK: scf.if
+  %result = scf.if %cond -> tensor<4xf32> {
+    // CHECK-NEXT: flow.tensor.trace "arith.addf inputs" = [%[[ARG0]] : tensor<4xf32>, %[[ARG1]] : tensor<4xf32>]
+    // CHECK-NEXT: %[[RESULT:.+]] = arith.addf
+    //  CHECK-NOT: iree.tensor.trace
+    %0 = arith.addf %arg0, %arg1 {iree.tensor.trace} : tensor<4xf32>
+    // CHECK-NEXT: flow.tensor.trace "arith.addf outputs" = [%[[RESULT]] : tensor<4xf32>]
+    scf.yield %0 : tensor<4xf32>
+  } else {
+    scf.yield %arg0 : tensor<4xf32>
+  }
+  util.return %result : tensor<4xf32>
+}