[spirv] Extract common SPIR-V target functionalities out (NFC) (#3140)

This commit just extracts some common SPIR-V target functionalities
into a new class so that later we can reuse them for a Metal-SPIRV
compiler target backend. It's pure refactoring plus removing
unnecessary header includes and library dependencies, so NFC.
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
new file mode 100644
index 0000000..9c55ead
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
@@ -0,0 +1,52 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_cmake_extra_content(
+    content = """
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+  return()
+endif()
+""",
+)
+
+cc_library(
+    name = "SPIRVCommon",
+    srcs = [
+        "SPIRVTarget.cpp",
+    ],
+    hdrs = [
+        "SPIRVTarget.h",
+    ],
+    deps = [
+        "//iree/compiler/Conversion/LinalgToSPIRV",
+        "//iree/compiler/Dialect/Flow/IR",
+        "//iree/compiler/Dialect/HAL/Target",
+        "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Dialect/Shape/IR",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:SPIRVDialect",
+        "@llvm-project//mlir:SPIRVLowering",
+        "@llvm-project//mlir:SPIRVSerialization",
+        "@llvm-project//mlir:Support",
+    ],
+)
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
new file mode 100644
index 0000000..0922d4a
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
@@ -0,0 +1,41 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+  return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    SPIRVCommon
+  HDRS
+    "SPIRVTarget.h"
+  SRCS
+    "SPIRVTarget.cpp"
+  DEPS
+    LLVMSupport
+    MLIRIR
+    MLIRSPIRV
+    MLIRSPIRVSerialization
+    MLIRSPIRVTransforms
+    MLIRSupport
+    iree::compiler::Conversion::LinalgToSPIRV
+    iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::HAL::Target
+    iree::compiler::Dialect::IREE::IR
+    iree::compiler::Dialect::Shape::IR
+  PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
new file mode 100644
index 0000000..2017d60
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -0,0 +1,299 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Records a full execution barrier that forces visibility of all buffers.
+static void recordFullExecutionBarrier(Value commandBuffer, Location loc,
+                                       OpBuilder &builder) {
+  Value memoryBarrier = builder.create<IREE::HAL::MakeMemoryBarrierOp>(
+      loc, IREE::HAL::AccessScopeBitfield::DispatchWrite,
+      IREE::HAL::AccessScopeBitfield::DispatchRead);
+  builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
+      loc, commandBuffer, IREE::HAL::ExecutionStageBitfield::Dispatch,
+      IREE::HAL::ExecutionStageBitfield::Dispatch,
+      ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
+}
+
+/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
+/// function. This function has arguments the !shapex.ranked_shape for all the
+/// input and output shaped types. Using this the function returns the number of
+/// workgroups to use. To use this function on the host side, generate the
+/// !shapex.ranked_shape values that describe the shape of the inputs and
+/// outputs of the dispatch region and "inline" the function body.
+static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
+    Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
+    ArrayRef<Optional<TensorRewriteAdaptor>> operands,
+    ArrayRef<Optional<TensorRewriteAdaptor>> results, OpBuilder &builder) {
+  std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
+  // TODO: This is really just inlining a function. For now assume that the
+  // `numWorkgroupsFn` has a single block to make inlining easier.
+  if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
+    return returnValue;
+  SmallVector<SmallVector<Value, 4>, 4> shapeValues;
+  shapeValues.reserve(operands.size() + results.size());
+  auto getShapeValuesFn =
+      [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
+    for (auto val : values) {
+      if (!val) continue;
+      Optional<SmallVector<Value, 4>> shape = val->getShapeDims(builder);
+      if (!shape) return emitError(loc, "shape computation for operand failed");
+      shapeValues.push_back(shape.getValue());
+    }
+    return success();
+  };
+  if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
+    return returnValue;
+  BlockAndValueMapping mapper;
+  for (Operation &op : numWorkgroupsFn.front()) {
+    if (isa<mlir::ReturnOp>(op)) {
+      for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
+           i != e; ++i) {
+        returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
+      }
+      break;
+    }
+    if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
+      if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
+        auto &dimValues = shapeValues[arg.getArgNumber()];
+        mapper.map(arg, dimValues[shapeOp.getIndex()]);
+        continue;
+      }
+      return returnValue;
+    }
+    // If all its operands are mapped, clone it.
+    if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
+          return mapper.contains(operand);
+        })) {
+      builder.clone(op, mapper);
+      continue;
+    }
+  }
+  return returnValue;
+}
+
+SPIRVTargetBackend::SPIRVTargetBackend(SPIRVCodegenOptions options)
+    : spvCodeGenOptions_(std::move(options)) {}
+
+void SPIRVTargetBackend::declareTargetOpsForEnv(
+    IREE::Flow::ExecutableOp sourceOp, IREE::HAL::ExecutableOp executableOp,
+    spirv::TargetEnvAttr spvTargetEnv) {
+  auto targetBuilder = OpBuilder::atBlockTerminator(&executableOp.getBlock());
+  auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
+      sourceOp.getLoc(), name(), filter_pattern());
+
+  auto containerBuilder = OpBuilder::atBlockTerminator(&targetOp.getBlock());
+  auto innerModuleOp = containerBuilder.create<ModuleOp>(sourceOp.getLoc());
+
+  // Attach SPIR-V target environment to the target's ModuleOp.
+  // If we had multiple target environments we would generate one target op
+  // per environment, with each setting its own environment attribute.
+  innerModuleOp.setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
+}
+
+void SPIRVTargetBackend::buildTranslationPassPipeline(
+    IREE::HAL::ExecutableTargetOp targetOp, OpPassManager &passManager) {
+  buildSPIRVTransformPassPipeline(passManager, spvCodeGenOptions_);
+}
+
+LogicalResult SPIRVTargetBackend::recordDispatch(
+    Location loc, DispatchState dispatchState,
+    DeviceSwitchBuilder &switchBuilder) {
+  // Multiple entry points might be generated for a single dispatch function.
+  // Under such circumstances, we will have a special attribute indicating the
+  // schedule of the split entry points. Try to see if we can find such
+  // schedule attribute first.
+  ArrayAttr entryPointScheduleAttr;
+  spirv::ModuleOp spvModuleOp;
+  IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
+  for (auto executableTargetOp :
+       executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
+    if (matchPattern(executableTargetOp.target_backend_filter(),
+                     filter_pattern())) {
+      ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
+      auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+      assert(llvm::hasSingleElement(spvModuleOps));
+      spvModuleOp = *spvModuleOps.begin();
+      entryPointScheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
+          iree_compiler::getEntryPointScheduleAttrName());
+      break;
+    }
+  }
+  if (!spvModuleOp) return executableOp.emitError("unable to find spv.module");
+
+  SmallVector<spirv::FuncOp, 2> spvEntryPointFns;
+  if (!entryPointScheduleAttr) {
+    for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
+      if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
+          SymbolTable::Visibility::Public)
+        spvEntryPointFns.push_back(spvFuncOp);
+    }
+    if (!llvm::hasSingleElement(spvEntryPointFns)) {
+      return spvModuleOp.emitError(
+                 "expected a single entry point function, found ")
+             << spvEntryPointFns.size();
+    }
+  } else {
+    llvm::StringMap<spirv::FuncOp> publicFns;
+    for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
+      if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
+          SymbolTable::Visibility::Public)
+        publicFns[spvFuncOp.sym_name()] = spvFuncOp;
+    }
+    for (Attribute entryNameAttr : entryPointScheduleAttr) {
+      StringRef entryName = entryNameAttr.cast<StringAttr>().getValue();
+      spirv::FuncOp spvFuncOp = publicFns.lookup(entryName);
+      if (!spvFuncOp)
+        return spvModuleOp.emitError("unable to find entry point function ")
+               << entryName;
+      spvEntryPointFns.push_back(spvFuncOp);
+    }
+  }
+
+  auto *region = switchBuilder.addConditionRegion(
+      IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
+      {
+          dispatchState.workload,
+          dispatchState.commandBuffer,
+          dispatchState.executable,
+      });
+
+  auto &entryBlock = region->front();
+  auto builder = OpBuilder::atBlockBegin(&entryBlock);
+  auto workload = entryBlock.getArgument(0);
+  auto commandBuffer = entryBlock.getArgument(1);
+  auto executable = entryBlock.getArgument(2);
+
+  // We have multiple entry points to dispatch. Record in the order
+  // specified by entry point schedule and insert barrier between sequential
+  // ones.
+  for (auto it : llvm::enumerate(spvEntryPointFns)) {
+    spirv::FuncOp spvFuncOp = it.value();
+    auto workgroupSize = calculateDispatchWorkgroupSize(
+        loc, spvModuleOp, spvFuncOp.sym_name(), workload, builder);
+
+    FlatSymbolRefAttr numWorkgroupsFnAttr =
+        spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
+            getNumWorkgroupsFnAttrName());
+
+    std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
+    if (numWorkgroupsFnAttr) {
+      FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
+          spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
+      if (!numWorkgroupsFn) return failure();
+      workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
+          loc, numWorkgroupsFn, executableOp.getInterfaceOp(),
+          dispatchState.operands, dispatchState.results, builder);
+    } else {
+      workgroupCount = calculateDispatchWorkgroupCount(loc, workload,
+                                                       workgroupSize, builder);
+    }
+
+    if (llvm::any_of(workgroupCount,
+                     [](Value v) -> bool { return v == nullptr; }))
+      return spvFuncOp.emitError("unable to find workgroup count");
+
+    // Ordinals are fixed based on the precomputed schedule, so use
+    // CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
+    builder.create<IREE::HAL::CommandBufferDispatchOp>(
+        loc, commandBuffer, executable,
+        builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
+        workgroupCount[0], workgroupCount[1], workgroupCount[2]);
+    if (it.index() + 1 != spvEntryPointFns.size()) {
+      recordFullExecutionBarrier(commandBuffer, loc, builder);
+    }
+  }
+
+  builder.create<IREE::HAL::ReturnOp>(loc);
+  return success();
+}
+
+// Finds the spv.ExecutionMode operation to get the workgroup size from.
+// TODO(ravishankarm): This might not be the only way this is specified. You
+// could also have a spec constant, but that is not generated in the
+// spv.module right now.
+// TODO(ravishankarm): change workgroup size calculation to something we can
+// query independently so that we don't need to lookup the value here.
+std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
+    Location loc, IREE::HAL::ExecutableOp executableOp,
+    IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+    OpBuilder &builder) {
+  // TODO(ravishankarm): possibly emit different recordDispatch logic if the
+  // workgroup sizes differ among targets.
+  spirv::ModuleOp spvModuleOp;
+  for (auto executableTargetOp :
+       executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
+    if (matchPattern(executableTargetOp.target_backend_filter(),
+                     filter_pattern())) {
+      ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
+      assert(!innerModuleOp.getAttr(
+          iree_compiler::getEntryPointScheduleAttrName()));
+      auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
+      assert(llvm::hasSingleElement(spvModuleOps));
+      spvModuleOp = *spvModuleOps.begin();
+      break;
+    }
+  }
+  return calculateDispatchWorkgroupSize(
+      loc, spvModuleOp, entryPointOp.sym_name(), workload, builder);
+}
+
+std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
+    Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
+    Value workload, OpBuilder &builder) {
+  std::array<Value, 3> workgroupSize;
+  for (auto executionModeOp :
+       spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
+    if (executionModeOp.fn() == entryPointName &&
+        executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
+      for (int i = 0; i < executionModeOp.values().size(); ++i) {
+        workgroupSize[i] =
+            builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
+                                                     .cast<IntegerAttr>()
+                                                     .getValue()
+                                                     .getZExtValue());
+      }
+      break;
+    }
+  }
+
+  // Pad out the workgroup size with 1's (if the original rank was < 3).
+  for (int i = 0; i < workgroupSize.size(); ++i) {
+    if (!workgroupSize[i]) {
+      workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
+    }
+  }
+
+  return workgroupSize;
+}
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
new file mode 100644
index 0000000..d24409b
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
@@ -0,0 +1,64 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
+
+#include <string>
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// A SPIR-V target backend that shares common overrides for Vulkan and Metal.
+class SPIRVTargetBackend : public TargetBackend {
+ public:
+  explicit SPIRVTargetBackend(SPIRVCodegenOptions options);
+
+  void declareTargetOpsForEnv(IREE::Flow::ExecutableOp sourceOp,
+                              IREE::HAL::ExecutableOp executableOp,
+                              spirv::TargetEnvAttr spvTargetEnv);
+
+  void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
+                                    OpPassManager &passManager) override;
+
+  LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
+                               DeviceSwitchBuilder &switchBuilder) override;
+
+  // Finds the spv.ExecutionMode operation to get the workgroup size from.
+  std::array<Value, 3> calculateDispatchWorkgroupSize(
+      Location loc, IREE::HAL::ExecutableOp executableOp,
+      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+      OpBuilder &builder) override;
+
+ private:
+  std::array<Value, 3> calculateDispatchWorkgroupSize(
+      Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
+      Value workload, OpBuilder &builder);
+
+  SPIRVCodegenOptions spvCodeGenOptions_;
+};
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_HAL_TARGET_SPIRVCOMMON_SPIRVTARGET_H_
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index 97f0662..97d6bc2 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -37,11 +37,10 @@
         "VulkanSPIRVTarget.h",
     ],
     deps = [
-        "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Conversion/LinalgToSPIRV",
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/Target",
-        "//iree/compiler/Dialect/Shape/IR",
+        "//iree/compiler/Dialect/HAL/Target/SPIRVCommon",
         "//iree/compiler/Dialect/Vulkan/IR",
         "//iree/compiler/Dialect/Vulkan/Utils",
         "//iree/schemas:spirv_executable_def_cc_fbs",
@@ -52,13 +51,9 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SPIRVDialect",
-        "@llvm-project//mlir:SPIRVLowering",
         "@llvm-project//mlir:SPIRVSerialization",
         "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorOps",
-        "@org_tensorflow//tensorflow/compiler/mlir/hlo",
     ],
 )
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index fdc9911..5452237 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -32,22 +32,17 @@
     MLIRIR
     MLIRLinalgOps
     MLIRParser
-    MLIRPass
     MLIRSPIRV
     MLIRSPIRVSerialization
-    MLIRSPIRVTransforms
     MLIRSupport
-    MLIRTransforms
     MLIRVector
     flatbuffers
-    iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Conversion::LinalgToSPIRV
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::Target
-    iree::compiler::Dialect::Shape::IR
+    iree::compiler::Dialect::HAL::Target::SPIRVCommon
     iree::compiler::Dialect::Vulkan::IR
     iree::compiler::Dialect::Vulkan::Utils
     iree::schemas::spirv_executable_def_cc_fbs
-    tensorflow::mlir_hlo
   PUBLIC
 )
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index f18c2b8..4a9ac26 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -14,25 +14,20 @@
 
 #include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
 
-#include <map>
-
 #include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
 #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
 #include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
 #include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
 #include "iree/schemas/spirv_executable_def_generated.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
@@ -40,14 +35,7 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -117,79 +105,11 @@
   return entryPointNames;
 }
 
-// Records a full execution barrier that forces visibility of all buffers.
-static void recordFullExecutionBarrier(Value commandBuffer, Location loc,
-                                       OpBuilder &builder) {
-  Value memoryBarrier = builder.create<IREE::HAL::MakeMemoryBarrierOp>(
-      loc, IREE::HAL::AccessScopeBitfield::DispatchWrite,
-      IREE::HAL::AccessScopeBitfield::DispatchRead);
-  builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
-      loc, commandBuffer, IREE::HAL::ExecutionStageBitfield::Dispatch,
-      IREE::HAL::ExecutionStageBitfield::Dispatch,
-      ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
-}
-
-/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
-/// function. This function has arguments the !shapex.ranked_shape for all the
-/// input and output shaped types. Using this the function returns the number of
-/// workgroups to use. To use this function on the host side, generate the
-/// !shapex.ranked_shape values that describe the shape of the inputs and
-/// outputs of the dispatch region and "inline" the function body.
-static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
-    Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
-    ArrayRef<Optional<TensorRewriteAdaptor>> operands,
-    ArrayRef<Optional<TensorRewriteAdaptor>> results, OpBuilder &builder) {
-  std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
-  // TODO: This is really just inlining a function. For now assume that the
-  // `numWorkgroupsFn` has a single block to make inlining easier.
-  if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
-    return returnValue;
-  SmallVector<SmallVector<Value, 4>, 4> shapeValues;
-  shapeValues.reserve(operands.size() + results.size());
-  auto getShapeValuesFn =
-      [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
-    for (auto val : values) {
-      if (!val) continue;
-      Optional<SmallVector<Value, 4>> shape = val->getShapeDims(builder);
-      if (!shape) return emitError(loc, "shape computation for operand failed");
-      shapeValues.push_back(shape.getValue());
-    }
-    return success();
-  };
-  if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
-    return returnValue;
-  BlockAndValueMapping mapper;
-  for (Operation &op : numWorkgroupsFn.front()) {
-    if (isa<mlir::ReturnOp>(op)) {
-      for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
-           i != e; ++i) {
-        returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
-      }
-      break;
-    }
-    if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
-      if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
-        auto &dimValues = shapeValues[arg.getArgNumber()];
-        mapper.map(arg, dimValues[shapeOp.getIndex()]);
-        continue;
-      }
-      return returnValue;
-    }
-    // If all its operands are mapped, clone it.
-    if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
-          return mapper.contains(operand);
-        })) {
-      builder.clone(op, mapper);
-      continue;
-    }
-  }
-  return returnValue;
-}
-
-class VulkanSPIRVTargetBackend : public TargetBackend {
+class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
  public:
   VulkanSPIRVTargetBackend(VulkanSPIRVTargetOptions options)
-      : options_(std::move(options)) {}
+      : SPIRVTargetBackend(options.codegenOptions),
+        options_(std::move(options)) {}
 
   // NOTE: we could vary these based on the options such as 'vulkan-v1.1'.
   std::string name() const override { return "vulkan_spirv"; }
@@ -209,194 +129,9 @@
 
   void declareTargetOps(IREE::Flow::ExecutableOp sourceOp,
                         IREE::HAL::ExecutableOp executableOp) override {
-    OpBuilder targetBuilder(&executableOp.getBlock().back());
-    auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
-        sourceOp.getLoc(), name(), filter_pattern());
-    OpBuilder containerBuilder(&targetOp.getBlock().back());
-
-    auto innerModuleOp = containerBuilder.create<ModuleOp>(sourceOp.getLoc());
-    // Attach SPIR-V target environment to the target's ModuleOp.
-    // If we had multiple target environments we would generate one target op
-    // per environment, with each setting its own environment attribute.
     spirv::TargetEnvAttr spvTargetEnv =
         getSPIRVTargetEnv(options_.vulkanTargetEnv, sourceOp.getContext());
-    innerModuleOp.setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
-  }
-
-  void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetOp targetOp,
-                                    OpPassManager &passManager) override {
-    buildSPIRVTransformPassPipeline(passManager, options_.codegenOptions);
-  }
-
-  LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
-                               DeviceSwitchBuilder &switchBuilder) override {
-    // Multiple entry points might be generated for a single dispatch function.
-    // Under such circumstances, we will have a special attribute indicating the
-    // schedule of the split entry points. Try to see if we can find such
-    // schedule attribute first.
-    ArrayAttr entryPointScheduleAttr;
-    spirv::ModuleOp spvModuleOp;
-    IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
-    for (auto executableTargetOp :
-         executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
-      if (matchPattern(executableTargetOp.target_backend_filter(),
-                       filter_pattern())) {
-        ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
-        auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
-        assert(llvm::hasSingleElement(spvModuleOps));
-        spvModuleOp = *spvModuleOps.begin();
-        entryPointScheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
-            iree_compiler::getEntryPointScheduleAttrName());
-        break;
-      }
-    }
-    if (!spvModuleOp)
-      return executableOp.emitError("unable to find spv.module");
-
-    SmallVector<spirv::FuncOp, 2> spvEntryPointFns;
-    if (!entryPointScheduleAttr) {
-      for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
-        if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
-            SymbolTable::Visibility::Public)
-          spvEntryPointFns.push_back(spvFuncOp);
-      }
-      if (!llvm::hasSingleElement(spvEntryPointFns)) {
-        return spvModuleOp.emitError(
-                   "expected a single entry point function, found ")
-               << spvEntryPointFns.size();
-      }
-    } else {
-      llvm::StringMap<spirv::FuncOp> publicFns;
-      for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
-        if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
-            SymbolTable::Visibility::Public)
-          publicFns[spvFuncOp.sym_name()] = spvFuncOp;
-      }
-      for (Attribute entryNameAttr : entryPointScheduleAttr) {
-        StringRef entryName = entryNameAttr.cast<StringAttr>().getValue();
-        spirv::FuncOp spvFuncOp = publicFns.lookup(entryName);
-        if (!spvFuncOp)
-          return spvModuleOp.emitError("unable to find entry point function ")
-                 << entryName;
-        spvEntryPointFns.push_back(spvFuncOp);
-      }
-    }
-
-    auto *region = switchBuilder.addConditionRegion(
-        IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
-        {
-            dispatchState.workload,
-            dispatchState.commandBuffer,
-            dispatchState.executable,
-        });
-
-    auto &entryBlock = region->front();
-    auto builder = OpBuilder::atBlockBegin(&entryBlock);
-    auto workload = entryBlock.getArgument(0);
-    auto commandBuffer = entryBlock.getArgument(1);
-    auto executable = entryBlock.getArgument(2);
-
-    // We have multiple entry points to dispatch. Record in the order
-    // specified by entry point schedule and insert barrier between sequential
-    // ones.
-    for (auto it : llvm::enumerate(spvEntryPointFns)) {
-      spirv::FuncOp spvFuncOp = it.value();
-      auto workgroupSize = calculateDispatchWorkgroupSize(
-          loc, spvModuleOp, spvFuncOp.sym_name(), workload, builder);
-
-      FlatSymbolRefAttr numWorkgroupsFnAttr =
-          spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
-              getNumWorkgroupsFnAttrName());
-
-      std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
-      if (numWorkgroupsFnAttr) {
-        FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
-            spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
-        if (!numWorkgroupsFn) return failure();
-        workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
-            loc, numWorkgroupsFn, executableOp.getInterfaceOp(),
-            dispatchState.operands, dispatchState.results, builder);
-      } else {
-        workgroupCount = calculateDispatchWorkgroupCount(
-            loc, workload, workgroupSize, builder);
-      }
-
-      if (llvm::any_of(workgroupCount,
-                       [](Value v) -> bool { return v == nullptr; }))
-        return spvFuncOp.emitError("unable to find workgroup count");
-
-      // Ordinals are fixed based on the precomputed schedule, so use
-      // CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
-      builder.create<IREE::HAL::CommandBufferDispatchOp>(
-          loc, commandBuffer, executable,
-          builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
-          workgroupCount[0], workgroupCount[1], workgroupCount[2]);
-      if (it.index() + 1 != spvEntryPointFns.size()) {
-        recordFullExecutionBarrier(commandBuffer, loc, builder);
-      }
-    }
-
-    builder.create<IREE::HAL::ReturnOp>(loc);
-    return success();
-  }
-
-  // Finds the spv.ExecutionMode operation to get the workgroup size from.
-  // TODO(ravishankarm): This might not be the only way this is specified. You
-  // could also have a spec constant, but that is not generated in the
-  // spv.module right now.
-  // TODO(ravishankarm): change workgroup size calculation to something we can
-  // query independently so that we don't need to lookup the value here.
-  std::array<Value, 3> calculateDispatchWorkgroupSize(
-      Location loc, IREE::HAL::ExecutableOp executableOp,
-      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
-      OpBuilder &builder) override {
-    // TODO(ravishankarm): possibly emit different recordDispatch logic if the
-    // workgroup sizes differ among targets.
-    spirv::ModuleOp spvModuleOp;
-    for (auto executableTargetOp :
-         executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
-      if (matchPattern(executableTargetOp.target_backend_filter(),
-                       filter_pattern())) {
-        ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
-        assert(!innerModuleOp.getAttr(
-            iree_compiler::getEntryPointScheduleAttrName()));
-        auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
-        assert(llvm::hasSingleElement(spvModuleOps));
-        spvModuleOp = *spvModuleOps.begin();
-        break;
-      }
-    }
-    return calculateDispatchWorkgroupSize(
-        loc, spvModuleOp, entryPointOp.sym_name(), workload, builder);
-  }
-
-  std::array<Value, 3> calculateDispatchWorkgroupSize(
-      Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
-      Value workload, OpBuilder &builder) {
-    std::array<Value, 3> workgroupSize;
-    for (auto executionModeOp :
-         spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
-      if (executionModeOp.fn() == entryPointName &&
-          executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
-        for (int i = 0; i < executionModeOp.values().size(); ++i) {
-          workgroupSize[i] =
-              builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
-                                                       .cast<IntegerAttr>()
-                                                       .getValue()
-                                                       .getZExtValue());
-        }
-        break;
-      }
-    }
-
-    // Pad out the workgroup size with 1's (if the original rank was < 3).
-    for (int i = 0; i < workgroupSize.size(); ++i) {
-      if (!workgroupSize[i]) {
-        workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
-      }
-    }
-
-    return workgroupSize;
+    declareTargetOpsForEnv(sourceOp, executableOp, spvTargetEnv);
   }
 
   LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
index 4eecf81..8382b5d 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
@@ -15,11 +15,10 @@
 #ifndef IREE_COMPILER_DIALECT_HAL_TARGET_VULKANSPIRV_VULKANSPIRVTARGET_H_
 #define IREE_COMPILER_DIALECT_HAL_TARGET_VULKANSPIRV_VULKANSPIRVTARGET_H_
 
+#include <functional>
 #include <string>
 
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
-#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 namespace iree_compiler {