Add pass to tile and fuse linalg operations.

The current tiling pass is incorporated as a pattern. So if there is
nothing to fuse, the op is just tiled. The fallback of executing the
op sequentially is also incorprated as a pattern of this pass. The
pass added here also updates the workgroup size when the tiling
pattern succeeds. This makes the UpdateWorkGroupSizePass unnecessary.

PiperOrigin-RevId: 303810983
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
index 8de2527..18abb32 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -20,7 +20,8 @@
 cc_library(
     name = "LinalgToSPIRV",
     srcs = [
-        "GPUKernelOutlining.cpp",
+        "GPUKernelOutliningPass.cpp",
+        "LinalgTileAndFusePass.cpp",
         "LowerToSPIRV.cpp",
     ],
     hdrs = [
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
index 370c17b..dfbf6c3 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -21,7 +21,8 @@
     "LowerToSPIRV.h"
     "Passes.h"
   SRCS
-    "GPUKernelOutlining.cpp"
+    "GPUKernelOutliningPass.cpp"
+    "LinalgTileAndFusePass.cpp"
     "LowerToSPIRV.cpp"
   DEPS
     LLVMSupport
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
similarity index 97%
rename from iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp
rename to iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
index 65b8d6f..f520b68 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-//===- GPUKernelOutlining.cpp - Generate GPU device-side code -------------===//
+//===- GPUKernelOutliningPass.cpp - Generate GPU device-side code ---------===//
 //
 // Implements a pass to convert a launch operation into a device-side code. Uses
 // a separate pass since the pass from core puts the gpu.module at the module
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp
new file mode 100644
index 0000000..fb290de
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -0,0 +1,293 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LinalgTilingOnBuffers.cpp - Tile and fuse Linalg on Buffers --------===//
+//
+// Implements a pass to tile and fuse linalg operations on buffers.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static StringRef getWorkGroupMarker() { return "spirv_workgroup"; }
+
+static constexpr unsigned kMaxWorkgroupRank = 3;
+
+/// Returns the tile sizes to use by default based on number of dimension of
+/// parallelism.
+static void getDefaultTileSizes(unsigned numDims,
+                                SmallVectorImpl<int64_t> &tileSizes) {
+  tileSizes.clear();
+  switch (numDims) {
+    case 0:
+      return;
+    case 1:
+      tileSizes.push_back(32);
+      return;
+    case 2:
+      tileSizes.push_back(4);
+      tileSizes.push_back(32);
+      return;
+    default:
+      break;
+  }
+  tileSizes.push_back(2);
+  tileSizes.push_back(2);
+  tileSizes.push_back(32);
+}
+
+/// Returns the number of "outer" parallel loops specified in the `linalgOp`.
+static unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
+  if (isa<linalg::ConvOp>(linalgOp.getOperation())) return 0;
+  return linalgOp.iterator_types()
+      .getValue()
+      .take_while([](Attribute attr) {
+        return attr.cast<StringAttr>().getValue() ==
+               getParallelIteratorTypeName();
+      })
+      .size();
+}
+
+/// Returns the tile size to use for a linalg operation by following
+/// `workGroupSize`, if provided, or the default otherwise.
+static void getTileSizes(unsigned numParallelLoops,
+                         ArrayRef<int64_t> workGroupSize,
+                         SmallVectorImpl<int64_t> &tileSizes) {
+  tileSizes.clear();
+  numParallelLoops = std::min(numParallelLoops, kMaxWorkgroupRank);
+  if (!workGroupSize.empty()) {
+    workGroupSize = dropTrailingOnes(workGroupSize);
+    auto rev = reverse(workGroupSize.take_front(numParallelLoops));
+    tileSizes.assign(rev.begin(), rev.end());
+    tileSizes.resize(numParallelLoops, 0);
+  } else {
+    getDefaultTileSizes(numParallelLoops, tileSizes);
+  }
+  // Linalg convention is to use 0 for no tiling. If the workgroup size is
+  // 1, then dont tile along that dimension. So overriding 1 to 0.
+  for (auto &tileSize : tileSizes)
+    if (tileSize == 1) tileSize = 0;
+}
+
+/// Checks if an operation already has an attribute with this marker. If set it
+/// implies this op shouldnt be tiled with the same marker.
+static bool hasMarker(Operation *op, StringRef marker) {
+  auto tilingAttr = op->getAttrOfType<StringAttr>(
+      linalg::LinalgTransforms::kLinalgTransformMarker);
+  return tilingAttr && tilingAttr.getValue() == marker;
+}
+
+namespace {
+/// Function pass that implements tiling and fusion in Linalg on buffers.
+struct LinalgTileAndFusePass : public FunctionPass<LinalgTileAndFusePass> {
+  LinalgTileAndFusePass(ArrayRef<int64_t> workGroupSize = {})
+      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+  void runOnFunction() override;
+
+ private:
+  SmallVector<int64_t, 3> workGroupSize;
+};
+
+/// Base class for Linalg tiling patterns. All classes that derive from this
+/// need to implement an apply method that will tile the operation with the
+/// following signature.
+///
+/// LogicalResult apply(LinalgOp op, SmallVectorImpl<int64_t> &tileSizes,
+///                     PatternRewriter &rewriter) const
+template <typename DerivedClass, typename LinalgOp>
+struct LinalgTilingPattern : public OpRewritePattern<LinalgOp> {
+  LinalgTilingPattern(MLIRContext *context, ArrayRef<int64_t> tileSizes,
+                      PatternBenefit benefit = 1)
+      : OpRewritePattern<LinalgOp>(context, benefit), tileSizes(tileSizes) {}
+
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (!linalgOp.hasBufferSemantics()) return failure();
+    // Currently we are only doing one-level tiling, so a single marker is
+    // enough. This might need to move into derived classes.
+    if (hasMarker(linalgOp.getOperation(), getWorkGroupMarker()))
+      return failure();
+
+    if (failed(static_cast<const DerivedClass *>(this)->apply(
+            linalgOp, tileSizes, rewriter)))
+      return failure();
+    rewriter.eraseOp(linalgOp);
+    return success();
+  }
+
+ private:
+  ArrayRef<int64_t> tileSizes;
+};
+
+/// If the linalg op has no outer parallel loops, inserts dummy one-trip loops
+/// around it to execute it sequentially within a thread.
+template <typename LinalgOp>
+struct SequentialExecutionPattern
+    : public LinalgTilingPattern<SequentialExecutionPattern<LinalgOp>,
+                                 LinalgOp> {
+  using LinalgTilingPattern<SequentialExecutionPattern<LinalgOp>,
+                            LinalgOp>::LinalgTilingPattern;
+  LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+                      PatternRewriter &rewriter) const {
+    if (!tileSizes.empty()) return failure();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto indexType = rewriter.getIndexType();
+    auto loc = linalgOp.getLoc();
+    auto zero =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(indexType, 0));
+    auto one =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(indexType, 1));
+    auto outerLoop = rewriter.create<loop::ForOp>(loc, zero, one, one);
+    rewriter.setInsertionPoint(outerLoop.getBody(),
+                               std::prev(outerLoop.getBody()->end()));
+    auto innerLoop = rewriter.create<loop::ForOp>(loc, zero, one, one);
+    rewriter.setInsertionPoint(innerLoop.getBody(),
+                               std::prev(innerLoop.getBody()->end()));
+    Operation *clonedOp = rewriter.clone(*linalgOp.getOperation());
+    clonedOp->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+                      rewriter.getStringAttr(getWorkGroupMarker()));
+    return success();
+  }
+};
+
+/// If there is nothing to fuse the linalg op with, then just tiles it.
+template <typename LinalgOp>
+struct TileLinalgOpPattern
+    : public LinalgTilingPattern<TileLinalgOpPattern<LinalgOp>, LinalgOp> {
+  using LinalgTilingPattern<TileLinalgOpPattern<LinalgOp>,
+                            LinalgOp>::LinalgTilingPattern;
+  LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+                      PatternRewriter &rewriter) const {
+    // Check that all input and outputs have a single use (this op). In that
+    // case, there is nothing to tile and fuse with. So just tile it.
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(linalgOp.getOperation());
+    if (!llvm::all_of(linalgOp.getInputsAndOutputBuffers(),
+                      [](Value arg) { return arg.hasOneUse(); }))
+      return failure();
+    return linalg::tileLinalgOpAndSetMarker(rewriter, linalgOp.getOperation(),
+                                            tileSizes, getWorkGroupMarker(),
+                                            /*permutation=*/{});
+  }
+};
+
+/// Tile and fuse linalg operations.
+template <typename LinalgOp>
+struct TileAndFuseLinalgOpPattern
+    : public LinalgTilingPattern<TileAndFuseLinalgOpPattern<LinalgOp>,
+                                 LinalgOp> {
+  using LinalgTilingPattern<TileAndFuseLinalgOpPattern<LinalgOp>,
+                            LinalgOp>::LinalgTilingPattern;
+  LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+                      PatternRewriter &rewriter) const {
+    SmallVector<int64_t, 1> operandIndicesToFuse;
+    for (auto buffer : llvm::enumerate(linalgOp.getInputsAndOutputBuffers())) {
+      // If a buffer has multiple uses, then it is a candidate for fusion.
+      if (!buffer.value().hasOneUse())
+        operandIndicesToFuse.push_back(buffer.index());
+    }
+    if (operandIndicesToFuse.empty()) return failure();
+    return linalg::tileAndFuseLinalgOpAndSetMarker(
+        rewriter, linalgOp, tileSizes, operandIndicesToFuse,
+        getWorkGroupMarker());
+  }
+};
+}  // namespace
+
+void LinalgTileAndFusePass::runOnFunction() {
+  MLIRContext *context = &getContext();
+  FuncOp funcOp = getFunction();
+  if (!isDispatchFunction(funcOp)) return;
+
+  Region &body = funcOp.getBody();
+  // Only handle single block functions.
+  if (body.getBlocks().size() != 1) {
+    funcOp.emitError("unhandled dispatch function with multiple blocks");
+    return signalPassFailure();
+  }
+  Block &block = body.front();
+  auto linalgOps = block.getOps<linalg::LinalgOp>();
+  if (linalgOps.empty()) return;
+
+  // Compute the minimum number of outer parallel loops across linalg
+  // operations. This gives the dimensionality of tiling to be used .
+  unsigned numParallelLoops = kMaxWorkgroupRank;
+  for (linalg::LinalgOp op : linalgOps)
+    numParallelLoops = std::min(numParallelLoops, getNumOuterParallelLoops(op));
+
+  // Get the tile sizes to use for the lowering.
+  SmallVector<int64_t, 3> tileSizes;
+  getTileSizes(numParallelLoops, workGroupSize, tileSizes);
+
+  OwningRewritePatternList patterns;
+  patterns.insert<SequentialExecutionPattern<linalg::ConvOp>,
+                  SequentialExecutionPattern<linalg::GenericOp>,
+                  SequentialExecutionPattern<linalg::IndexedGenericOp>,
+                  TileLinalgOpPattern<linalg::GenericOp>,
+                  TileLinalgOpPattern<linalg::IndexedGenericOp>,
+                  TileLinalgOpPattern<linalg::MatmulOp>,
+                  TileAndFuseLinalgOpPattern<linalg::GenericOp>>(context,
+                                                                 tileSizes);
+  applyPatternsGreedily(getOperation(), patterns);
+
+  // Check that there are single perfectly nested loop.for operations at the top
+  // most level that will get mapped to thread blocks/workgroups.
+  auto forLoops = block.getOps<loop::ForOp>();
+  if (!mlir::has_single_element(forLoops)) {
+    funcOp.emitError(
+        "unable to fuse operations within a dispatch region to get a single "
+        "outer parallel loop nest to map to workgroups");
+    return signalPassFailure();
+  }
+  // TODO(ravishankarm): Also need to check that the loops are perfectly nested
+  // and that there as many as numParallelLoops. That check is more involved, so
+  // come back to it after moving to loop.parallel at which point the check is
+  // just a check of the number of induction variables.
+
+  // Update the workgroup size to be consistent with the tile sizes used. Note
+  // the tile sizes are ordered from outer most to inner most loops. The
+  // heuristic is to map the inner loops to x, the next outer (if it exists) to
+  // y, and the next outer (if it exists) to z. So tile sizes are reversed to
+  // get the workgroup size.
+  SmallVector<int64_t, 3> updatedWorkGroupSize(reverse(tileSizes));
+  updatedWorkGroupSize.resize(3, 1);
+  auto attrs = functional::map(
+      [&context](int64_t v) -> Attribute {
+        return IntegerAttr::get(IndexType::get(context), v);
+      },
+      updatedWorkGroupSize);
+  // TODO(b/150312935): Switch to update the HAL interface directly.
+  funcOp.setAttr("iree.executable.workgroup_size",
+                 ArrayAttr::get(attrs, context));
+}
+
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgTileAndFusePass(
+    ArrayRef<int64_t> workGroupSize) {
+  return std::make_unique<LinalgTileAndFusePass>(workGroupSize);
+}
+
+static PassRegistration<LinalgTileAndFusePass> pass(
+    "iree-linalg-tile-and-fuse", "Tile and fuse Linalg operations on buffers",
+    [] { return std::make_unique<LinalgTileAndFusePass>(); });
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
index ea020a3..c4bde7b 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -77,99 +77,8 @@
   }
 }
 
-/// Gets the number of outer parallel loops in a linalg operation.
-unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
-  // Find the number of leading parallel loops in the generic op
-  unsigned numOuterParallelLoops = 0;
-  for (auto iteratorType : linalgOp.iterator_types()) {
-    if (iteratorType.cast<StringAttr>().getValue() !=
-        getParallelIteratorTypeName()) {
-      break;
-    }
-    numOuterParallelLoops++;
-  }
-  return numOuterParallelLoops;
-}
-
 namespace {
 
-/// To be able to use the workgroup size from the dispatch function attribute
-/// within the linalg tiling pass, need to actually implement a pass to retrieve
-/// the attribute value from the function and pass it along.
-// TODO(ravishankarm): Move this into Linalg dialect.
-struct IREETileLinalgPass : public FunctionPass<IREETileLinalgPass> {
-  void runOnFunction() override {
-    FuncOp funcOp = getFunction();
-    SmallVector<int64_t, 3> workGroupSizeVec;
-    workGroupSizeVec.reserve(3);
-    if (failed(getWorkGroupSize(funcOp, workGroupSizeVec))) return;
-    ArrayRef<int64_t> workGroupSize = dropTrailingOnes(workGroupSizeVec);
-
-    OpBuilder builder(funcOp);
-    OperationFolder folder(funcOp.getContext());
-    Region &body = funcOp.getBody();
-    if (!mlir::has_single_element(body)) {
-      funcOp.emitError(
-          "unhandled dispatch function that doesn't have a single block");
-      return signalPassFailure();
-    }
-    auto linalgOps = body.front().getOps<linalg::LinalgOp>();
-    if (linalgOps.empty()) {
-      // Nothing to do. Return.
-      return;
-    }
-    if (!mlir::has_single_element(linalgOps)) {
-      funcOp.emitError(
-          "unhandled tiling multiple linalg ops in one dispatch function");
-      return signalPassFailure();
-    }
-    linalg::LinalgOp linalgOp = *linalgOps.begin();
-    if (!linalgOp.hasBufferSemantics()) {
-      linalgOp.emitError(
-          "expected linalg op with buffer semantics during SPIR-V "
-          "code generation");
-      return signalPassFailure();
-    }
-
-    // TODO(ravishankarm): Tile conv op.
-    bool isConvOp = isa<linalg::ConvOp>(linalgOp.getOperation());
-    unsigned numOuterParallelLoops = getNumOuterParallelLoops(linalgOp);
-    if (isConvOp || !numOuterParallelLoops) {
-      // There are no outer parallel loops to partition. So just create dummy
-      // 1-trip loops that will be "split" across workgroups and workitems.
-      builder.setInsertionPoint(linalgOp);
-      auto indexType = builder.getIndexType();
-      auto loc = linalgOp.getLoc();
-      auto zero =
-          builder.create<ConstantOp>(loc, builder.getIntegerAttr(indexType, 0));
-      auto one =
-          builder.create<ConstantOp>(loc, builder.getIntegerAttr(indexType, 1));
-      auto outerLoop = builder.create<loop::ForOp>(loc, zero, one, one);
-      OpBuilder outerLoopBuilder = outerLoop.getBodyBuilder();
-      auto innerLoop =
-          outerLoopBuilder.create<loop::ForOp>(loc, zero, one, one);
-      OpBuilder innerLoopBuilder = innerLoop.getBodyBuilder();
-      innerLoopBuilder.clone(*linalgOp.getOperation());
-      linalgOp.erase();
-      return;
-    }
-
-    // Tile sizes to use are reverse of the workGroupSize.
-    SmallVector<int64_t, 3> tileSizes(reverse(workGroupSize));
-    // Linalg convention is to use 0 for no tiling. If the workgroup size is
-    // 1, then dont tile along that dimension. So overriding 1 to 0.
-    for (auto &tileSize : tileSizes)
-      if (tileSize == 1) tileSize = 0;
-    tileSizes.resize(numOuterParallelLoops, 0);
-    if (linalg::tileLinalgOp(builder, linalgOp, tileSizes, {}, &folder)) {
-      linalgOp.erase();
-      return;
-    }
-    linalgOp.emitError("unable to tile linalg op for SPIR-V code generation");
-    return signalPassFailure();
-  }
-};
-
 /// To be able to use the workgroup size from the dispatch function attribute to
 /// convert loops to GPU kernel, need to actually implement a pass to retrieve
 /// the attribute value from the function and pass it along.
@@ -242,72 +151,12 @@
     }
   }
 };
-
-/// Pass to override the workgroup_size attribute value of a dispatch function.
-struct UpdateWorkGroupSizePass : FunctionPass<UpdateWorkGroupSizePass> {
-  UpdateWorkGroupSizePass(ArrayRef<int64_t> workGroupSize)
-      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
-  void runOnFunction() {
-    FuncOp funcOp = getFunction();
-    if (!isDispatchFunction(funcOp)) return;
-
-    if (workGroupSize.empty()) {
-      // By default look at the number of "parallel" loops in the generic op.
-      Region &body = funcOp.getBody();
-      // Only handle single block functions.
-      if (body.getBlocks().size() != 1) {
-        funcOp.emitError("unhandled dispatch function with multiple blocks");
-        return signalPassFailure();
-      }
-      Block &block = body.front();
-      auto linalgOps = block.getOps<linalg::LinalgOp>();
-      if (linalgOps.empty()) {
-        // Nothing to update. Return.
-        return;
-      }
-      if (!mlir::has_single_element(linalgOps)) {
-        funcOp.emitError(
-            "unhandled updating workgroup size of dispatch function with "
-            "multiple linalg ops");
-        return signalPassFailure();
-      }
-      // Find the number of leading parallel loops in the generic op
-      unsigned numOuterParallelLoops =
-          getNumOuterParallelLoops(*linalgOps.begin());
-      workGroupSize.resize(3, 1);
-      if (numOuterParallelLoops > 0) {
-        workGroupSize[0] = 32;
-      }
-      if (numOuterParallelLoops > 1) {
-        workGroupSize[1] = 4;
-      }
-      if (numOuterParallelLoops > 2) {
-        // Change workGroupsSize[1] such that the total size is equal to 128,
-        // which is the minimum gauranteed by Vulkan spec.
-        workGroupSize[1] = 2;
-        workGroupSize[2] = 2;
-      }
-      // TODO(ravishankarm): The current code-generation will "serialize" all
-      // the inner loops that are more than 3 deep. We can potentially "fold"
-      // all the parallel loops so that they all executed on different
-      // workitems.
-    }
-    OpBuilder builder(&getContext());
-    assert(workGroupSize.size() == 3);
-    funcOp.setAttr("iree.executable.workgroup_size",
-                   builder.getIndexArrayAttr(workGroupSize));
-  }
-
- private:
-  SmallVector<int64_t, 3> workGroupSize;
-};
 }  // namespace
 
 void addLinalgToSPIRVPasses(OpPassManager &pm,
                             ArrayRef<int64_t> workGroupSize) {
   // Linalg to loops.
-  pm.addPass(std::make_unique<UpdateWorkGroupSizePass>(workGroupSize));
-  pm.addPass(std::make_unique<IREETileLinalgPass>());
+  pm.addPass(createLinalgTileAndFusePass(workGroupSize));
   pm.addPass(createConvertLinalgToLoopsPass());
   pm.addPass(createLowerAffinePass());
   pm.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
index ef1f419..58fd328 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
@@ -23,6 +23,15 @@
 /// Pass to get gpu.module from a gpu.launch operation.
 std::unique_ptr<OpPassBase<ModuleOp>> createIREEGpuKernelOutliningPass();
 
+/// Pass to tile and fuse linalg operations on buffers. The pass takes as
+/// argument the `workgroupSize` that the tiling should use. Note that the
+/// tile-sizes are the reverse of the workgroup size. So workgroup size along
+/// "x" is used to tile the innermost loop, along "y" for the next innermost (if
+/// it exists) and along "z" for the next loop (if it exists). The workgroup
+/// size is expected to be of size at-most 3.
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgTileAndFusePass(
+    ArrayRef<int64_t> workGroupSize = {});
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
new file mode 100644
index 0000000..1d8d07c
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
new file mode 100644
index 0000000..17d53b9
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-opt
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
new file mode 100644
index 0000000..2cd316a
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -0,0 +1,123 @@
+// RUN: iree-opt -split-input-file -iree-linalg-tile-and-fuse %s | IreeFileCheck %s
+
+module {
+  // CHECK-LABEL: func @tile_only
+  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+  //       CHECK: loop.for
+  //       CHECK:   loop.for
+  //       CHECK:     %[[VIEW0:.*]] = subview %[[ARG0]]
+  //       CHECK:     %[[VIEW1:.*]] = subview %[[ARG1]]
+  //       CHECK:     %[[VIEW2:.*]] = subview %[[ARG2]]
+  //       CHECK:     linalg.generic
+  //  CHECK-SAME:       %[[VIEW0]]
+  //  CHECK-SAME:       %[[VIEW1]]
+  //  CHECK-SAME:       %[[VIEW2]]
+  //       CHECK: return
+  func @tile_only(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
+                  %arg2: memref<4x8xi32>)
+  attributes
+    {iree.executable.export,
+     iree.executable.workgroup_size = dense<[32, 4, 1]> : vector<3xi32>} {
+    linalg.generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+    ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
+      %0 = addi %arg3, %arg4 : i32
+      linalg.yield %0 : i32
+    }: memref<4x8xi32>, memref<4x8xi32>, memref<4x8xi32>
+    return
+  }
+}
+
+// -----
+
+module {
+  // CHECK-LABEL: func @sequential
+  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<10xi32>
+  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<i32>
+  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<i32>
+  func @sequential(%arg0: memref<10xi32>, %arg1: memref<i32>,
+                   %arg2: memref<i32>)
+  attributes
+    {iree.executable.export,
+     iree.executable.workgroup_size = dense<1> : vector<3xi32>} {
+    //      CHECK: %[[C0:.*]] = constant 0 : index
+    //      CHECK: %[[C1:.*]] = constant 1 : index
+    //      CHECK: loop.for %{{.*}} = %[[C0]] to %[[C1]]
+    //      CHECK:   loop.for %{{.*}} = %[[C0]] to %[[C1]]
+    //      CHECK:     linalg.indexed_generic
+    // CHECK-SAME:       %[[ARG0]]
+    // CHECK-SAME:       %[[ARG1]]
+    // CHECK-SAME:       %[[ARG2]]
+    linalg.indexed_generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>,
+                        affine_map<(d0) -> ()>],
+       iterator_types = ["reduction"]} %arg0, %arg1, %arg2 {
+    ^bb0(%arg3: index, %arg4: i32, %arg5: i32, %arg6: i32):
+      %c0 = constant 0 : index
+      %cst = constant true
+      %0 = cmpi "eq", %arg3, %c0 : index
+      %1 = and %cst, %0 : i1
+      %2 = select %1, %arg5, %arg6 : i32
+      %3 = addi %arg4, %2 : i32
+      linalg.yield %3 : i32
+    }: memref<10xi32>, memref<i32>, memref<i32>
+    return
+  }
+}
+
+// -----
+
+module {
+  // CHECK-LABEL: func @tile_and_fuse
+  //  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+  //  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+  //  CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+  //  CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+  //       CHECK: loop.for
+  //       CHECK:   loop.for
+  //   CHECK-DAG:     %[[VIEW0:.*]] = subview %[[ARG0]]
+  //   CHECK-DAG:     %[[VIEW1:.*]] = subview %[[ARG1]]
+  //   CHECK-DAG:     %[[VIEW2READ:.*]] = subview %[[ARG2]]
+  //   CHECK-DAG:     %[[VIEW2WRITE:.*]] = subview %[[ARG2]]
+  //   CHECK-DAG:     %[[VIEW3:.*]] = subview %[[ARG3]]
+  //       CHECK:     linalg.generic
+  //  CHECK-SAME:       %[[VIEW0]]
+  //  CHECK-SAME:       %[[VIEW1]]
+  //  CHECK-SAME:       %[[VIEW2WRITE]]
+  //       CHECK:     linalg.generic
+  //  CHECK-SAME:       %[[VIEW2READ]]
+  //  CHECK-SAME:       %[[VIEW3]]
+  func @tile_and_fuse(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                      %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>)
+  attributes
+    {iree.executable.export,
+     iree.executable.workgroup_size = dense<[32, 4, 1]> : vector<3xi32>} {
+    linalg.generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+    ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
+      %0 = addf %arg4, %arg5 : f32
+      linalg.yield %0 : f32
+    }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+    linalg.generic
+      {args_in = 1 : i64, args_out = 1 : i64,
+       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]} %arg2, %arg3 {
+    ^bb0(%arg7: f32, %arg8: f32):
+      %1 = mulf %arg7, %arg7 : f32
+      linalg.yield %1 : f32
+    }: memref<?x?xf32>, memref<?x?xf32>
+    return
+  }
+}