[spirv] Extract common SPIR-V target functionalities out (NFC) (#3140)
This commit just extracts some common SPIR-V target functionalities
into a new class so that later we can reuse them for a Metal-SPIRV
compiler target backend. It's pure refactoring plus removing
unnecessary header includes and library dependencies, so NFC.
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
new file mode 100644
index 0000000..9c55ead
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
@@ -0,0 +1,52 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+ return()
+endif()
+""",
+)
+
+cc_library(
+ name = "SPIRVCommon",
+ srcs = [
+ "SPIRVTarget.cpp",
+ ],
+ hdrs = [
+ "SPIRVTarget.h",
+ ],
+ deps = [
+ "//iree/compiler/Conversion/LinalgToSPIRV",
+ "//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SPIRVDialect",
+ "@llvm-project//mlir:SPIRVLowering",
+ "@llvm-project//mlir:SPIRVSerialization",
+ "@llvm-project//mlir:Support",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
new file mode 100644
index 0000000..0922d4a
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
@@ -0,0 +1,41 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ SPIRVCommon
+ HDRS
+ "SPIRVTarget.h"
+ SRCS
+ "SPIRVTarget.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRSPIRV
+ MLIRSPIRVSerialization
+ MLIRSPIRVTransforms
+ MLIRSupport
+ iree::compiler::Conversion::LinalgToSPIRV
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::IR
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
new file mode 100644
index 0000000..2017d60
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -0,0 +1,299 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Records a full execution barrier that forces visibility of all buffers.
+static void recordFullExecutionBarrier(Value commandBuffer, Location loc,
+ OpBuilder &builder) {
+ Value memoryBarrier = builder.create<IREE::HAL::MakeMemoryBarrierOp>(
+ loc, IREE::HAL::AccessScopeBitfield::DispatchWrite,
+ IREE::HAL::AccessScopeBitfield::DispatchRead);
+ builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
+ loc, commandBuffer, IREE::HAL::ExecutionStageBitfield::Dispatch,
+ IREE::HAL::ExecutionStageBitfield::Dispatch,
+ ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
+}
+
+/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
+/// function. This function has arguments the !shapex.ranked_shape for all the
+/// input and output shaped types. Using this the function returns the number of
+/// workgroups to use. To use this function on the host side, generate the
+/// !shapex.ranked_shape values that describe the shape of the inputs and
+/// outputs of the dispatch region and "inline" the function body.
+static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
+ Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
+ ArrayRef<Optional<TensorRewriteAdaptor>> operands,
+ ArrayRef<Optional<TensorRewriteAdaptor>> results, OpBuilder &builder) {
+ std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
+ // TODO: This is really just inlining a function. For now assume that the
+ // `numWorkgroupsFn` has a single block to make inlining easier.
+ if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
+ return returnValue;
+ SmallVector<SmallVector<Value, 4>, 4> shapeValues;
+ shapeValues.reserve(operands.size() + results.size());
+ auto getShapeValuesFn =
+ [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
+ for (auto val : values) {
+ if (!val) continue;
+ Optional<SmallVector<Value, 4>> shape = val->getShapeDims(builder);
+ if (!shape) return emitError(loc, "shape computation for operand failed");
+ shapeValues.push_back(shape.getValue());
+ }
+ return success();
+ };
+ if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
+ return returnValue;
+ BlockAndValueMapping mapper;
+ for (Operation &op : numWorkgroupsFn.front()) {
+ if (isa<mlir::ReturnOp>(op)) {
+ for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
+ i != e; ++i) {
+ returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
+ }
+ break;
+ }
+ if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
+ if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
+ auto &dimValues = shapeValues[arg.getArgNumber()];
+ mapper.map(arg, dimValues[shapeOp.getIndex()]);
+ continue;
+ }
+ return returnValue;
+ }
+ // If all its operands are mapped, clone it.
+ if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
+ return mapper.contains(operand);
+ })) {
+ builder.clone(op, mapper);
+ continue;
+ }
+ }
+ return returnValue;
+}
+
+SPIRVTargetBackend::SPIRVTargetBackend(SPIRVCodegenOptions options)
+ : spvCodeGenOptions_(std::move(options)) {}
+
+void SPIRVTargetBackend::declareTargetOpsForEnv(
+ IREE::Flow::ExecutableOp sourceOp, IREE::HAL::ExecutableOp executableOp,
+ spirv::TargetEnvAttr spvTargetEnv) {
+ auto targetBuilder = OpBuilder::atBlockTerminator(&executableOp.getBlock());
+ auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
+ sourceOp.getLoc(), name(), filter_pattern());
+
+ auto containerBuilder = OpBuilder::atBlockTerminator(&targetOp.getBlock());
+ auto innerModuleOp = containerBuilder.create<ModuleOp>(sourceOp.getLoc());
+
+ // Attach SPIR-V target environment to the target's ModuleOp.
+ // If we had multiple target environments we would generate one target op
+ // per environment, with each setting its own environment attribute.
+ innerModuleOp.setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
+}
+
+void SPIRVTargetBackend::buildTranslationPassPipeline(
+ IREE::HAL::ExecutableTargetOp targetOp, OpPassManager &passManager) {
+ buildSPIRVTransformPassPipeline(passManager, spvCodeGenOptions_);
+}
+
+LogicalResult SPIRVTargetBackend::recordDispatch(
+ Location loc, DispatchState dispatchState,
+ DeviceSwitchBuilder &switchBuilder) {
+ // Multiple entry points might be generated for a single dispatch function.
+ // Under such circumstances, we will have a special attribute indicating the
+ // schedule of the split entry points. Try to see if we can find such
+ // schedule attribute first.
+ ArrayAttr entryPointScheduleAttr;
+ spirv::ModuleOp spvModuleOp;
+ IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
+ for (auto executableTargetOp :
+ executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
+ if (matchPattern(executableTargetOp.target_backend_filter(),
+ filter_pattern())) {
+ ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
+ auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+ assert(llvm::hasSingleElement(spvModuleOps));
+ spvModuleOp = *spvModuleOps.begin();
+ entryPointScheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
+ iree_compiler::getEntryPointScheduleAttrName());
+ break;
+ }
+ }
+ if (!spvModuleOp) return executableOp.emitError("unable to find spv.module");
+
+ SmallVector<spirv::FuncOp, 2> spvEntryPointFns;
+ if (!entryPointScheduleAttr) {
+ for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
+ if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
+ SymbolTable::Visibility::Public)
+ spvEntryPointFns.push_back(spvFuncOp);
+ }
+ if (!llvm::hasSingleElement(spvEntryPointFns)) {
+ return spvModuleOp.emitError(
+ "expected a single entry point function, found ")
+ << spvEntryPointFns.size();
+ }
+ } else {
+ llvm::StringMap<spirv::FuncOp> publicFns;
+ for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
+ if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
+ SymbolTable::Visibility::Public)
+ publicFns[spvFuncOp.sym_name()] = spvFuncOp;
+ }
+ for (Attribute entryNameAttr : entryPointScheduleAttr) {
+ StringRef entryName = entryNameAttr.cast<StringAttr>().getValue();
+ spirv::FuncOp spvFuncOp = publicFns.lookup(entryName);
+ if (!spvFuncOp)
+ return spvModuleOp.emitError("unable to find entry point function ")
+ << entryName;
+ spvEntryPointFns.push_back(spvFuncOp);
+ }
+ }
+
+ auto *region = switchBuilder.addConditionRegion(
+ IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
+ {
+ dispatchState.workload,
+ dispatchState.commandBuffer,
+ dispatchState.executable,
+ });
+
+ auto &entryBlock = region->front();
+ auto builder = OpBuilder::atBlockBegin(&entryBlock);
+ auto workload = entryBlock.getArgument(0);
+ auto commandBuffer = entryBlock.getArgument(1);
+ auto executable = entryBlock.getArgument(2);
+
+ // We have multiple entry points to dispatch. Record in the order
+ // specified by entry point schedule and insert barrier between sequential
+ // ones.
+ for (auto it : llvm::enumerate(spvEntryPointFns)) {
+ spirv::FuncOp spvFuncOp = it.value();
+ auto workgroupSize = calculateDispatchWorkgroupSize(
+ loc, spvModuleOp, spvFuncOp.sym_name(), workload, builder);
+
+ FlatSymbolRefAttr numWorkgroupsFnAttr =
+ spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
+ getNumWorkgroupsFnAttrName());
+
+ std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
+ if (numWorkgroupsFnAttr) {
+ FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
+ spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
+ if (!numWorkgroupsFn) return failure();
+ workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
+ loc, numWorkgroupsFn, executableOp.getInterfaceOp(),
+ dispatchState.operands, dispatchState.results, builder);
+ } else {
+ workgroupCount = calculateDispatchWorkgroupCount(loc, workload,
+ workgroupSize, builder);
+ }
+
+ if (llvm::any_of(workgroupCount,
+ [](Value v) -> bool { return v == nullptr; }))
+ return spvFuncOp.emitError("unable to find workgroup count");
+
+ // Ordinals are fixed based on the precomputed schedule, so use
+ // CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
+ builder.create<IREE::HAL::CommandBufferDispatchOp>(
+ loc, commandBuffer, executable,
+ builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
+ workgroupCount[0], workgroupCount[1], workgroupCount[2]);
+ if (it.index() + 1 != spvEntryPointFns.size()) {
+ recordFullExecutionBarrier(commandBuffer, loc, builder);
+ }
+ }
+
+ builder.create<IREE::HAL::ReturnOp>(loc);
+ return success();
+}
+
+// Finds the spv.ExecutionMode operation to get the workgroup size from.
+// TODO(ravishankarm): This might not be the only way this is specified. You
+// could also have a spec constant, but that is not generated in the
+// spv.module right now.
+// TODO(ravishankarm): change workgroup size calculation to something we can
+// query independently so that we don't need to lookup the value here.
+std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) {
+ // TODO(ravishankarm): possibly emit different recordDispatch logic if the
+ // workgroup sizes differ among targets.
+ spirv::ModuleOp spvModuleOp;
+ for (auto executableTargetOp :
+ executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
+ if (matchPattern(executableTargetOp.target_backend_filter(),
+ filter_pattern())) {
+ ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
+ assert(!innerModuleOp.getAttr(
+ iree_compiler::getEntryPointScheduleAttrName()));
+ auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+ assert(llvm::hasSingleElement(spvModuleOps));
+ spvModuleOp = *spvModuleOps.begin();
+ break;
+ }
+ }
+ return calculateDispatchWorkgroupSize(
+ loc, spvModuleOp, entryPointOp.sym_name(), workload, builder);
+}
+
+std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
+ Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
+ Value workload, OpBuilder &builder) {
+ std::array<Value, 3> workgroupSize;
+ for (auto executionModeOp :
+ spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
+ if (executionModeOp.fn() == entryPointName &&
+ executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
+ for (int i = 0; i < executionModeOp.values().size(); ++i) {
+ workgroupSize[i] =
+ builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
+ .cast<IntegerAttr>()
+ .getValue()
+ .getZExtValue());
+ }
+ break;
+ }
+ }
+
+ // Pad out the workgroup size with 1's (if the original rank was < 3).
+ for (int i = 0; i < workgroupSize.size(); ++i) {
+ if (!workgroupSize[i]) {
+ workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
+ }
+ }
+
+ return workgroupSize;
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
new file mode 100644
index 0000000..d24409b
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
@@ -0,0 +1,64 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
+
+#include <string>
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// A SPIR-V target backend that shares common overrides for Vulkan and Metal.
+class SPIRVTargetBackend : public TargetBackend {
+ public:
+ explicit SPIRVTargetBackend(SPIRVCodegenOptions options);
+
+ void declareTargetOpsForEnv(IREE::Flow::ExecutableOp sourceOp,
+ IREE::HAL::ExecutableOp executableOp,
+ spirv::TargetEnvAttr spvTargetEnv);
+
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
+ OpPassManager &passManager) override;
+
+ LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
+ DeviceSwitchBuilder &switchBuilder) override;
+
+ // Finds the spv.ExecutionMode operation to get the workgroup size from.
+ std::array<Value, 3> calculateDispatchWorkgroupSize(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) override;
+
+ private:
+ std::array<Value, 3> calculateDispatchWorkgroupSize(
+ Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
+ Value workload, OpBuilder &builder);
+
+ SPIRVCodegenOptions spvCodeGenOptions_;
+};
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index 97f0662..97d6bc2 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -37,11 +37,10 @@
"VulkanSPIRVTarget.h",
],
deps = [
- "//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
- "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/HAL/Target/SPIRVCommon",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Dialect/Vulkan/Utils",
"//iree/schemas:spirv_executable_def_cc_fbs",
@@ -52,13 +51,9 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Parser",
- "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SPIRVDialect",
- "@llvm-project//mlir:SPIRVLowering",
"@llvm-project//mlir:SPIRVSerialization",
"@llvm-project//mlir:Support",
- "@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
- "@org_tensorflow//tensorflow/compiler/mlir/hlo",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index fdc9911..5452237 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -32,22 +32,17 @@
MLIRIR
MLIRLinalgOps
MLIRParser
- MLIRPass
MLIRSPIRV
MLIRSPIRVSerialization
- MLIRSPIRVTransforms
MLIRSupport
- MLIRTransforms
MLIRVector
flatbuffers
- iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::HAL::Target::SPIRVCommon
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Dialect::Vulkan::Utils
iree::schemas::spirv_executable_def_cc_fbs
- tensorflow::mlir_hlo
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index f18c2b8..4a9ac26 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -14,25 +14,20 @@
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
-#include <map>
-
#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
#include "iree/schemas/spirv_executable_def_generated.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
@@ -40,14 +35,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
@@ -117,79 +105,11 @@
return entryPointNames;
}
-// Records a full execution barrier that forces visibility of all buffers.
-static void recordFullExecutionBarrier(Value commandBuffer, Location loc,
- OpBuilder &builder) {
- Value memoryBarrier = builder.create<IREE::HAL::MakeMemoryBarrierOp>(
- loc, IREE::HAL::AccessScopeBitfield::DispatchWrite,
- IREE::HAL::AccessScopeBitfield::DispatchRead);
- builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
- loc, commandBuffer, IREE::HAL::ExecutionStageBitfield::Dispatch,
- IREE::HAL::ExecutionStageBitfield::Dispatch,
- ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
-}
-
-/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
-/// function. This function has arguments the !shapex.ranked_shape for all the
-/// input and output shaped types. Using this the function returns the number of
-/// workgroups to use. To use this function on the host side, generate the
-/// !shapex.ranked_shape values that describe the shape of the inputs and
-/// outputs of the dispatch region and "inline" the function body.
-static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
- Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
- ArrayRef<Optional<TensorRewriteAdaptor>> operands,
- ArrayRef<Optional<TensorRewriteAdaptor>> results, OpBuilder &builder) {
- std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
- // TODO: This is really just inlining a function. For now assume that the
- // `numWorkgroupsFn` has a single block to make inlining easier.
- if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
- return returnValue;
- SmallVector<SmallVector<Value, 4>, 4> shapeValues;
- shapeValues.reserve(operands.size() + results.size());
- auto getShapeValuesFn =
- [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
- for (auto val : values) {
- if (!val) continue;
- Optional<SmallVector<Value, 4>> shape = val->getShapeDims(builder);
- if (!shape) return emitError(loc, "shape computation for operand failed");
- shapeValues.push_back(shape.getValue());
- }
- return success();
- };
- if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
- return returnValue;
- BlockAndValueMapping mapper;
- for (Operation &op : numWorkgroupsFn.front()) {
- if (isa<mlir::ReturnOp>(op)) {
- for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
- i != e; ++i) {
- returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
- }
- break;
- }
- if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
- if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
- auto &dimValues = shapeValues[arg.getArgNumber()];
- mapper.map(arg, dimValues[shapeOp.getIndex()]);
- continue;
- }
- return returnValue;
- }
- // If all its operands are mapped, clone it.
- if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
- return mapper.contains(operand);
- })) {
- builder.clone(op, mapper);
- continue;
- }
- }
- return returnValue;
-}
-
-class VulkanSPIRVTargetBackend : public TargetBackend {
+class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
public:
VulkanSPIRVTargetBackend(VulkanSPIRVTargetOptions options)
- : options_(std::move(options)) {}
+ : SPIRVTargetBackend(options.codegenOptions),
+ options_(std::move(options)) {}
// NOTE: we could vary these based on the options such as 'vulkan-v1.1'.
std::string name() const override { return "vulkan_spirv"; }
@@ -209,194 +129,9 @@
void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) override {
- OpBuilder targetBuilder(&executableOp.getBlock().back());
- auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
- sourceOp.getLoc(), name(), filter_pattern());
- OpBuilder containerBuilder(&targetOp.getBlock().back());
-
- auto innerModuleOp = containerBuilder.create<ModuleOp>(sourceOp.getLoc());
- // Attach SPIR-V target environment to the target's ModuleOp.
- // If we had multiple target environments we would generate one target op
- // per environment, with each setting its own environment attribute.
spirv::TargetEnvAttr spvTargetEnv =
getSPIRVTargetEnv(options_.vulkanTargetEnv, sourceOp.getContext());
- innerModuleOp.setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
- }
-
- void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
- OpPassManager &passManager) override {
- buildSPIRVTransformPassPipeline(passManager, options_.codegenOptions);
- }
-
- LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
- DeviceSwitchBuilder &switchBuilder) override {
- // Multiple entry points might be generated for a single dispatch function.
- // Under such circumstances, we will have a special attribute indicating the
- // schedule of the split entry points. Try to see if we can find such
- // schedule attribute first.
- ArrayAttr entryPointScheduleAttr;
- spirv::ModuleOp spvModuleOp;
- IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
- for (auto executableTargetOp :
- executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
- if (matchPattern(executableTargetOp.target_backend_filter(),
- filter_pattern())) {
- ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
- auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
- assert(llvm::hasSingleElement(spvModuleOps));
- spvModuleOp = *spvModuleOps.begin();
- entryPointScheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
- iree_compiler::getEntryPointScheduleAttrName());
- break;
- }
- }
- if (!spvModuleOp)
- return executableOp.emitError("unable to find spv.module");
-
- SmallVector<spirv::FuncOp, 2> spvEntryPointFns;
- if (!entryPointScheduleAttr) {
- for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
- SymbolTable::Visibility::Public)
- spvEntryPointFns.push_back(spvFuncOp);
- }
- if (!llvm::hasSingleElement(spvEntryPointFns)) {
- return spvModuleOp.emitError(
- "expected a single entry point function, found ")
- << spvEntryPointFns.size();
- }
- } else {
- llvm::StringMap<spirv::FuncOp> publicFns;
- for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
- SymbolTable::Visibility::Public)
- publicFns[spvFuncOp.sym_name()] = spvFuncOp;
- }
- for (Attribute entryNameAttr : entryPointScheduleAttr) {
- StringRef entryName = entryNameAttr.cast<StringAttr>().getValue();
- spirv::FuncOp spvFuncOp = publicFns.lookup(entryName);
- if (!spvFuncOp)
- return spvModuleOp.emitError("unable to find entry point function ")
- << entryName;
- spvEntryPointFns.push_back(spvFuncOp);
- }
- }
-
- auto *region = switchBuilder.addConditionRegion(
- IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
- {
- dispatchState.workload,
- dispatchState.commandBuffer,
- dispatchState.executable,
- });
-
- auto &entryBlock = region->front();
- auto builder = OpBuilder::atBlockBegin(&entryBlock);
- auto workload = entryBlock.getArgument(0);
- auto commandBuffer = entryBlock.getArgument(1);
- auto executable = entryBlock.getArgument(2);
-
- // We have multiple entry points to dispatch. Record in the order
- // specified by entry point schedule and insert barrier between sequential
- // ones.
- for (auto it : llvm::enumerate(spvEntryPointFns)) {
- spirv::FuncOp spvFuncOp = it.value();
- auto workgroupSize = calculateDispatchWorkgroupSize(
- loc, spvModuleOp, spvFuncOp.sym_name(), workload, builder);
-
- FlatSymbolRefAttr numWorkgroupsFnAttr =
- spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
- getNumWorkgroupsFnAttrName());
-
- std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
- if (numWorkgroupsFnAttr) {
- FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
- spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
- if (!numWorkgroupsFn) return failure();
- workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
- loc, numWorkgroupsFn, executableOp.getInterfaceOp(),
- dispatchState.operands, dispatchState.results, builder);
- } else {
- workgroupCount = calculateDispatchWorkgroupCount(
- loc, workload, workgroupSize, builder);
- }
-
- if (llvm::any_of(workgroupCount,
- [](Value v) -> bool { return v == nullptr; }))
- return spvFuncOp.emitError("unable to find workgroup count");
-
- // Ordinals are fixed based on the precomputed schedule, so use
- // CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
- builder.create<IREE::HAL::CommandBufferDispatchOp>(
- loc, commandBuffer, executable,
- builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
- workgroupCount[0], workgroupCount[1], workgroupCount[2]);
- if (it.index() + 1 != spvEntryPointFns.size()) {
- recordFullExecutionBarrier(commandBuffer, loc, builder);
- }
- }
-
- builder.create<IREE::HAL::ReturnOp>(loc);
- return success();
- }
-
- // Finds the spv.ExecutionMode operation to get the workgroup size from.
- // TODO(ravishankarm): This might not be the only way this is specified. You
- // could also have a spec constant, but that is not generated in the
- // spv.module right now.
- // TODO(ravishankarm): change workgroup size calculation to something we can
- // query independently so that we don't need to lookup the value here.
- std::array<Value, 3> calculateDispatchWorkgroupSize(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
- OpBuilder &builder) override {
- // TODO(ravishankarm): possibly emit different recordDispatch logic if the
- // workgroup sizes differ among targets.
- spirv::ModuleOp spvModuleOp;
- for (auto executableTargetOp :
- executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
- if (matchPattern(executableTargetOp.target_backend_filter(),
- filter_pattern())) {
- ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
- assert(!innerModuleOp.getAttr(
- iree_compiler::getEntryPointScheduleAttrName()));
- auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
- assert(llvm::hasSingleElement(spvModuleOps));
- spvModuleOp = *spvModuleOps.begin();
- break;
- }
- }
- return calculateDispatchWorkgroupSize(
- loc, spvModuleOp, entryPointOp.sym_name(), workload, builder);
- }
-
- std::array<Value, 3> calculateDispatchWorkgroupSize(
- Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
- Value workload, OpBuilder &builder) {
- std::array<Value, 3> workgroupSize;
- for (auto executionModeOp :
- spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
- if (executionModeOp.fn() == entryPointName &&
- executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
- for (int i = 0; i < executionModeOp.values().size(); ++i) {
- workgroupSize[i] =
- builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
- .cast<IntegerAttr>()
- .getValue()
- .getZExtValue());
- }
- break;
- }
- }
-
- // Pad out the workgroup size with 1's (if the original rank was < 3).
- for (int i = 0; i < workgroupSize.size(); ++i) {
- if (!workgroupSize[i]) {
- workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
- }
- }
-
- return workgroupSize;
+ declareTargetOpsForEnv(sourceOp, executableOp, spvTargetEnv);
}
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
index 4eecf81..8382b5d 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
@@ -15,11 +15,10 @@
#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_VULKANSPIRV_VULKANSPIRVTARGET_H_
#define IREE_COMPILER_DIALECT_HAL_TARGET_VULKANSPIRV_VULKANSPIRVTARGET_H_
+#include <functional>
#include <string>
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
-#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace iree_compiler {