blob: cdfbf4f469fe0d020dd9747bd341998049c1fc58 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
#include <map>
#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/LegacyUtil.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
#include "iree/compiler/Translation/CodegenPasses/Passes.h"
#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
#include "iree/compiler/Translation/SPIRV/EmbeddedKernels/EmbeddedKernels.h"
#include "iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h"
#include "iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.h"
#include "iree/schemas/spirv_executable_def_generated.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
VulkanSPIRVTargetOptions getVulkanSPIRVTargetOptionsFromFlags() {
// TODO(antiagainst): Enable option categories once the following bug is
// fixed: https://bugs.llvm.org/show_bug.cgi?id=44223 static
// llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
// "IREE Vulkan/SPIR-V backend options");
// TODO(ravishankarm): Flags to test the Linalg To SPIR-V path. Need a better
// way to handle these options.
static llvm::cl::opt<bool> clUseLinalgPath(
"iree-use-linalg-to-spirv-path",
llvm::cl::desc("Use the XLA-HLO to Linalg To SPIR-V pass pipeline"),
llvm::cl::init(false));
static llvm::cl::list<unsigned> clLinalgPathWorkgroupSize(
"iree-linalg-to-spirv-workgroup-size",
llvm::cl::desc(
"Workgroup size to use for XLA-HLO to Linalg to SPIR-V path"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
static llvm::cl::opt<std::string> clVulkanTargetEnv(
"iree-vulkan-target-env",
llvm::cl::desc(
"Vulkan target environment as #vk.target_env attribute assembly"),
llvm::cl::init(Vulkan::swiftShaderTargetEnvAssembly));
VulkanSPIRVTargetOptions targetOptions;
targetOptions.useLinalgToSPIRVPath = clUseLinalgPath;
for (unsigned dim : clLinalgPathWorkgroupSize) {
targetOptions.linalgToSPIRVWorkgroupSize.push_back(dim);
}
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
return targetOptions;
}
// Returns the Vulkan target environment for conversion.
static spirv::TargetEnvAttr getSPIRVTargetEnv(
const std::string &vulkanTargetEnv, MLIRContext *context) {
if (auto attr = mlir::parseAttribute(vulkanTargetEnv, context)) {
if (auto vkTargetEnv = attr.dyn_cast<Vulkan::TargetEnvAttr>()) {
return convertTargetEnv(vkTargetEnv);
}
}
emitError(Builder(context).getUnknownLoc())
<< "cannot parse vulkan target environment as #vk.target_env attribute ";
return {};
}
/// Returns true if the linalg on tensors path is to be used for
/// compilation.
static bool useLinalgPath(ModuleOp moduleOp,
VulkanSPIRVTargetOptions const &targetOptions) {
if (targetOptions.useLinalgToSPIRVPath) return true;
// Use linalg path if dispatch function contains any of the following ops.
auto walkResult = moduleOp.walk([](Operation *op) -> WalkResult {
if (isa<xla_hlo::ReduceOp>(op) || isa<xla_hlo::ConvOp>(op) ||
isa<xla_hlo::DotOp>(op)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return walkResult.wasInterrupted();
}
// Returns a list of entry point names matching the expected export ordinals.
static std::vector<std::string> populateEntryPointNames(
spirv::ModuleOp spvModuleOp) {
std::vector<std::string> entryPointNames;
spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
entryPointNames.push_back(std::string(entryPointOp.fn()));
});
return entryPointNames;
}
class VulkanSPIRVTargetBackend : public TargetBackend {
public:
VulkanSPIRVTargetBackend(VulkanSPIRVTargetOptions options)
: options_(std::move(options)) {}
// NOTE: we could vary this based on the options such as 'vulkan-v1.1'.
std::string name() const override { return "vulkan*"; }
void constructTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) override {
// Attach SPIR-V target environment to the hal.executable.target op.
// If we had multiple target environments we would generate one target op
// per environment.
spirv::TargetEnvAttr spvTargetEnv =
getSPIRVTargetEnv(options_.vulkanTargetEnv, sourceOp.getContext());
OpBuilder targetBuilder(&executableOp.getBlock().back());
auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
sourceOp.getLoc(), name());
OpBuilder containerBuilder(&targetOp.getBlock().back());
auto innerModuleOp =
containerBuilder.clone(*sourceOp.getInnerModule().getOperation());
innerModuleOp->setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
if (useLinalgPath(sourceOp.getInnerModule(), options_)) {
targetOp.setAttr("vkspv.use_linalg",
UnitAttr::get(sourceOp.getContext()));
}
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
OpPassManager &passManager) override {
passManager.addPass(createHALInterfaceToMemrefPass());
if (targetOp.getAttr("vkspv.use_linalg")) {
addHLOToLinalgToSPIRVPasses(passManager,
options_.linalgToSPIRVWorkgroupSize);
} else {
addIREEToSPIRVPasses(passManager);
}
}
// Finds the spv.ExecutionMode operation to get the workgroup size from.
// TODO(ravishankarm): This might not be the only way this is specified. You
// could also have a spec constant, but that is not generated in the
// spv.module right now.
// TODO(ravishankarm): change workgroup size calculation to something we can
// query independently so that we don't need to lookup the value here.
std::array<Value, 3> calculateDispatchWorkgroupSize(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
OpBuilder &builder) override {
// TODO(ravishankarm): possibly emit different recordDispatch logic if the
// workgroup sizes differ among targets.
spirv::ModuleOp spvModuleOp;
for (auto executableTargetOp :
executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
if (matchPattern(executableTargetOp.target_backend(), "vulkan*")) {
auto spvModuleOps =
executableTargetOp.getInnerModule().getOps<spirv::ModuleOp>();
assert(!spvModuleOps.empty());
spvModuleOp = *spvModuleOps.begin();
break;
}
}
std::array<Value, 3> workgroupSize;
for (auto executionModeOp :
spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
if (executionModeOp.fn() == entryPointOp.sym_name() &&
executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
for (int i = 0; i < executionModeOp.values().size(); ++i) {
workgroupSize[i] =
builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
.cast<IntegerAttr>()
.getValue()
.getZExtValue());
}
break;
}
}
// Pad out the workgroup size with 1's (if the original rank was < 3).
for (int i = 0; i < workgroupSize.size(); ++i) {
if (!workgroupSize[i]) {
workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
}
}
return workgroupSize;
}
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
iree::SpirVExecutableDefT spirvExecutableDef;
auto spvModuleOp =
*targetOp.getInnerModule().getOps<spirv::ModuleOp>().begin();
// The sequencer and runtime use ordinals instead of names. We provide the
// list of entry point names here that are then passed in
// VkShaderModuleCreateInfo.
spirvExecutableDef.entry_points = populateEntryPointNames(spvModuleOp);
// Serialize the spirv::ModuleOp into the binary that we will embed in the
// final flatbuffer.
SmallVector<uint32_t, 256> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
return targetOp.emitError() << "failed to serialize spv.module";
}
spirvExecutableDef.code = {spvBinary.begin(), spvBinary.end()};
if (spirvExecutableDef.code.empty()) {
return targetOp.emitError()
<< "failed to translate and serialize SPIR-V executable";
}
// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
::flatbuffers::FlatBufferBuilder fbb;
auto executableOffset =
iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef);
iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset);
std::vector<uint8_t> bytes;
bytes.resize(fbb.GetSize());
std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::SpirV),
std::move(bytes));
return success();
}
protected:
VulkanSPIRVTargetOptions options_;
};
void registerVulkanSPIRVTargetBackends(
std::function<VulkanSPIRVTargetOptions()> queryOptions) {
getVulkanSPIRVTargetOptionsFromFlags();
static TargetBackendRegistration registration("vulkan-spirv", [=]() {
return std::make_unique<VulkanSPIRVTargetBackend>(queryOptions());
});
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir