[metal] Enable debug capture manager (6/n) (#3361)

This commit works on implementing hal::DebugCaptureManager.
We just wrap a MTLCaptureManager under the hood.
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
index 29b0f67..7cdb1f5 100644
--- a/iree/hal/metal/CMakeLists.txt
+++ b/iree/hal/metal/CMakeLists.txt
@@ -18,6 +18,24 @@
 
 iree_cc_library(
   NAME
+    metal_capture_manager
+  HDRS
+    "metal_capture_manager.h"
+  SRCS
+    "metal_capture_manager.mm"
+  DEPS
+    iree::base::file_io
+    iree::base::logging
+    iree::base::status
+    iree::base::tracing
+    iree::hal::debug_capture_manager
+  LINKOPTS
+    "-framework Metal"
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     metal_command_buffer
   HDRS
     "metal_command_buffer.h"
@@ -63,6 +81,7 @@
   SRCS
     "metal_device.mm"
   DEPS
+    ::metal_capture_manager
     ::metal_command_buffer
     ::metal_command_queue
     ::metal_direct_allocator
@@ -112,6 +131,7 @@
   SRCS
     "metal_driver.mm"
   DEPS
+    ::metal_capture_manager
     ::metal_device
     iree::base::status
     iree::base::tracing
@@ -129,6 +149,7 @@
     "metal_driver_module.cc"
   DEPS
     ::metal_driver
+    absl::flags
     iree::base::init
     iree::base::status
     iree::hal::driver_registry
diff --git a/iree/hal/metal/metal_capture_manager.h b/iree/hal/metal/metal_capture_manager.h
new file mode 100644
index 0000000..90b8ac6
--- /dev/null
+++ b/iree/hal/metal/metal_capture_manager.h
@@ -0,0 +1,65 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
+#define IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
+
+#include <memory>
+
+#import <Metal/Metal.h>
+
+#include "iree/base/status.h"
+#include "iree/hal/debug_capture_manager.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A DebugCaptureManager implementation for Metal that directly wraps a
+// MTLCaptureManager.
+class MetalCaptureManager final : public DebugCaptureManager {
+ public:
+  // Creates a capture manager that captures Metal commands to the given |capture_file| if not
+  // empty. Capture to Xcode otherwise.
+  static StatusOr<std::unique_ptr<MetalCaptureManager>> Create(const std::string& capture_file);
+  ~MetalCaptureManager() override;
+
+  Status Connect() override;
+
+  void Disconnect() override;
+
+  bool is_connected() const override;
+
+  void SetCaptureObject(id object);
+
+  void StartCapture() override;
+
+  void StopCapture() override;
+
+  bool is_capturing() const override;
+
+ private:
+  explicit MetalCaptureManager(NSURL* capture_file);
+
+  MTLCaptureManager* metal_handle_ = nil;
+  // The path for storing the .gputrace file. Empty means capturing to Xcode.
+  NSURL* capture_file_ = nil;
+  id capture_object_ = nil;
+};
+
+}  // namespace metal
+}  // namespace hal
+}  // namespace iree
+
+#endif  // IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
diff --git a/iree/hal/metal/metal_capture_manager.mm b/iree/hal/metal/metal_capture_manager.mm
new file mode 100644
index 0000000..4437951
--- /dev/null
+++ b/iree/hal/metal/metal_capture_manager.mm
@@ -0,0 +1,128 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_capture_manager.h"
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "iree/base/file_io.h"
+#include "iree/base/logging.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// static
+StatusOr<std::unique_ptr<MetalCaptureManager>> MetalCaptureManager::Create(
+    const std::string& capture_file) {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::Create");
+  @autoreleasepool {
+    NSURL* capture_url = nil;
+    if (!capture_file.empty()) {
+      NSString* ns_string = [NSString stringWithCString:capture_file.c_str()
+                                               encoding:[NSString defaultCStringEncoding]];
+      NSString* capture_path = ns_string.stringByStandardizingPath;
+      capture_url = [[NSURL fileURLWithPath:capture_path isDirectory:false] retain];
+    }
+    return absl::WrapUnique(new MetalCaptureManager(capture_url));
+  }
+}
+
+MetalCaptureManager::MetalCaptureManager(NSURL* capture_file) : capture_file_(capture_file) {}
+
+MetalCaptureManager::~MetalCaptureManager() {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::dtor");
+  Disconnect();
+  if (capture_file_) [capture_file_ release];
+}
+
+Status MetalCaptureManager::Connect() {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::Connect");
+
+  if (metal_handle_) return OkStatus();
+
+  @autoreleasepool {
+    metal_handle_ = [[MTLCaptureManager sharedCaptureManager] retain];
+
+    if (capture_file_ &&
+        [metal_handle_ supportsDestination:MTLCaptureDestinationGPUTraceDocument]) {
+      IREE_LOG(INFO) << "Connected to shared Metal capture manager; writing capture to "
+                     << std::string([capture_file_.absoluteString UTF8String]);
+    } else {
+      IREE_LOG(INFO) << "Connected to shared Metal capture manager; capturing to Xcode";
+    }
+  }
+
+  return OkStatus();
+}
+
+void MetalCaptureManager::Disconnect() {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::Disconnect");
+
+  if (!metal_handle_) return;
+
+  if (is_capturing()) StopCapture();
+
+  [metal_handle_ release];
+  metal_handle_ = nil;
+}
+
+bool MetalCaptureManager::is_connected() const { return metal_handle_ != nil; }
+
+void MetalCaptureManager::SetCaptureObject(id object) { capture_object_ = object; }
+
+void MetalCaptureManager::StartCapture() {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::StartCapture");
+
+  IREE_CHECK(is_connected()) << "Can't start capture when not connected";
+  IREE_CHECK(!is_capturing()) << "Capture is already started";
+  IREE_CHECK(capture_object_) << "Must set capture object before starting";
+
+  IREE_LOG(INFO) << "Starting Metal capture";
+  @autoreleasepool {
+    MTLCaptureDescriptor* capture_descriptor = [[[MTLCaptureDescriptor alloc] init] autorelease];
+    capture_descriptor.captureObject = capture_object_;
+    if (capture_file_) {
+      capture_descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
+      capture_descriptor.outputURL = capture_file_;
+    } else {
+      capture_descriptor.destination = MTLCaptureDestinationDeveloperTools;
+    }
+
+    NSError* error;
+    if (![metal_handle_ startCaptureWithDescriptor:capture_descriptor error:&error]) {
+      NSLog(@"Failed to start capture, error %@", error);
+    }
+  }
+}
+
+void MetalCaptureManager::StopCapture() {
+  IREE_TRACE_SCOPE0("MetalCaptureManager::StopCapture");
+
+  IREE_CHECK(is_capturing()) << "Can't stop capture when not capturing";
+
+  IREE_LOG(INFO) << "Ending Metal capture";
+  [metal_handle_ stopCapture];
+}
+
+bool MetalCaptureManager::is_capturing() const {
+  if (!is_connected()) return false;
+  return metal_handle_.isCapturing;
+}
+
+}  // namespace metal
+}  // namespace hal
+}  // namespace iree
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
index 5257bec..10ba84e 100644
--- a/iree/hal/metal/metal_command_buffer.mm
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -45,7 +45,9 @@
 MetalCommandBuffer::MetalCommandBuffer(CommandBufferModeBitfield mode,
                                        CommandCategoryBitfield command_categories,
                                        id<MTLCommandBuffer> command_buffer)
-    : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) {}
+    : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) {
+  metal_handle_.label = @"IREE MetalCommandBuffer";
+}
 
 MetalCommandBuffer::~MetalCommandBuffer() {
   IREE_TRACE_SCOPE0("MetalCommandBuffer::dtor");
@@ -313,6 +315,7 @@
 
       id<MTLArgumentEncoder> argument_encoder =
           [metal_kernel newArgumentEncoderWithBufferIndex:set_number];  // retained
+      argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentEncoder";
       if (!argument_encoder) {
         return InvalidArgumentErrorBuilder(IREE_LOC)
                << "Buffer index #" << set_number << " is not an argument buffer";
@@ -321,6 +324,7 @@
       __block id<MTLBuffer> argument_buffer =
           [metal_handle_.device newBufferWithLength:argument_encoder.encodedLength
                                             options:MTLResourceStorageModeShared];  // retained
+      argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentBuffer";
       if (!argument_buffer) {
         return InternalErrorBuilder(IREE_LOC)
                << "Failed to create argument buffer with length=" << argument_encoder.encodedLength;
diff --git a/iree/hal/metal/metal_command_queue.mm b/iree/hal/metal/metal_command_queue.mm
index 3729fd0..3f1b6e2 100644
--- a/iree/hal/metal/metal_command_queue.mm
+++ b/iree/hal/metal/metal_command_queue.mm
@@ -26,7 +26,9 @@
 
 MetalCommandQueue::MetalCommandQueue(std::string name, CommandCategoryBitfield supported_categories,
                                      id<MTLCommandQueue> queue)
-    : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) {}
+    : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) {
+  metal_handle_.label = @"IREE MetalQueue";
+}
 
 MetalCommandQueue::~MetalCommandQueue() { [metal_handle_ release]; }
 
@@ -37,6 +39,8 @@
       // Wait for semaphores blocking this batch.
       if (!batch.wait_semaphores.empty()) {
         id<MTLCommandBuffer> wait_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+        wait_buffer.label = @"IREE MetalCommandQueue::Submit Wait Semaphore CommandBuffer";
+
         for (const auto& semaphore : batch.wait_semaphores) {
           auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
           [wait_buffer encodeWaitForEvent:event->handle() value:semaphore.value];
@@ -53,6 +57,8 @@
       // Signal semaphores advanced by this batch.
       if (!batch.signal_semaphores.empty()) {
         id<MTLCommandBuffer> signal_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+        signal_buffer.label = @"IREE MetalCommandQueue::Submit Signal Semaphore CommandBuffer";
+
         for (const auto& semaphore : batch.signal_semaphores) {
           auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
           [signal_buffer encodeSignalEvent:event->handle() value:semaphore.value];
@@ -73,6 +79,7 @@
   // work has completed too.
   @autoreleasepool {
     id<MTLCommandBuffer> comand_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+    comand_buffer.label = @"IREE MetalCommandQueue::WaitIdle Command Buffer";
     __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
     [comand_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
       dispatch_semaphore_signal(work_done);
diff --git a/iree/hal/metal/metal_device.h b/iree/hal/metal/metal_device.h
index 729cbaf..f183efa 100644
--- a/iree/hal/metal/metal_device.h
+++ b/iree/hal/metal/metal_device.h
@@ -22,6 +22,7 @@
 #include "absl/types/span.h"
 #include "iree/base/memory.h"
 #include "iree/hal/allocator.h"
+#include "iree/hal/debug_capture_manager.h"
 #include "iree/hal/device.h"
 #include "iree/hal/driver.h"
 #include "iree/hal/semaphore.h"
@@ -35,8 +36,9 @@
  public:
   // Creates a device that retains the underlying Metal GPU device.
   // The DriverDeviceID in |device_info| is expected to be an id<MTLDevice>.
-  static StatusOr<ref_ptr<MetalDevice>> Create(ref_ptr<Driver> driver,
-                                               const DeviceInfo& device_info);
+  static StatusOr<ref_ptr<MetalDevice>> Create(
+      ref_ptr<Driver> driver, const DeviceInfo& device_info,
+      DebugCaptureManager* debug_capture_manager);
 
   ~MetalDevice() override;
 
@@ -81,7 +83,8 @@
   Status WaitIdle(Time deadline_ns) override;
 
  private:
-  MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info);
+  MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info,
+              DebugCaptureManager* debug_capture_manager);
 
   ref_ptr<Driver> driver_;
   id<MTLDevice> metal_handle_;
@@ -102,6 +105,8 @@
   // semaphore.
   dispatch_queue_t wait_notifier_;
   MTLSharedEventListener* event_listener_;
+
+  DebugCaptureManager* debug_capture_manager_ = nullptr;
 };
 
 }  // namespace metal
diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm
index 97dcba8..f11d4b8 100644
--- a/iree/hal/metal/metal_device.mm
+++ b/iree/hal/metal/metal_device.mm
@@ -22,6 +22,7 @@
 #include "iree/hal/allocator.h"
 #include "iree/hal/command_buffer_validation.h"
 #include "iree/hal/metal/dispatch_time_util.h"
+#include "iree/hal/metal/metal_capture_manager.h"
 #include "iree/hal/metal/metal_command_buffer.h"
 #include "iree/hal/metal/metal_command_queue.h"
 #include "iree/hal/metal/metal_direct_allocator.h"
@@ -35,14 +36,17 @@
 
 // static
 StatusOr<ref_ptr<MetalDevice>> MetalDevice::Create(ref_ptr<Driver> driver,
-                                                   const DeviceInfo& device_info) {
-  return assign_ref(new MetalDevice(std::move(driver), device_info));
+                                                   const DeviceInfo& device_info,
+                                                   DebugCaptureManager* debug_capture_manager) {
+  return assign_ref(new MetalDevice(std::move(driver), device_info, debug_capture_manager));
 }
 
-MetalDevice::MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info)
+MetalDevice::MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info,
+                         DebugCaptureManager* debug_capture_manager)
     : Device(device_info),
       driver_(std::move(driver)),
-      metal_handle_([(__bridge id<MTLDevice>)device_info.device_id() retain]) {
+      metal_handle_([(__bridge id<MTLDevice>)device_info.device_id() retain]),
+      debug_capture_manager_(debug_capture_manager) {
   IREE_TRACE_SCOPE0("MetalDevice::ctor");
 
   // Grab one queue for dispatch and transfer.
@@ -51,6 +55,12 @@
 
   allocator_ = MetalDirectAllocator::Create(metal_handle_, metal_queue);
 
+  if (debug_capture_manager_ && debug_capture_manager_->is_connected()) {
+    // Record a capture covering the duration of this device lifetime.
+    static_cast<MetalCaptureManager*>(debug_capture_manager_)->SetCaptureObject(metal_handle_);
+    debug_capture_manager_->StartCapture();
+  }
+
   command_queue_ = absl::make_unique<MetalCommandQueue>(
       name, CommandCategory::kDispatch | CommandCategory::kTransfer, metal_queue);
   common_queue_ = command_queue_.get();
@@ -63,8 +73,14 @@
 
 MetalDevice::~MetalDevice() {
   IREE_TRACE_SCOPE0("MetalDevice::dtor");
+
+  if (debug_capture_manager_ && debug_capture_manager_->is_capturing()) {
+    debug_capture_manager_->StopCapture();
+  }
+
   [event_listener_ release];
   dispatch_release(wait_notifier_);
+
   [metal_handle_ release];
 }
 
diff --git a/iree/hal/metal/metal_driver.h b/iree/hal/metal/metal_driver.h
index 7f458b6..f42e9ed 100644
--- a/iree/hal/metal/metal_driver.h
+++ b/iree/hal/metal/metal_driver.h
@@ -15,19 +15,31 @@
 #ifndef IREE_HAL_METAL_METAL_DRIVER_H_
 #define IREE_HAL_METAL_METAL_DRIVER_H_
 
+#include <memory>
+#include <string>
+
+#include "iree/hal/debug_capture_manager.h"
 #include "iree/hal/driver.h"
 
 namespace iree {
 namespace hal {
 namespace metal {
 
+struct MetalDriverOptions {
+  // Whether to enable Metal command capture.
+  bool enable_capture;
+  // The file to contain the Metal capture. Empty means capturing to Xcode.
+  std::string capture_file;
+};
+
 // A pseudo Metal GPU driver which retains all available Metal GPU devices
 // during its lifetime.
 //
 // It uses the DriverDeviceID to store the underlying id<MTLDevice>.
 class MetalDriver final : public Driver {
  public:
-  static StatusOr<ref_ptr<MetalDriver>> Create();
+  static StatusOr<ref_ptr<MetalDriver>> Create(
+      const MetalDriverOptions& options);
 
   ~MetalDriver() override;
 
@@ -38,9 +50,12 @@
   StatusOr<ref_ptr<Device>> CreateDevice(DriverDeviceID device_id) override;
 
  private:
-  explicit MetalDriver(std::vector<DeviceInfo> devices);
+  MetalDriver(std::vector<DeviceInfo> devices,
+              std::unique_ptr<DebugCaptureManager> debug_capture_manager);
 
   std::vector<DeviceInfo> devices_;
+
+  std::unique_ptr<DebugCaptureManager> debug_capture_manager_;
 };
 
 }  // namespace metal
diff --git a/iree/hal/metal/metal_driver.mm b/iree/hal/metal/metal_driver.mm
index de3aca1..3742d3e 100644
--- a/iree/hal/metal/metal_driver.mm
+++ b/iree/hal/metal/metal_driver.mm
@@ -18,6 +18,7 @@
 
 #include "iree/base/status.h"
 #include "iree/base/tracing.h"
+#include "iree/hal/metal/metal_capture_manager.h"
 #include "iree/hal/metal/metal_device.h"
 
 namespace iree {
@@ -41,7 +42,7 @@
 }  // namespace
 
 // static
-StatusOr<ref_ptr<MetalDriver>> MetalDriver::Create() {
+StatusOr<ref_ptr<MetalDriver>> MetalDriver::Create(const MetalDriverOptions& options) {
   IREE_TRACE_SCOPE0("MetalDriver::Create");
 
   @autoreleasepool {
@@ -50,6 +51,13 @@
       return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available";
     }
 
+    std::unique_ptr<MetalCaptureManager> metal_capture_manager;
+    if (options.enable_capture) {
+      IREE_ASSIGN_OR_RETURN(metal_capture_manager,
+                            MetalCaptureManager::Create(options.capture_file));
+      IREE_RETURN_IF_ERROR(metal_capture_manager->Connect());
+    }
+
     std::vector<DeviceInfo> device_infos;
     for (id<MTLDevice> device in devices) {
       std::string name = std::string([device.name UTF8String]);
@@ -57,12 +65,15 @@
       DriverDeviceID device_id = reinterpret_cast<DriverDeviceID>((__bridge void*)device);
       device_infos.emplace_back("metal", std::move(name), supported_features, device_id);
     }
-    return assign_ref(new MetalDriver(std::move(device_infos)));
+    return assign_ref(new MetalDriver(std::move(device_infos), std::move(metal_capture_manager)));
   }
 }
 
-MetalDriver::MetalDriver(std::vector<DeviceInfo> devices)
-    : Driver("metal"), devices_(std::move(devices)) {
+MetalDriver::MetalDriver(std::vector<DeviceInfo> devices,
+                         std::unique_ptr<DebugCaptureManager> debug_capture_manager)
+    : Driver("metal"),
+      devices_(std::move(devices)),
+      debug_capture_manager_(std::move(debug_capture_manager)) {
   // Retain all the retained Metal GPU devices.
   for (const auto& device : devices_) {
     [(__bridge id<MTLDevice>)device.device_id() retain];
@@ -97,7 +108,9 @@
   IREE_TRACE_SCOPE0("MetalDriver::CreateDevice");
 
   for (const DeviceInfo& info : devices_) {
-    if (info.device_id() == device_id) return MetalDevice::Create(add_ref(this), info);
+    if (info.device_id() == device_id) {
+      return MetalDevice::Create(add_ref(this), info, debug_capture_manager_.get());
+    }
   }
   return InvalidArgumentErrorBuilder(IREE_LOC) << "unknown driver device id: " << device_id;
 }
diff --git a/iree/hal/metal/metal_driver_module.cc b/iree/hal/metal/metal_driver_module.cc
index ba2ec58..48308ae 100644
--- a/iree/hal/metal/metal_driver_module.cc
+++ b/iree/hal/metal/metal_driver_module.cc
@@ -14,17 +14,28 @@
 
 #include <memory>
 
+#include "absl/flags/flag.h"
 #include "iree/base/init.h"
 #include "iree/base/status.h"
 #include "iree/hal/driver_registry.h"
 #include "iree/hal/metal/metal_driver.h"
 
+ABSL_FLAG(bool, metal_capture, false, "Enables capturing Metal commands.");
+ABSL_FLAG(
+    std::string, metal_capture_to_file, "",
+    "Full path to store the GPU trace file (empty means capture to Xcode)");
+
 namespace iree {
 namespace hal {
 namespace metal {
 namespace {
 
-StatusOr<ref_ptr<Driver>> CreateMetalDriver() { return MetalDriver::Create(); }
+StatusOr<ref_ptr<Driver>> CreateMetalDriver() {
+  MetalDriverOptions options;
+  options.enable_capture = absl::GetFlag(FLAGS_metal_capture);
+  options.capture_file = absl::GetFlag(FLAGS_metal_capture_to_file);
+  return MetalDriver::Create(options);
+}
 
 }  // namespace
 }  // namespace metal