blob: e588dee37ee2e384a01319a19fbabfbfd58b5185 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/metal/metal_kernel_library.h"
#include "iree/base/memory.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
namespace iree {
namespace hal {
namespace metal {
// static
StatusOr<ref_ptr<MetalKernelLibrary>> MetalKernelLibrary::Create(
id<MTLDevice> device, ExecutableCachingModeBitfield mode,
const MetalExecutableDef& metal_executable_def) {
IREE_TRACE_SCOPE0("MetalKernelLibrary::Create");
if (!metal_executable_def.entry_points() || metal_executable_def.entry_points()->size() == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
}
if (!metal_executable_def.shader_sources() ||
metal_executable_def.shader_sources()->size() == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "No MSL source string present";
}
const auto& entry_points = *metal_executable_def.entry_points();
const auto& msl_sources = *metal_executable_def.shader_sources();
if (entry_points.size() != msl_sources.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Number of entry points and source strings mismatch";
}
// Compile each MSL source string into a MTLLibrary and get the MTLFunction for the entry point to
// build the pipeline state object.
absl::InlinedVector<id<MTLLibrary>, 1> libraries;
absl::InlinedVector<id<MTLFunction>, 1> functions;
absl::InlinedVector<id<MTLComputePipelineState>, 1> states;
MTLCompileOptions* msl_compile_options = [MTLCompileOptions new];
msl_compile_options.languageVersion = MTLLanguageVersion2_0;
auto cleanup = MakeCleanup([&]() {
for (id<MTLComputePipelineState> state : states) [state release];
for (id<MTLFunction> function : functions) [function release];
for (id<MTLLibrary> library : libraries) [library release];
[msl_compile_options release];
});
// TODO(antiagainst): We are performing synchronous compilation at runtime here. This is good for
// debugging purposes but bad for performance. Enable offline compilation and make that as the
// default.
for (int i = 0; i < msl_sources.size(); ++i) {
@autoreleasepool {
NSError* error = nil;
NSString* shader_source = [NSString stringWithCString:msl_sources[i]->c_str()
encoding:[NSString defaultCStringEncoding]];
id<MTLLibrary> library = [device newLibraryWithSource:shader_source
options:msl_compile_options
error:&error];
if (!library) {
NSLog(@"Failed to create MTLLibrary: %@", error);
#ifndef NDEBUG
NSLog(@"Original MSL source: %@", shader_source);
#endif
return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source";
}
libraries.push_back(library);
NSString* entry_point = [NSString stringWithCString:entry_points[i]->c_str()
encoding:[NSString defaultCStringEncoding]];
id<MTLFunction> function = [library newFunctionWithName:entry_point];
if (!function) {
NSLog(@"Failed to create MTLFunction");
#ifndef NDEBUG
NSLog(@"Original MSL source: %@", shader_source);
#endif
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Cannot find entry point '" << entry_points[i] << "' in shader source index "
<< i;
}
functions.push_back(function);
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:function
error:&error];
if (!pso) {
NSLog(@"Failed to create MTLComputePipelineState: %@", error);
#ifndef NDEBUG
NSLog(@"Original MSL source: %@", shader_source);
#endif
return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source";
}
states.push_back(pso);
}
}
std::string tag = metal_executable_def.tag() ? metal_executable_def.tag()->str() : "";
return assign_ref(new MetalKernelLibrary([device retain], std::move(libraries),
std::move(functions), std::move(states),
std::move(tag)));
}
MetalKernelLibrary::MetalKernelLibrary(
id<MTLDevice> device, absl::InlinedVector<id<MTLLibrary>, 1> libraries,
absl::InlinedVector<id<MTLFunction>, 1> functions,
absl::InlinedVector<id<MTLComputePipelineState>, 1> pipelines, std::string tag)
: tag_(std::move(tag)),
device_(device),
libraries_(libraries),
functions_(functions),
pipelines_(pipelines) {}
MetalKernelLibrary::~MetalKernelLibrary() {
IREE_TRACE_SCOPE0("MetalKernelLibrary::dtor");
for (id<MTLComputePipelineState> pso : pipelines_) [pso release];
for (id<MTLFunction> function : functions_) [function release];
for (id<MTLLibrary> library : libraries_) [library release];
}
StatusOr<id<MTLComputePipelineState>> MetalKernelLibrary::GetPipelineStateForEntryPoint(
int ordinal) const {
if (ordinal < 0 || ordinal >= pipelines_.size()) {
return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal;
}
return pipelines_[ordinal];
}
} // namespace metal
} // namespace hal
} // namespace iree