[metal] Enable debug capture manager (6/n) (#3361)
This commit works on implementing hal::DebugCaptureManager.
We just wrap a MTLCaptureManager under the hood.
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
index 29b0f67..7cdb1f5 100644
--- a/iree/hal/metal/CMakeLists.txt
+++ b/iree/hal/metal/CMakeLists.txt
@@ -18,6 +18,24 @@
iree_cc_library(
NAME
+ metal_capture_manager
+ HDRS
+ "metal_capture_manager.h"
+ SRCS
+ "metal_capture_manager.mm"
+ DEPS
+ iree::base::file_io
+ iree::base::logging
+ iree::base::status
+ iree::base::tracing
+ iree::hal::debug_capture_manager
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
metal_command_buffer
HDRS
"metal_command_buffer.h"
@@ -63,6 +81,7 @@
SRCS
"metal_device.mm"
DEPS
+ ::metal_capture_manager
::metal_command_buffer
::metal_command_queue
::metal_direct_allocator
@@ -112,6 +131,7 @@
SRCS
"metal_driver.mm"
DEPS
+ ::metal_capture_manager
::metal_device
iree::base::status
iree::base::tracing
@@ -129,6 +149,7 @@
"metal_driver_module.cc"
DEPS
::metal_driver
+ absl::flags
iree::base::init
iree::base::status
iree::hal::driver_registry
diff --git a/iree/hal/metal/metal_capture_manager.h b/iree/hal/metal/metal_capture_manager.h
new file mode 100644
index 0000000..90b8ac6
--- /dev/null
+++ b/iree/hal/metal/metal_capture_manager.h
@@ -0,0 +1,65 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
+#define IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
+
+#include <memory>
+
+#import <Metal/Metal.h>
+
+#include "iree/base/status.h"
+#include "iree/hal/debug_capture_manager.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A DebugCaptureManager implementation for Metal that directly wraps a
+// MTLCaptureManager.
+class MetalCaptureManager final : public DebugCaptureManager {
+ public:
+ // Creates a capture manager that captures Metal commands to the given |capture_file| if not
+ // empty. Capture to Xcode otherwise.
+ static StatusOr<std::unique_ptr<MetalCaptureManager>> Create(const std::string& capture_file);
+ ~MetalCaptureManager() override;
+
+ Status Connect() override;
+
+ void Disconnect() override;
+
+ bool is_connected() const override;
+
+ void SetCaptureObject(id object);
+
+ void StartCapture() override;
+
+ void StopCapture() override;
+
+ bool is_capturing() const override;
+
+ private:
+ explicit MetalCaptureManager(NSURL* capture_file);
+
+ MTLCaptureManager* metal_handle_ = nil;
+ // The path for storing the .gputrace file. Empty means capturing to Xcode.
+ NSURL* capture_file_ = nil;
+ id capture_object_ = nil;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_
diff --git a/iree/hal/metal/metal_capture_manager.mm b/iree/hal/metal/metal_capture_manager.mm
new file mode 100644
index 0000000..4437951
--- /dev/null
+++ b/iree/hal/metal/metal_capture_manager.mm
@@ -0,0 +1,128 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_capture_manager.h"
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "iree/base/file_io.h"
+#include "iree/base/logging.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// static
+StatusOr<std::unique_ptr<MetalCaptureManager>> MetalCaptureManager::Create(
+ const std::string& capture_file) {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::Create");
+ @autoreleasepool {
+ NSURL* capture_url = nil;
+ if (!capture_file.empty()) {
+ NSString* ns_string = [NSString stringWithCString:capture_file.c_str()
+ encoding:[NSString defaultCStringEncoding]];
+ NSString* capture_path = ns_string.stringByStandardizingPath;
+ capture_url = [[NSURL fileURLWithPath:capture_path isDirectory:false] retain];
+ }
+ return absl::WrapUnique(new MetalCaptureManager(capture_url));
+ }
+}
+
+MetalCaptureManager::MetalCaptureManager(NSURL* capture_file) : capture_file_(capture_file) {}
+
+MetalCaptureManager::~MetalCaptureManager() {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::dtor");
+ Disconnect();
+ if (capture_file_) [capture_file_ release];
+}
+
+Status MetalCaptureManager::Connect() {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::Connect");
+
+ if (metal_handle_) return OkStatus();
+
+ @autoreleasepool {
+ metal_handle_ = [[MTLCaptureManager sharedCaptureManager] retain];
+
+ if (capture_file_ &&
+ [metal_handle_ supportsDestination:MTLCaptureDestinationGPUTraceDocument]) {
+ IREE_LOG(INFO) << "Connected to shared Metal capture manager; writing capture to "
+ << std::string([capture_file_.absoluteString UTF8String]);
+ } else {
+ IREE_LOG(INFO) << "Connected to shared Metal capture manager; capturing to Xcode";
+ }
+ }
+
+ return OkStatus();
+}
+
+void MetalCaptureManager::Disconnect() {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::Disconnect");
+
+ if (!metal_handle_) return;
+
+ if (is_capturing()) StopCapture();
+
+ [metal_handle_ release];
+ metal_handle_ = nil;
+}
+
+bool MetalCaptureManager::is_connected() const { return metal_handle_ != nil; }
+
+void MetalCaptureManager::SetCaptureObject(id object) { capture_object_ = object; }
+
+void MetalCaptureManager::StartCapture() {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::StartCapture");
+
+ IREE_CHECK(is_connected()) << "Can't start capture when not connected";
+ IREE_CHECK(!is_capturing()) << "Capture is already started";
+ IREE_CHECK(capture_object_) << "Must set capture object before starting";
+
+ IREE_LOG(INFO) << "Starting Metal capture";
+ @autoreleasepool {
+ MTLCaptureDescriptor* capture_descriptor = [[[MTLCaptureDescriptor alloc] init] autorelease];
+ capture_descriptor.captureObject = capture_object_;
+ if (capture_file_) {
+ capture_descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
+ capture_descriptor.outputURL = capture_file_;
+ } else {
+ capture_descriptor.destination = MTLCaptureDestinationDeveloperTools;
+ }
+
+ NSError* error;
+ if (![metal_handle_ startCaptureWithDescriptor:capture_descriptor error:&error]) {
+ NSLog(@"Failed to start capture, error %@", error);
+ }
+ }
+}
+
+void MetalCaptureManager::StopCapture() {
+ IREE_TRACE_SCOPE0("MetalCaptureManager::StopCapture");
+
+ IREE_CHECK(is_capturing()) << "Can't stop capture when not capturing";
+
+ IREE_LOG(INFO) << "Ending Metal capture";
+ [metal_handle_ stopCapture];
+}
+
+bool MetalCaptureManager::is_capturing() const {
+ if (!is_connected()) return false;
+ return metal_handle_.isCapturing;
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
index 5257bec..10ba84e 100644
--- a/iree/hal/metal/metal_command_buffer.mm
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -45,7 +45,9 @@
MetalCommandBuffer::MetalCommandBuffer(CommandBufferModeBitfield mode,
CommandCategoryBitfield command_categories,
id<MTLCommandBuffer> command_buffer)
- : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) {}
+ : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) {
+ metal_handle_.label = @"IREE MetalCommandBuffer";
+}
MetalCommandBuffer::~MetalCommandBuffer() {
IREE_TRACE_SCOPE0("MetalCommandBuffer::dtor");
@@ -313,6 +315,7 @@
id<MTLArgumentEncoder> argument_encoder =
[metal_kernel newArgumentEncoderWithBufferIndex:set_number]; // retained
+ argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentEncoder";
if (!argument_encoder) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Buffer index #" << set_number << " is not an argument buffer";
@@ -321,6 +324,7 @@
__block id<MTLBuffer> argument_buffer =
[metal_handle_.device newBufferWithLength:argument_encoder.encodedLength
options:MTLResourceStorageModeShared]; // retained
+ argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentBuffer";
if (!argument_buffer) {
return InternalErrorBuilder(IREE_LOC)
<< "Failed to create argument buffer with length=" << argument_encoder.encodedLength;
diff --git a/iree/hal/metal/metal_command_queue.mm b/iree/hal/metal/metal_command_queue.mm
index 3729fd0..3f1b6e2 100644
--- a/iree/hal/metal/metal_command_queue.mm
+++ b/iree/hal/metal/metal_command_queue.mm
@@ -26,7 +26,9 @@
MetalCommandQueue::MetalCommandQueue(std::string name, CommandCategoryBitfield supported_categories,
id<MTLCommandQueue> queue)
- : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) {}
+ : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) {
+ metal_handle_.label = @"IREE MetalQueue";
+}
MetalCommandQueue::~MetalCommandQueue() { [metal_handle_ release]; }
@@ -37,6 +39,8 @@
// Wait for semaphores blocking this batch.
if (!batch.wait_semaphores.empty()) {
id<MTLCommandBuffer> wait_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ wait_buffer.label = @"IREE MetalCommandQueue::Submit Wait Semaphore CommandBuffer";
+
for (const auto& semaphore : batch.wait_semaphores) {
auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
[wait_buffer encodeWaitForEvent:event->handle() value:semaphore.value];
@@ -53,6 +57,8 @@
// Signal semaphores advanced by this batch.
if (!batch.signal_semaphores.empty()) {
id<MTLCommandBuffer> signal_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ signal_buffer.label = @"IREE MetalCommandQueue::Submit Signal Semaphore CommandBuffer";
+
for (const auto& semaphore : batch.signal_semaphores) {
auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
[signal_buffer encodeSignalEvent:event->handle() value:semaphore.value];
@@ -73,6 +79,7 @@
// work has completed too.
@autoreleasepool {
id<MTLCommandBuffer> comand_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ comand_buffer.label = @"IREE MetalCommandQueue::WaitIdle Command Buffer";
__block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
[comand_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
dispatch_semaphore_signal(work_done);
diff --git a/iree/hal/metal/metal_device.h b/iree/hal/metal/metal_device.h
index 729cbaf..f183efa 100644
--- a/iree/hal/metal/metal_device.h
+++ b/iree/hal/metal/metal_device.h
@@ -22,6 +22,7 @@
#include "absl/types/span.h"
#include "iree/base/memory.h"
#include "iree/hal/allocator.h"
+#include "iree/hal/debug_capture_manager.h"
#include "iree/hal/device.h"
#include "iree/hal/driver.h"
#include "iree/hal/semaphore.h"
@@ -35,8 +36,9 @@
public:
// Creates a device that retains the underlying Metal GPU device.
// The DriverDeviceID in |device_info| is expected to be an id<MTLDevice>.
- static StatusOr<ref_ptr<MetalDevice>> Create(ref_ptr<Driver> driver,
- const DeviceInfo& device_info);
+ static StatusOr<ref_ptr<MetalDevice>> Create(
+ ref_ptr<Driver> driver, const DeviceInfo& device_info,
+ DebugCaptureManager* debug_capture_manager);
~MetalDevice() override;
@@ -81,7 +83,8 @@
Status WaitIdle(Time deadline_ns) override;
private:
- MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info);
+ MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info,
+ DebugCaptureManager* debug_capture_manager);
ref_ptr<Driver> driver_;
id<MTLDevice> metal_handle_;
@@ -102,6 +105,8 @@
// semaphore.
dispatch_queue_t wait_notifier_;
MTLSharedEventListener* event_listener_;
+
+ DebugCaptureManager* debug_capture_manager_ = nullptr;
};
} // namespace metal
diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm
index 97dcba8..f11d4b8 100644
--- a/iree/hal/metal/metal_device.mm
+++ b/iree/hal/metal/metal_device.mm
@@ -22,6 +22,7 @@
#include "iree/hal/allocator.h"
#include "iree/hal/command_buffer_validation.h"
#include "iree/hal/metal/dispatch_time_util.h"
+#include "iree/hal/metal/metal_capture_manager.h"
#include "iree/hal/metal/metal_command_buffer.h"
#include "iree/hal/metal/metal_command_queue.h"
#include "iree/hal/metal/metal_direct_allocator.h"
@@ -35,14 +36,17 @@
// static
StatusOr<ref_ptr<MetalDevice>> MetalDevice::Create(ref_ptr<Driver> driver,
- const DeviceInfo& device_info) {
- return assign_ref(new MetalDevice(std::move(driver), device_info));
+ const DeviceInfo& device_info,
+ DebugCaptureManager* debug_capture_manager) {
+ return assign_ref(new MetalDevice(std::move(driver), device_info, debug_capture_manager));
}
-MetalDevice::MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info)
+MetalDevice::MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info,
+ DebugCaptureManager* debug_capture_manager)
: Device(device_info),
driver_(std::move(driver)),
- metal_handle_([(__bridge id<MTLDevice>)device_info.device_id() retain]) {
+ metal_handle_([(__bridge id<MTLDevice>)device_info.device_id() retain]),
+ debug_capture_manager_(debug_capture_manager) {
IREE_TRACE_SCOPE0("MetalDevice::ctor");
// Grab one queue for dispatch and transfer.
@@ -51,6 +55,12 @@
allocator_ = MetalDirectAllocator::Create(metal_handle_, metal_queue);
+ if (debug_capture_manager_ && debug_capture_manager_->is_connected()) {
+ // Record a capture covering the duration of this device lifetime.
+ static_cast<MetalCaptureManager*>(debug_capture_manager_)->SetCaptureObject(metal_handle_);
+ debug_capture_manager_->StartCapture();
+ }
+
command_queue_ = absl::make_unique<MetalCommandQueue>(
name, CommandCategory::kDispatch | CommandCategory::kTransfer, metal_queue);
common_queue_ = command_queue_.get();
@@ -63,8 +73,14 @@
MetalDevice::~MetalDevice() {
IREE_TRACE_SCOPE0("MetalDevice::dtor");
+
+ if (debug_capture_manager_ && debug_capture_manager_->is_capturing()) {
+ debug_capture_manager_->StopCapture();
+ }
+
[event_listener_ release];
dispatch_release(wait_notifier_);
+
[metal_handle_ release];
}
diff --git a/iree/hal/metal/metal_driver.h b/iree/hal/metal/metal_driver.h
index 7f458b6..f42e9ed 100644
--- a/iree/hal/metal/metal_driver.h
+++ b/iree/hal/metal/metal_driver.h
@@ -15,19 +15,31 @@
#ifndef IREE_HAL_METAL_METAL_DRIVER_H_
#define IREE_HAL_METAL_METAL_DRIVER_H_
+#include <memory>
+#include <string>
+
+#include "iree/hal/debug_capture_manager.h"
#include "iree/hal/driver.h"
namespace iree {
namespace hal {
namespace metal {
+struct MetalDriverOptions {
+ // Whether to enable Metal command capture.
+ bool enable_capture;
+ // The file to contain the Metal capture. Empty means capturing to Xcode.
+ std::string capture_file;
+};
+
// A pseudo Metal GPU driver which retains all available Metal GPU devices
// during its lifetime.
//
// It uses the DriverDeviceID to store the underlying id<MTLDevice>.
class MetalDriver final : public Driver {
public:
- static StatusOr<ref_ptr<MetalDriver>> Create();
+ static StatusOr<ref_ptr<MetalDriver>> Create(
+ const MetalDriverOptions& options);
~MetalDriver() override;
@@ -38,9 +50,12 @@
StatusOr<ref_ptr<Device>> CreateDevice(DriverDeviceID device_id) override;
private:
- explicit MetalDriver(std::vector<DeviceInfo> devices);
+ MetalDriver(std::vector<DeviceInfo> devices,
+ std::unique_ptr<DebugCaptureManager> debug_capture_manager);
std::vector<DeviceInfo> devices_;
+
+ std::unique_ptr<DebugCaptureManager> debug_capture_manager_;
};
} // namespace metal
diff --git a/iree/hal/metal/metal_driver.mm b/iree/hal/metal/metal_driver.mm
index de3aca1..3742d3e 100644
--- a/iree/hal/metal/metal_driver.mm
+++ b/iree/hal/metal/metal_driver.mm
@@ -18,6 +18,7 @@
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/metal/metal_capture_manager.h"
#include "iree/hal/metal/metal_device.h"
namespace iree {
@@ -41,7 +42,7 @@
} // namespace
// static
-StatusOr<ref_ptr<MetalDriver>> MetalDriver::Create() {
+StatusOr<ref_ptr<MetalDriver>> MetalDriver::Create(const MetalDriverOptions& options) {
IREE_TRACE_SCOPE0("MetalDriver::Create");
@autoreleasepool {
@@ -50,6 +51,13 @@
return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available";
}
+ std::unique_ptr<MetalCaptureManager> metal_capture_manager;
+ if (options.enable_capture) {
+ IREE_ASSIGN_OR_RETURN(metal_capture_manager,
+ MetalCaptureManager::Create(options.capture_file));
+ IREE_RETURN_IF_ERROR(metal_capture_manager->Connect());
+ }
+
std::vector<DeviceInfo> device_infos;
for (id<MTLDevice> device in devices) {
std::string name = std::string([device.name UTF8String]);
@@ -57,12 +65,15 @@
DriverDeviceID device_id = reinterpret_cast<DriverDeviceID>((__bridge void*)device);
device_infos.emplace_back("metal", std::move(name), supported_features, device_id);
}
- return assign_ref(new MetalDriver(std::move(device_infos)));
+ return assign_ref(new MetalDriver(std::move(device_infos), std::move(metal_capture_manager)));
}
}
-MetalDriver::MetalDriver(std::vector<DeviceInfo> devices)
- : Driver("metal"), devices_(std::move(devices)) {
+MetalDriver::MetalDriver(std::vector<DeviceInfo> devices,
+ std::unique_ptr<DebugCaptureManager> debug_capture_manager)
+ : Driver("metal"),
+ devices_(std::move(devices)),
+ debug_capture_manager_(std::move(debug_capture_manager)) {
// Retain all the retained Metal GPU devices.
for (const auto& device : devices_) {
[(__bridge id<MTLDevice>)device.device_id() retain];
@@ -97,7 +108,9 @@
IREE_TRACE_SCOPE0("MetalDriver::CreateDevice");
for (const DeviceInfo& info : devices_) {
- if (info.device_id() == device_id) return MetalDevice::Create(add_ref(this), info);
+ if (info.device_id() == device_id) {
+ return MetalDevice::Create(add_ref(this), info, debug_capture_manager_.get());
+ }
}
return InvalidArgumentErrorBuilder(IREE_LOC) << "unknown driver device id: " << device_id;
}
diff --git a/iree/hal/metal/metal_driver_module.cc b/iree/hal/metal/metal_driver_module.cc
index ba2ec58..48308ae 100644
--- a/iree/hal/metal/metal_driver_module.cc
+++ b/iree/hal/metal/metal_driver_module.cc
@@ -14,17 +14,28 @@
#include <memory>
+#include "absl/flags/flag.h"
#include "iree/base/init.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/metal/metal_driver.h"
+ABSL_FLAG(bool, metal_capture, false, "Enables capturing Metal commands.");
+ABSL_FLAG(
+ std::string, metal_capture_to_file, "",
+ "Full path to store the GPU trace file (empty means capture to Xcode)");
+
namespace iree {
namespace hal {
namespace metal {
namespace {
-StatusOr<ref_ptr<Driver>> CreateMetalDriver() { return MetalDriver::Create(); }
+StatusOr<ref_ptr<Driver>> CreateMetalDriver() {
+ MetalDriverOptions options;
+ options.enable_capture = absl::GetFlag(FLAGS_metal_capture);
+ options.capture_file = absl::GetFlag(FLAGS_metal_capture_to_file);
+ return MetalDriver::Create(options);
+}
} // namespace
} // namespace metal