Pull in passes to convert from XLA HLO ops to SPIR-V.

Pulls in passes to convert xla-hlo to linalg.generic on tensors,
followed by conversion from linalg.generic on tensors to
linalg.generic on buffers. The the linalg to SPIR-V pass pipeline can
be used to generate SPIR-V dialect.
Also adds some patterns to IREELinalgTensorToBuffer to eliminate
IREE::*Ops
PiperOrigin-RevId: 293252317
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt
index 1bb1bee..4543971 100644
--- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt
+++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla/CMakeLists.txt
@@ -43,11 +43,13 @@
     "transforms/lower_general_dot.cc"
     "transforms/materialize_broadcasts.cc"
     "transforms/unfuse_batch_norm.cc"
+    "transforms/xla_legalize_to_linalg.cc"
   HDRS
     "convert_op_folder.h"
     "ir/hlo_ops.h"
     "ir/hlo_utils.h"
     "ir/lhlo_ops.h"
+    "transforms/map_xla_to_scalar_op.h"
     "transforms/passes.h"
     "transforms/rewriters.h"
   COPTS
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index eb51823..1c5ac34 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -30,6 +30,7 @@
         "//iree/compiler/Dialect/HAL/Target:ExecutableTarget",
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Translation/SPIRV/EmbeddedKernels",
+        "//iree/compiler/Translation/SPIRV/LinalgToSPIRV",
         "//iree/compiler/Translation/SPIRV/XLAToSPIRV",
         "//iree/schemas:spirv_executable_def_cc_fbs",
         "@com_github_google_flatbuffers//:flatbuffers",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 6c962e5..917dfa4 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -24,6 +24,7 @@
     iree::compiler::Dialect::HAL::Target::ExecutableTarget
     iree::compiler::Dialect::IREE::IR
     iree::compiler::Translation::SPIRV::EmbeddedKernels
+    iree::compiler::Translation::SPIRV::LinalgToSPIRV
     iree::compiler::Translation::SPIRV::XLAToSPIRV
     iree::schemas::spirv_executable_def_cc_fbs
     flatbuffers
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 75a4a83..ac1d0d8 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -20,6 +20,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/Target/LegacyUtil.h"
 #include "iree/compiler/Translation/SPIRV/EmbeddedKernels/EmbeddedKernels.h"
+#include "iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h"
 #include "iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.h"
 #include "iree/schemas/spirv_executable_def_generated.h"
 #include "llvm/ADT/STLExtras.h"
@@ -43,6 +44,12 @@
 // static llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
 //     "IREE Vulkan/SPIR-V backend options");
 
+static llvm::cl::opt<bool> useLinalgPathForCodegen(
+    "iree-use-linalg-to-spirv-path",
+    llvm::cl::desc(
+        "Flag to use the XLA-HLO to Linalg To SPIR-V pass pipeline."),
+    llvm::cl::init(false));
+
 VulkanSPIRVTargetOptions getVulkanSPIRVTargetOptionsFromFlags() {
   VulkanSPIRVTargetOptions targetOptions;
   // TODO(benvanik): flags.
@@ -140,7 +147,11 @@
 
     // Lower module to spirv::ModuleOp.
     PassManager conversionPassManager(moduleOp.getContext());
-    addIREEToSPIRVPasses(conversionPassManager);
+    if (useLinalgPathForCodegen) {
+      addLowerToSPIRVPasses(conversionPassManager);
+    } else {
+      addIREEToSPIRVPasses(conversionPassManager);
+    }
     if (failed(conversionPassManager.run(moduleOp))) {
       return moduleOp.emitError() << "failed to run conversion passes";
     }
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
index d45ecaf..d787c2e 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -20,18 +20,26 @@
 cc_library(
     name = "LinalgToSPIRV",
     srcs = [
-        "LinalgToSPIRV.cpp",
+        "LowerToSPIRV.cpp",
+    ],
+    hdrs = [
+        "LowerToSPIRV.h",
     ],
     deps = [
+        "//iree/compiler/Translation/XLAToLinalg:IREELinalgTensorToBuffer",
+        "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:AffineOps",
         "@llvm-project//mlir:AffineToStandardTransforms",
+        "@llvm-project//mlir:EDSC",
+        "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToSPIRVTransforms",
         "@llvm-project//mlir:GPUTransforms",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
-        "@llvm-project//mlir:LoopsToGPUPass",
+        "@llvm-project//mlir:LoopOps",
+        "@llvm-project//mlir:LoopsToGPU",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SPIRVDialect",
         "@llvm-project//mlir:SPIRVLowering",
@@ -39,6 +47,8 @@
         "@llvm-project//mlir:StandardToSPIRVConversions",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
+        "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+        "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
     ],
     alwayslink = 1,
 )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
index 20f8392..6a13ae2 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -17,17 +17,23 @@
 iree_cc_library(
   NAME
     LinalgToSPIRV
+  HDRS
+    "LowerToSPIRV.h"
   SRCS
-    "LinalgToSPIRV.cpp"
+    "LowerToSPIRV.cpp"
   DEPS
+    iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer
+    iree::compiler::Utils
     LLVMSupport
     MLIRAffineOps
     MLIRAffineToStandard
+    MLIREDSC
     MLIRGPUtoSPIRVTransforms
     MLIRGPU
     MLIRIR
     MLIRLinalgOps
     MLIRLinalgTransforms
+    MLIRLoopOps
     MLIRLoopsToGPU
     MLIRPass
     MLIRSPIRV
@@ -35,6 +41,7 @@
     MLIRStandardToSPIRVTransforms
     MLIRSupport
     MLIRTransforms
+    tensorflow::mlir_xla
   ALWAYSLINK
   PUBLIC
 )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp
deleted file mode 100644
index 62e1986..0000000
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ /dev/null
@@ -1,104 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- LinalgToSPIRV.cpp - Linalg dialect to SPIR-V dialect----------------===//
-//
-// Implementation of conversion from Linalg To SPIRV
-//
-//===----------------------------------------------------------------------===//
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
-#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
-#include "mlir/Dialect/GPU/Passes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Pass/PassOptions.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-struct LinalgToSPIRVPassOptions
-    : public PassPipelineOptions<LinalgToSPIRVPassOptions> {
-  ListOption<int64_t> numWorkGroups{
-      *this, "num-workgroups",
-      llvm::cl::desc(
-          "Number of workgroups in the SPIR-V module for x, followed by y, "
-          "followed by z dimension of the dispatch (others will be ignored)"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<int64_t> workGroupSize{
-      *this, "workgroup-size",
-      llvm::cl::desc(
-          "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
-          "by z dimension of the dispatch (others will be ignored)"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-};
-}  // namespace
-
-static void addLinalgToSPIRVPasses(OpPassManager &pm,
-                                   const LinalgToSPIRVPassOptions &options) {
-  // TODO(ravishankarm): For now only evaluated with 2D tiling. So set the
-  // workgroup size and numworkgroups to size 2
-  SmallVector<int64_t, 2> numWorkGroups, workGroupSize;
-  numWorkGroups.assign(options.numWorkGroups.begin(),
-                       options.numWorkGroups.end());
-  numWorkGroups.resize(2, 1);
-  workGroupSize.assign(options.workGroupSize.begin(),
-                       options.workGroupSize.end());
-  workGroupSize.resize(2, 1);
-
-  // Linalg to loops.
-  pm.addPass(createLinalgTilingPass(workGroupSize));
-  pm.addPass(createConvertLinalgToLoopsPass());
-  pm.addPass(createLowerAffinePass());
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
-
-  // Loops to GPU.
-  pm.addPass(createLoopToGPUPass(numWorkGroups, workGroupSize));
-  pm.addPass(createGpuKernelOutliningPass());
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
-  pm.addPass(createLowerAffinePass());
-
-  // GPU to SPIR-V.
-  pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
-  pm.addPass(createConvertGPUToSPIRVPass(workGroupSize));
-
-  // SPIR-V passes for lowering attributes.
-  OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>();
-  spirvModulePM.addPass(spirv::createLowerABIAttributesPass());
-  spirvModulePM.addPass(createCanonicalizerPass());
-  spirvModulePM.addPass(createCSEPass());
-}
-
-static PassPipelineRegistration<LinalgToSPIRVPassOptions> linalgToSPIRVPipeline(
-    "iree-linalg-to-spirv",
-    "Runs the progressive lowering pipeline from Linalg to SPIR-V",
-    [](OpPassManager &passManager, const LinalgToSPIRVPassOptions &options) {
-      addLinalgToSPIRVPasses(passManager, options);
-    });
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
new file mode 100644
index 0000000..e84fd48
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -0,0 +1,289 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LowerToSPIRV.cpp - Lower from XLA to Linalg to SPIR-V dialect-------===//
+//
+// Implementation of conversion from XLA-HLO to Linalg to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h"
+#include "iree/compiler/Utils/IREECodegenUtils.h"
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+/// These options are only for testing purposes. For actual execution with IREE,
+/// these are computed by IREE/Backends automatically.
+struct WorkGroupOptions : public PassPipelineOptions<WorkGroupOptions> {
+  ListOption<int64_t> workGroupSize{
+      *this, "workgroup-size",
+      llvm::cl::desc(
+          "Number of workgroups to dispatch for the SPIR-V module; at most "
+          "three integers standarding for the x, y, and z dimension; "
+          "additional arguments will be ignored (used only for testing)"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+}  // namespace
+
+static DenseIntElementsAttr getDenseIntElementsAttrVal(
+    Builder *builder, ArrayRef<int64_t> value) {
+  SmallVector<int32_t, 3> vector;
+  vector.reserve(3);
+  for (auto val : value) {
+    vector.emplace_back(val);
+  }
+  vector.resize(3, 1);
+  return builder->getI32VectorAttr(vector);
+}
+
+/// Helper function to create a std.constant of index type to initialize the
+/// workgroup size as a SSA value.
+static void createConstantsInFunc(FuncOp funcOp, ArrayRef<int64_t> intVal,
+                                  SmallVectorImpl<Value> &constVal) {
+  OpBuilder builder(funcOp.getBody());
+  MLIRContext *context = funcOp.getContext();
+  for (auto val : intVal) {
+    constVal.push_back(builder.create<ConstantOp>(
+        funcOp.getLoc(), IntegerAttr::get(IndexType::get(context), val)));
+  }
+}
+
+namespace {
+
+/// To be able to use the workgroup size from the dispatch function attribute
+/// within the linalg tiling pass, need to actually implement a pass to retrieve
+/// the attribute value from the function and pass it along.
+// TODO(ravishankarm): Move this into Linalg dialect.
+struct IREETileLinalgPass : public FunctionPass<IREETileLinalgPass> {
+  void runOnFunction() override {
+    FuncOp funcOp = getFunction();
+    SmallVector<int64_t, 3> workGroupSize;
+    workGroupSize.reserve(3);
+    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) {
+      return;
+    }
+    OpBuilder builder(funcOp);
+    OperationFolder folder(funcOp.getContext());
+    funcOp.walk([&workGroupSize, &builder, &folder](linalg::LinalgOp op) {
+      if (!op.hasBufferSemantics()) {
+        return;
+      }
+      SmallVector<int64_t, 3> tileSizes;
+      auto nLoops = op.getNumLoops();
+      tileSizes.assign(workGroupSize.begin(), workGroupSize.end());
+      // Linalg convention is to use 0 for no tiling. If the workgroup size is
+      // 1, then dont tile along that dimension. So overriding 1 to 0.
+      for (auto &tileSize : tileSizes) {
+        if (tileSize == 1) tileSize = 0;
+      }
+      tileSizes.resize(nLoops, 0);
+      auto tiledOp = linalg::tileLinalgOp(builder, op, tileSizes, {}, &folder);
+      if (tiledOp) {
+        op.erase();
+      }
+    });
+  }
+};
+
+/// To be able to use the workgroup size from the dispatch function attribute to
+/// convert loops to GPU kernel, need to actually implement a pass to retrieve
+/// the attribute value from the function and pass it along.
+// TODO(ravishankarm): Structure the Loops to GPU pass in MLIR so that we dont
+// have to do this. Maybe make it an OpPassBase<loop::ForOp> ?
+struct LoopsToGPUPass : public FunctionPass<LoopsToGPUPass> {
+  void runOnFunction() override {
+    // Get the workgroup size from the attributes.
+    FuncOp funcOp = getFunction();
+    SmallVector<int64_t, 3> workGroupSize;
+    workGroupSize.reserve(3);
+    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) {
+      return;
+    }
+    // TODO(ravishankarm): Currently evaluating only 2D tiling. Generalize this.
+    workGroupSize.resize(2);
+    // The Loop To GPU pass expects the numWorkGroups only to create the
+    // host-side launch operation. We don't care about that, so just pass {1, 1,
+    // 1} for that.
+    SmallVector<int64_t, 3> numWorkGroups(workGroupSize.size(), 1);
+    SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal;
+    numWorkGroupsVal.reserve(3);
+    workGroupSizeVal.reserve(3);
+    createConstantsInFunc(funcOp, numWorkGroups, numWorkGroupsVal);
+    createConstantsInFunc(funcOp, workGroupSize, workGroupSizeVal);
+    for (Block &block : getFunction()) {
+      for (Operation &op : llvm::make_early_inc_range(block)) {
+        if (auto forOp = dyn_cast<loop::ForOp>(&op)) {
+          if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
+                                            workGroupSizeVal))) {
+            return signalPassFailure();
+          }
+        }
+      }
+    }
+  }
+};
+
+/// To be able to use the workgroup size from the dispatch function attribute to
+/// convert GPU kernel into SPIR-V kernel, need to actually implement a pass to
+/// retrieve the attribute value from the function and pass it along.
+// TODO(ravishankarm): Move this into MLIR core.
+struct IREEGPUToSPIRVPass : public ModulePass<IREEGPUToSPIRVPass> {
+  void runOnModule() {
+    MLIRContext *context = &getContext();
+    ModuleOp moduleOp = getModule();
+    FuncOp funcOp = nullptr;
+    auto walkResult = moduleOp.walk([&funcOp](FuncOp fOp) -> WalkResult {
+      if (fOp.getAttr("iree.executable.export")) {
+        if (funcOp) {
+          return WalkResult::interrupt();
+        }
+        funcOp = fOp;
+      }
+      return WalkResult::advance();
+    });
+    if (!funcOp || walkResult.wasInterrupted()) {
+      moduleOp.emitError("expected a single dispatch function within module");
+      return signalPassFailure();
+    }
+    SmallVector<Operation *, 1> kernelModules;
+    OpBuilder builder(context);
+    builder.setInsertionPoint(funcOp.getOperation());
+
+    // Clone the GPU module into the funcop to convert into a SPIR-V module.
+    funcOp.walk(
+        [&builder, &moduleOp, &kernelModules](gpu::LaunchFuncOp gpuLaunchOp) {
+          auto kernelModuleName = gpuLaunchOp.getKernelModuleName();
+          auto gpuModuleOp =
+              moduleOp.lookupSymbol<gpu::GPUModuleOp>(kernelModuleName);
+          kernelModules.push_back(builder.clone(*gpuModuleOp.getOperation()));
+        });
+    SPIRVTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    SmallVector<int64_t, 3> workGroupSize;
+    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) {
+      return;
+    }
+    populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
+    populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+
+    std::unique_ptr<ConversionTarget> target =
+        spirv::SPIRVConversionTarget::get(
+            spirv::lookupTargetEnvOrDefault(funcOp), context);
+    target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType());
+    });
+
+    if (failed(applyFullConversion(kernelModules, *target, patterns,
+                                   &typeConverter))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+/// Pass to override the workgroup_size attribute value of a dispatch function.
+// TODO(ravishankarm): Use a more cohorent strategy than just setting it to {2,
+// 2}.
+struct UpdateWorkGroupSizePass : FunctionPass<UpdateWorkGroupSizePass> {
+  UpdateWorkGroupSizePass(ArrayRef<int64_t> workGroupSize)
+      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+  void runOnFunction() {
+    FuncOp funcOp = getFunction();
+    if (!funcOp.getAttr("iree.executable.export")) {
+      return;
+    }
+    if (workGroupSize.empty()) {
+      workGroupSize = {2, 2};
+    }
+    workGroupSize.resize(3, 1);
+    OpBuilder builder(&getContext());
+    funcOp.setAttr("iree.executable.workgroup_size",
+                   getDenseIntElementsAttrVal(&builder, workGroupSize));
+  }
+
+ private:
+  SmallVector<int64_t, 3> workGroupSize;
+};
+}  // namespace
+
+static void addLinalgToSPIRVPasses(OpPassManager &pm) {
+  // Linalg to loops.
+  pm.addPass(std::make_unique<IREETileLinalgPass>());
+  pm.addPass(createConvertLinalgToLoopsPass());
+  pm.addPass(createLowerAffinePass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  pm.addPass(std::make_unique<LoopsToGPUPass>());
+  pm.addPass(createGpuKernelOutliningPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+  pm.addPass(createLowerAffinePass());
+
+  // GPU to SPIR-V.
+  pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+  pm.addPass(std::make_unique<IREEGPUToSPIRVPass>());
+
+  // SPIR-V passes for lowering attributes.
+  OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>();
+  spirvModulePM.addPass(spirv::createLowerABIAttributesPass());
+  spirvModulePM.addPass(createCanonicalizerPass());
+  spirvModulePM.addPass(createCSEPass());
+}
+
+void addLowerToSPIRVPasses(OpPassManager &pm, ArrayRef<int64_t> workGroupSize) {
+  pm.addPass(xla_hlo::createLegalizeHloToLinalgPass());
+  pm.addPass(createLinalgTensorToBufferConversionPass());
+  pm.addPass(std::make_unique<UpdateWorkGroupSizePass>(workGroupSize));
+  addLinalgToSPIRVPasses(pm);
+}
+
+static PassPipelineRegistration<WorkGroupOptions> xlaToLinalgSPIRVPipeline(
+    "iree-xla-to-linalg-to-spirv",
+    "Runs the progressive lowering pipeline from XLA HLO to Linalg to SPIR-V",
+    [](OpPassManager &passManager, const WorkGroupOptions &options) {
+      SmallVector<int64_t, 2> workGroupSize;
+      workGroupSize.assign(options.workGroupSize.begin(),
+                           options.workGroupSize.end());
+      addLowerToSPIRVPasses(passManager, workGroupSize);
+    });
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h
new file mode 100644
index 0000000..8dc6c25
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.h
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_LOWERTOSPIRV_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_LOWERTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect.
+void addLowerToSPIRVPasses(OpPassManager &pm,
+                           ArrayRef<int64_t> workGroupSize = {});
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_TRANSLATION_SPIRV_XLATOLINALGSPIRV_LOWERTOSPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
index 1d8d07c..f63a472 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -27,5 +27,6 @@
     data = [
         "//iree/tools:IreeFileCheck",
         "//iree/tools:iree-opt",
+        "//iree/tools:iree-run-mlir",
     ],
 )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
index 2fafb16..e10a865 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
@@ -16,8 +16,10 @@
   NAME
     lit
   SRCS
-    "single_pw_op.mlir"
+    "pw_add.mlir"
+    "pw_add_e2e.mlir"
   DATA
     iree::tools::IreeFileCheck
     iree::tools::iree-opt
+    iree::tools::iree-run-mlir
 )
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir
new file mode 100644
index 0000000..95812a6
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add.mlir
@@ -0,0 +1,13 @@
+// RUN: iree-opt -pass-pipeline='iree-xla-to-linalg-to-spirv' %s | IreeFileCheck %s
+
+module {
+  func @simple_load_store(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>)
+  attributes  {iree.executable.export, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[2, 2, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    // CHECK: spv.module
+    %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32>
+    %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32>
+    %2 = "xla_hlo.add"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+    iree.store_output(%2 : tensor<4x8xi32>, %arg2 : memref<4x8xi32>)
+    iree.return
+  }
+}
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir
deleted file mode 100644
index d875f14..0000000
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: iree-opt -pass-pipeline='iree-linalg-to-spirv{workgroup-size=2,2 num-workgroups=2,2}' %s
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-
-module {
-  func @fmul(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
-    linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):   // no predecessors
-      %0 = mulf %arg3, %arg4 : f32
-      linalg.yield %0 : f32
-    }: memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32>
-    return
-  }
-}
diff --git a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp
index 9ba3583..1d0d3f9 100644
--- a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp
+++ b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.cpp
@@ -22,6 +22,17 @@
 namespace iree_compiler {
 namespace {
 
+/// Remove IREE::LoadInputOp operations
+struct RemoveLoadInputOpPattern : OpConversionPattern<IREE::LoadInputOp> {
+  using OpConversionPattern<IREE::LoadInputOp>::OpConversionPattern;
+  PatternMatchResult matchAndRewrite(
+      IREE::LoadInputOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOp(op, op.getOperand());
+    return matchSuccess();
+  }
+};
+
 /// Convert from a linalg.generic on tensors to linalg.generic on buffers. In
 /// IREE it is expected that each dispatch region will become a single
 /// linalg.generic op on tensors (after XLA-HLO -> Linalg conversion and
@@ -41,11 +52,11 @@
 };
 
 /// Remove IREE::StoreOutputOp operations.
-struct RemoveDeadStorePattern : OpConversionPattern<IREE::StoreOutputOp> {
+struct RemoveStoreOutputOpPattern : OpConversionPattern<IREE::StoreOutputOp> {
   using OpConversionPattern<IREE::StoreOutputOp>::OpConversionPattern;
   PatternMatchResult matchAndRewrite(
       IREE::StoreOutputOp op, ArrayRef<Value> operands,
-      ConversionPatternRewriter &rewriter) const {
+      ConversionPatternRewriter &rewriter) const override {
     rewriter.eraseOp(op);
     return matchSuccess();
   }
@@ -61,7 +72,6 @@
     return matchSuccess();
   }
 };
-
 }  // namespace
 
 PatternMatchResult LinalgTensorToBufferConverter::matchAndRewrite(
@@ -69,17 +79,7 @@
     ConversionPatternRewriter &rewriter) const {
   // TODO(ravishankarm): Find a way to write this using Matchers, but need to
   // figure out how to match operations with variadic operands.
-  SmallVector<Value, 2> memrefArgs;
-  for (auto arg : op.getOperands()) {
-    if (!arg.getType().isa<RankedTensorType>()) {
-      return matchFailure();
-    }
-    auto definingOp = dyn_cast_or_null<IREE::LoadInputOp>(arg.getDefiningOp());
-    if (!definingOp) {
-      return matchFailure();
-    }
-    memrefArgs.push_back(definingOp.getOperand());
-  }
+  SmallVector<Value, 2> memrefArgs(operands.begin(), operands.end());
   // For result, check that there is a single use in an iree::store_output op.
   for (auto result : op.getResults()) {
     if (!result.hasOneUse()) {
@@ -121,8 +121,9 @@
 
 void populateLinalgTensorToBufferConversionPattern(
     MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<LinalgTensorToBufferConverter, RemoveDeadStorePattern,
-                  IREEReturnOpLowering>(context);
+  patterns.insert<IREEReturnOpLowering, LinalgTensorToBufferConverter,
+                  RemoveLoadInputOpPattern, RemoveStoreOutputOpPattern>(
+      context);
 }
 
 struct LinalgTensorToBufferConversionPass
@@ -132,7 +133,6 @@
     MLIRContext *context = &getContext();
     ConversionTarget target(*context);
     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
-    target.addLegalOp<IREE::LoadInputOp>();
     target.addLegalOp<FuncOp>();
     target.addDynamicallyLegalOp<linalg::GenericOp>([&](linalg::GenericOp op) {
       return llvm::all_of(op.getOperands(),
diff --git a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h
index 60e5246..c1be661 100644
--- a/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h
+++ b/iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
index c254864..a76ef85 100644
--- a/iree/compiler/Utils/BUILD
+++ b/iree/compiler/Utils/BUILD
@@ -35,6 +35,7 @@
         "//iree/compiler/Dialect/IREE/IR",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt
index d8db64d..c361f20 100644
--- a/iree/compiler/Utils/CMakeLists.txt
+++ b/iree/compiler/Utils/CMakeLists.txt
@@ -27,6 +27,7 @@
     iree::compiler::Dialect::IREE::IR
     LLVMSupport
     MLIRIR
+    MLIRPass
     MLIRStandardOps
     MLIRSupport
     MLIRTransformUtils
diff --git a/iree/compiler/Utils/IREECodegenUtils.cpp b/iree/compiler/Utils/IREECodegenUtils.cpp
index 98bd878..b0b1f3b 100644
--- a/iree/compiler/Utils/IREECodegenUtils.cpp
+++ b/iree/compiler/Utils/IREECodegenUtils.cpp
@@ -48,8 +48,9 @@
 }
 
 /// Gets the workgroup size.
+template <typename intType>
 LogicalResult getLegacyWorkGroupSize(FuncOp funcOp,
-                                     SmallVectorImpl<int32_t> &workGroupSize) {
+                                     SmallVectorImpl<intType> &workGroupSize) {
   if (!funcOp.getAttr("iree.executable.export")) {
     return funcOp.emitError(
         "expected operation to be in dispatch function to get launch size");
@@ -68,5 +69,10 @@
   return success();
 }
 
+template LogicalResult getLegacyWorkGroupSize<int32_t>(
+    FuncOp funcOp, SmallVectorImpl<int32_t> &workGroupSize);
+template LogicalResult getLegacyWorkGroupSize<int64_t>(
+    FuncOp funcOp, SmallVectorImpl<int64_t> &workGroupSize);
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Utils/IREECodegenUtils.h b/iree/compiler/Utils/IREECodegenUtils.h
index 3983156..908de04 100644
--- a/iree/compiler/Utils/IREECodegenUtils.h
+++ b/iree/compiler/Utils/IREECodegenUtils.h
@@ -29,10 +29,10 @@
 LogicalResult getLegacyLaunchSize(FuncOp funcOp,
                                   SmallVectorImpl<int64_t> &launchSize);
 
-// TODO(ravishankarm): remove this; it is not safe for variable sizes.
-/// Gets the workgroup size.
+/// Gets the workgroup size. Has to be a static constant.
+template <typename intType>
 LogicalResult getLegacyWorkGroupSize(FuncOp funcOp,
-                                     SmallVectorImpl<int32_t> &workGroupSize);
+                                     SmallVectorImpl<intType> &workGroupSize);
 
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index f6ddae0..67b08c8 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -117,8 +117,8 @@
       iree::compiler::Dialect::VM::Transforms
       iree::compiler::Translation::Interpreter::Transforms
       iree::compiler::Translation::IREEVM
-      iree::compiler::Translation::SPIRV::XLAToSPIRV
       iree::compiler::Translation::SPIRV::LinalgToSPIRV
+      iree::compiler::Translation::SPIRV::XLAToSPIRV
       iree::compiler::Translation::XLAToLinalg
       iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer
       LLVMSupport