Adding a flag to force indirect command buffers on in non-reusable cases. (#18945)

Includes various fixes found during testing.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 05ab667..308c880 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -160,6 +160,17 @@
   }
 };
 
+// Returns the ABI or an empty string if unspecified.
+static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) {
+  if (targetAttr) {
+    if (auto config = targetAttr.getConfiguration()) {
+      auto abiAttr = targetAttr.getConfiguration().getAs<StringAttr>("abi");
+      return abiAttr ? abiAttr.getValue() : "";
+    }
+  }
+  return "";
+}
+
 static void dumpModuleToPath(StringRef path, StringRef baseName,
                              StringRef suffix, StringRef extension,
                              llvm::Module &module) {
@@ -585,8 +596,7 @@
 
     // Wrap the HSACO ELF binary in a Flatbuffers container.
     FailureOr<DenseIntElementsAttr> binaryContainer;
-    if (targetAttr.getConfiguration() &&
-        targetAttr.getConfiguration().getAs<StringAttr>("abi") == "amdgpu") {
+    if (getABI(targetAttr) == "amdgpu") {
       binaryContainer = serializeAMDGPUBinaryContainer(
           serializationOptions, variantOp, exportOps, targetHSACO);
     } else {
diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
index 7f83a5e..d12dbc5 100644
--- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
+++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
@@ -1100,8 +1100,12 @@
     return new Error("not a valid HAL executable");
   }
   auto binaryOp = binaryOps.front();
-  auto rawData = binaryOp.getData().getRawData();
-  output.outputStream->write(rawData.data(), rawData.size());
+  if (failed(cast<IREE::Util::SerializableAttrInterface>(binaryOp.getData())
+                 .serializeToStream(binaryOp.getLoc(), llvm::endianness::little,
+                                    *output.outputStream))) {
+    return new Error(
+        "data attribute failed to serialize: unsupported format or encoding");
+  }
   output.outputStream->flush();
   return output.getWriteError();
 }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 4ca2fdb..491bd87 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -32,6 +32,15 @@
     llvm::cl::init(true),
 };
 
+// TODO(benvanik): remove when we support capturing dynamic values for reuse.
+static llvm::cl::opt<bool> clForceIndirectCommandBuffers{
+    "iree-hal-force-indirect-command-buffers",
+    llvm::cl::desc("Forces indirect command buffers when they would otherwise "
+                   "not be chosen due to the values they capture. They may not "
+                   "be reusable but will still be outlined."),
+    llvm::cl::init(false),
+};
+
 struct ContextResolveOpPattern
     : public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
   using StreamConversionPattern::StreamConversionPattern;
@@ -1002,7 +1011,9 @@
     // changes dispatches to use them for any dispatch we can - note that there
     // may still be some that slip through due to custom executables.
     const bool capturesDynamicUniformValues =
-        regionCapturesDynamicUniformValues(executeOp);
+        clForceIndirectCommandBuffers
+            ? false
+            : regionCapturesDynamicUniformValues(executeOp);
 
     // Calculate the indirect buffer references used within the command buffer
     // by analyzing captured resources. This analysis will be used by subsequent
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 2b2f23c..3f1e811 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -150,8 +150,6 @@
 def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">;
 def HAL_OrdinalArrayAttr : TypedArrayAttrBase<HAL_OrdinalAttr, "Array of index ordinal attributes">;
 
-def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>;
-
 def HAL_ElementType : TypeAlias<I32>;
 def HAL_ElementTypeAttr : SignlessIntegerAttrBase<
   I32, "element type attribute">;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index e3f05fc..c1d9a4e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -2593,7 +2593,7 @@
     OptionalAttr<StrAttr>:$sym_visibility,
     SymbolNameAttr:$sym_name,
     StrAttr:$format,
-    HAL_ExecutableDataAttr:$data,
+    Util_AnySerializableAttr:$data,
     OptionalAttr<StrAttr>:$mime_type
     // TODO(benvanik): add compatibility and versioning attributes.
   );
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
index 3044786..19b8d49 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
@@ -13,6 +13,7 @@
 #include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "iree/compiler/Utils/StringUtils.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,6 +25,8 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/RegionUtils.h"
 
+#define DEBUG_TYPE "iree-hal-outline-memoize-regions"
+
 namespace mlir::iree_compiler::IREE::HAL {
 
 #define GEN_PASS_DEF_OUTLINEMEMOIZEREGIONSPASS
@@ -153,6 +156,8 @@
                                                          name, funcType);
   moduleSymbolTable.insert(funcOp);
   funcOp.setVisibility(SymbolTable::Visibility::Private);
+  funcOp.setInliningPolicyAttr(
+      moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
   auto funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock());
 
   // Remap any captured operands that have corresponding function arguments.
@@ -521,8 +526,11 @@
   // If we can't memoize the resources at initialization time then we need
   // to do it on-demand.
   if (!memoizeAnalysis.canRunAtInitializationTime()) {
-    memoizeOp.emitWarning(
-        "memoization failed: dynamic values captured at the call site");
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "memoization failed: dynamic values captured at the call site\n";
+      memoizeOp.dump();
+    });
     replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
     return;
   }
@@ -532,8 +540,11 @@
   auto deviceGlobals =
       deviceAnalysis.lookupDeviceGlobals(memoizeOp.getDevice());
   if (!deviceGlobals) {
-    memoizeOp.emitWarning("memoization failed: unable to analyze devices "
-                          "that may be used with memoized region");
+    LLVM_DEBUG({
+      llvm::dbgs() << "memoization failed: unable to analyze devices that may "
+                      "be used with memoized region\n";
+      memoizeOp.dump();
+    });
     replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
     return;
   }
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
index 9690bca..a71011b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
@@ -84,7 +84,6 @@
     // Note that we want to remove entire chains of unused calls and run this
     // as a pattern application.
     RewritePatternSet patterns(&getContext());
-    // patterns
     patterns.insert<EraseUnusedCallOp<IREE::VM::CallOp>,
                     EraseUnusedCallOp<IREE::VM::CallVariadicOp>>(
         &getContext(), noSideEffectsSymbols);
diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir
index 03ebbc3..9e30a2e 100644
--- a/tests/compiler_driver/streams.mlir
+++ b/tests/compiler_driver/streams.mlir
@@ -51,10 +51,10 @@
     }
   }
 }
-// CHECK: vm.func private @simple_mul
+// CHECK: vm.func private @__simple_mul_memoize_apply
+// CHECK:   vm.call.variadic @hal.command_buffer.dispatch
 func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = arith.constant 4 : index
-  // CHECK: vm.call.variadic @hal.command_buffer.dispatch
   %ret0 = flow.dispatch @executable_0::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   return %ret0 : tensor<4xf32>
 }
@@ -98,10 +98,10 @@
     }
   }
 }
-// CHECK: vm.func private @simple_mul_inplace
+// CHECK: vm.func private @__simple_mul_inplace_memoize_apply
+// CHECK:   vm.call.variadic @hal.command_buffer.dispatch
 func.func @simple_mul_inplace(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = arith.constant 4 : index
-  // CHECK: vm.call.variadic @hal.command_buffer.dispatch
   %ret0 = flow.dispatch @executable_1::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> %arg0
   return %ret0 : tensor<4xf32>
 }