Add pass to tile and fuse linalg operations.
The current tiling pass is incorporated as a pattern. So if there is
nothing to fuse, the op is just tiled. The fallback of executing the
op sequentially is also incorprated as a pattern of this pass. The
pass added here also updates the workgroup size when the tiling
pattern succeeds. This makes the UpdateWorkGroupSizePass unnecessary.
PiperOrigin-RevId: 303810983
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
index 8de2527..18abb32 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -20,7 +20,8 @@
cc_library(
name = "LinalgToSPIRV",
srcs = [
- "GPUKernelOutlining.cpp",
+ "GPUKernelOutliningPass.cpp",
+ "LinalgTileAndFusePass.cpp",
"LowerToSPIRV.cpp",
],
hdrs = [
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
index 370c17b..dfbf6c3 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -21,7 +21,8 @@
"LowerToSPIRV.h"
"Passes.h"
SRCS
- "GPUKernelOutlining.cpp"
+ "GPUKernelOutliningPass.cpp"
+ "LinalgTileAndFusePass.cpp"
"LowerToSPIRV.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
similarity index 97%
rename from iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp
rename to iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
index 65b8d6f..f520b68 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutlining.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/GPUKernelOutliningPass.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//===- GPUKernelOutlining.cpp - Generate GPU device-side code -------------===//
+//===- GPUKernelOutliningPass.cpp - Generate GPU device-side code ---------===//
//
// Implements a pass to convert a launch operation into a device-side code. Uses
// a separate pass since the pass from core puts the gpu.module at the module
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp
new file mode 100644
index 0000000..fb290de
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -0,0 +1,293 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LinalgTilingOnBuffers.cpp - Tile and fuse Linalg on Buffers --------===//
+//
+// Implements a pass to tile and fuse linalg operations on buffers.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static StringRef getWorkGroupMarker() { return "spirv_workgroup"; }
+
+static constexpr unsigned kMaxWorkgroupRank = 3;
+
+/// Returns the tile sizes to use by default based on number of dimension of
+/// parallelism.
+static void getDefaultTileSizes(unsigned numDims,
+ SmallVectorImpl<int64_t> &tileSizes) {
+ tileSizes.clear();
+ switch (numDims) {
+ case 0:
+ return;
+ case 1:
+ tileSizes.push_back(32);
+ return;
+ case 2:
+ tileSizes.push_back(4);
+ tileSizes.push_back(32);
+ return;
+ default:
+ break;
+ }
+ tileSizes.push_back(2);
+ tileSizes.push_back(2);
+ tileSizes.push_back(32);
+}
+
+/// Returns the number of "outer" parallel loops specified in the `linalgOp`.
+static unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
+ if (isa<linalg::ConvOp>(linalgOp.getOperation())) return 0;
+ return linalgOp.iterator_types()
+ .getValue()
+ .take_while([](Attribute attr) {
+ return attr.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName();
+ })
+ .size();
+}
+
+/// Returns the tile size to use for a linalg operation by following
+/// `workGroupSize`, if provided, or the default otherwise.
+static void getTileSizes(unsigned numParallelLoops,
+ ArrayRef<int64_t> workGroupSize,
+ SmallVectorImpl<int64_t> &tileSizes) {
+ tileSizes.clear();
+ numParallelLoops = std::min(numParallelLoops, kMaxWorkgroupRank);
+ if (!workGroupSize.empty()) {
+ workGroupSize = dropTrailingOnes(workGroupSize);
+ auto rev = reverse(workGroupSize.take_front(numParallelLoops));
+ tileSizes.assign(rev.begin(), rev.end());
+ tileSizes.resize(numParallelLoops, 0);
+ } else {
+ getDefaultTileSizes(numParallelLoops, tileSizes);
+ }
+ // Linalg convention is to use 0 for no tiling. If the workgroup size is
+ // 1, then dont tile along that dimension. So overriding 1 to 0.
+ for (auto &tileSize : tileSizes)
+ if (tileSize == 1) tileSize = 0;
+}
+
+/// Checks if an operation already has an attribute with this marker. If set it
+/// implies this op shouldnt be tiled with the same marker.
+static bool hasMarker(Operation *op, StringRef marker) {
+ auto tilingAttr = op->getAttrOfType<StringAttr>(
+ linalg::LinalgTransforms::kLinalgTransformMarker);
+ return tilingAttr && tilingAttr.getValue() == marker;
+}
+
+namespace {
+/// Function pass that implements tiling and fusion in Linalg on buffers.
+struct LinalgTileAndFusePass : public FunctionPass<LinalgTileAndFusePass> {
+ LinalgTileAndFusePass(ArrayRef<int64_t> workGroupSize = {})
+ : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+ void runOnFunction() override;
+
+ private:
+ SmallVector<int64_t, 3> workGroupSize;
+};
+
+/// Base class for Linalg tiling patterns. All classes that derive from this
+/// need to implement an apply method that will tile the operation with the
+/// following signature.
+///
+/// LogicalResult apply(LinalgOp op, SmallVectorImpl<int64_t> &tileSizes,
+/// PatternRewriter &rewriter) const
+template <typename DerivedClass, typename LinalgOp>
+struct LinalgTilingPattern : public OpRewritePattern<LinalgOp> {
+ LinalgTilingPattern(MLIRContext *context, ArrayRef<int64_t> tileSizes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<LinalgOp>(context, benefit), tileSizes(tileSizes) {}
+
+ LogicalResult matchAndRewrite(LinalgOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ if (!linalgOp.hasBufferSemantics()) return failure();
+ // Currently we are only doing one-level tiling, so a single marker is
+ // enough. This might need to move into derived classes.
+ if (hasMarker(linalgOp.getOperation(), getWorkGroupMarker()))
+ return failure();
+
+ if (failed(static_cast<const DerivedClass *>(this)->apply(
+ linalgOp, tileSizes, rewriter)))
+ return failure();
+ rewriter.eraseOp(linalgOp);
+ return success();
+ }
+
+ private:
+ ArrayRef<int64_t> tileSizes;
+};
+
+/// If the linalg op has no outer parallel loops, inserts dummy one-trip loops
+/// around it to execute it sequentially within a thread.
+template <typename LinalgOp>
+struct SequentialExecutionPattern
+ : public LinalgTilingPattern<SequentialExecutionPattern<LinalgOp>,
+ LinalgOp> {
+ using LinalgTilingPattern<SequentialExecutionPattern<LinalgOp>,
+ LinalgOp>::LinalgTilingPattern;
+ LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+ PatternRewriter &rewriter) const {
+ if (!tileSizes.empty()) return failure();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto indexType = rewriter.getIndexType();
+ auto loc = linalgOp.getLoc();
+ auto zero =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(indexType, 0));
+ auto one =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(indexType, 1));
+ auto outerLoop = rewriter.create<loop::ForOp>(loc, zero, one, one);
+ rewriter.setInsertionPoint(outerLoop.getBody(),
+ std::prev(outerLoop.getBody()->end()));
+ auto innerLoop = rewriter.create<loop::ForOp>(loc, zero, one, one);
+ rewriter.setInsertionPoint(innerLoop.getBody(),
+ std::prev(innerLoop.getBody()->end()));
+ Operation *clonedOp = rewriter.clone(*linalgOp.getOperation());
+ clonedOp->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(getWorkGroupMarker()));
+ return success();
+ }
+};
+
+/// If there is nothing to fuse the linalg op with, then just tiles it.
+template <typename LinalgOp>
+struct TileLinalgOpPattern
+ : public LinalgTilingPattern<TileLinalgOpPattern<LinalgOp>, LinalgOp> {
+ using LinalgTilingPattern<TileLinalgOpPattern<LinalgOp>,
+ LinalgOp>::LinalgTilingPattern;
+ LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+ PatternRewriter &rewriter) const {
+ // Check that all input and outputs have a single use (this op). In that
+ // case, there is nothing to tile and fuse with. So just tile it.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(linalgOp.getOperation());
+ if (!llvm::all_of(linalgOp.getInputsAndOutputBuffers(),
+ [](Value arg) { return arg.hasOneUse(); }))
+ return failure();
+ return linalg::tileLinalgOpAndSetMarker(rewriter, linalgOp.getOperation(),
+ tileSizes, getWorkGroupMarker(),
+ /*permutation=*/{});
+ }
+};
+
+/// Tile and fuse linalg operations.
+template <typename LinalgOp>
+struct TileAndFuseLinalgOpPattern
+ : public LinalgTilingPattern<TileAndFuseLinalgOpPattern<LinalgOp>,
+ LinalgOp> {
+ using LinalgTilingPattern<TileAndFuseLinalgOpPattern<LinalgOp>,
+ LinalgOp>::LinalgTilingPattern;
+ LogicalResult apply(LinalgOp linalgOp, ArrayRef<int64_t> tileSizes,
+ PatternRewriter &rewriter) const {
+ SmallVector<int64_t, 1> operandIndicesToFuse;
+ for (auto buffer : llvm::enumerate(linalgOp.getInputsAndOutputBuffers())) {
+ // If a buffer has multiple uses, then it is a candidate for fusion.
+ if (!buffer.value().hasOneUse())
+ operandIndicesToFuse.push_back(buffer.index());
+ }
+ if (operandIndicesToFuse.empty()) return failure();
+ return linalg::tileAndFuseLinalgOpAndSetMarker(
+ rewriter, linalgOp, tileSizes, operandIndicesToFuse,
+ getWorkGroupMarker());
+ }
+};
+} // namespace
+
+void LinalgTileAndFusePass::runOnFunction() {
+ MLIRContext *context = &getContext();
+ FuncOp funcOp = getFunction();
+ if (!isDispatchFunction(funcOp)) return;
+
+ Region &body = funcOp.getBody();
+ // Only handle single block functions.
+ if (body.getBlocks().size() != 1) {
+ funcOp.emitError("unhandled dispatch function with multiple blocks");
+ return signalPassFailure();
+ }
+ Block &block = body.front();
+ auto linalgOps = block.getOps<linalg::LinalgOp>();
+ if (linalgOps.empty()) return;
+
+ // Compute the minimum number of outer parallel loops across linalg
+ // operations. This gives the dimensionality of tiling to be used .
+ unsigned numParallelLoops = kMaxWorkgroupRank;
+ for (linalg::LinalgOp op : linalgOps)
+ numParallelLoops = std::min(numParallelLoops, getNumOuterParallelLoops(op));
+
+ // Get the tile sizes to use for the lowering.
+ SmallVector<int64_t, 3> tileSizes;
+ getTileSizes(numParallelLoops, workGroupSize, tileSizes);
+
+ OwningRewritePatternList patterns;
+ patterns.insert<SequentialExecutionPattern<linalg::ConvOp>,
+ SequentialExecutionPattern<linalg::GenericOp>,
+ SequentialExecutionPattern<linalg::IndexedGenericOp>,
+ TileLinalgOpPattern<linalg::GenericOp>,
+ TileLinalgOpPattern<linalg::IndexedGenericOp>,
+ TileLinalgOpPattern<linalg::MatmulOp>,
+ TileAndFuseLinalgOpPattern<linalg::GenericOp>>(context,
+ tileSizes);
+ applyPatternsGreedily(getOperation(), patterns);
+
+ // Check that there are single perfectly nested loop.for operations at the top
+ // most level that will get mapped to thread blocks/workgroups.
+ auto forLoops = block.getOps<loop::ForOp>();
+ if (!mlir::has_single_element(forLoops)) {
+ funcOp.emitError(
+ "unable to fuse operations within a dispatch region to get a single "
+ "outer parallel loop nest to map to workgroups");
+ return signalPassFailure();
+ }
+ // TODO(ravishankarm): Also need to check that the loops are perfectly nested
+ // and that there as many as numParallelLoops. That check is more involved, so
+ // come back to it after moving to loop.parallel at which point the check is
+ // just a check of the number of induction variables.
+
+ // Update the workgroup size to be consistent with the tile sizes used. Note
+ // the tile sizes are ordered from outer most to inner most loops. The
+ // heuristic is to map the inner loops to x, the next outer (if it exists) to
+ // y, and the next outer (if it exists) to z. So tile sizes are reversed to
+ // get the workgroup size.
+ SmallVector<int64_t, 3> updatedWorkGroupSize(reverse(tileSizes));
+ updatedWorkGroupSize.resize(3, 1);
+ auto attrs = functional::map(
+ [&context](int64_t v) -> Attribute {
+ return IntegerAttr::get(IndexType::get(context), v);
+ },
+ updatedWorkGroupSize);
+ // TODO(b/150312935): Switch to update the HAL interface directly.
+ funcOp.setAttr("iree.executable.workgroup_size",
+ ArrayAttr::get(attrs, context));
+}
+
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgTileAndFusePass(
+ ArrayRef<int64_t> workGroupSize) {
+ return std::make_unique<LinalgTileAndFusePass>(workGroupSize);
+}
+
+static PassRegistration<LinalgTileAndFusePass> pass(
+ "iree-linalg-tile-and-fuse", "Tile and fuse Linalg operations on buffers",
+ [] { return std::make_unique<LinalgTileAndFusePass>(); });
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
index ea020a3..c4bde7b 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -77,99 +77,8 @@
}
}
-/// Gets the number of outer parallel loops in a linalg operation.
-unsigned getNumOuterParallelLoops(linalg::LinalgOp linalgOp) {
- // Find the number of leading parallel loops in the generic op
- unsigned numOuterParallelLoops = 0;
- for (auto iteratorType : linalgOp.iterator_types()) {
- if (iteratorType.cast<StringAttr>().getValue() !=
- getParallelIteratorTypeName()) {
- break;
- }
- numOuterParallelLoops++;
- }
- return numOuterParallelLoops;
-}
-
namespace {
-/// To be able to use the workgroup size from the dispatch function attribute
-/// within the linalg tiling pass, need to actually implement a pass to retrieve
-/// the attribute value from the function and pass it along.
-// TODO(ravishankarm): Move this into Linalg dialect.
-struct IREETileLinalgPass : public FunctionPass<IREETileLinalgPass> {
- void runOnFunction() override {
- FuncOp funcOp = getFunction();
- SmallVector<int64_t, 3> workGroupSizeVec;
- workGroupSizeVec.reserve(3);
- if (failed(getWorkGroupSize(funcOp, workGroupSizeVec))) return;
- ArrayRef<int64_t> workGroupSize = dropTrailingOnes(workGroupSizeVec);
-
- OpBuilder builder(funcOp);
- OperationFolder folder(funcOp.getContext());
- Region &body = funcOp.getBody();
- if (!mlir::has_single_element(body)) {
- funcOp.emitError(
- "unhandled dispatch function that doesn't have a single block");
- return signalPassFailure();
- }
- auto linalgOps = body.front().getOps<linalg::LinalgOp>();
- if (linalgOps.empty()) {
- // Nothing to do. Return.
- return;
- }
- if (!mlir::has_single_element(linalgOps)) {
- funcOp.emitError(
- "unhandled tiling multiple linalg ops in one dispatch function");
- return signalPassFailure();
- }
- linalg::LinalgOp linalgOp = *linalgOps.begin();
- if (!linalgOp.hasBufferSemantics()) {
- linalgOp.emitError(
- "expected linalg op with buffer semantics during SPIR-V "
- "code generation");
- return signalPassFailure();
- }
-
- // TODO(ravishankarm): Tile conv op.
- bool isConvOp = isa<linalg::ConvOp>(linalgOp.getOperation());
- unsigned numOuterParallelLoops = getNumOuterParallelLoops(linalgOp);
- if (isConvOp || !numOuterParallelLoops) {
- // There are no outer parallel loops to partition. So just create dummy
- // 1-trip loops that will be "split" across workgroups and workitems.
- builder.setInsertionPoint(linalgOp);
- auto indexType = builder.getIndexType();
- auto loc = linalgOp.getLoc();
- auto zero =
- builder.create<ConstantOp>(loc, builder.getIntegerAttr(indexType, 0));
- auto one =
- builder.create<ConstantOp>(loc, builder.getIntegerAttr(indexType, 1));
- auto outerLoop = builder.create<loop::ForOp>(loc, zero, one, one);
- OpBuilder outerLoopBuilder = outerLoop.getBodyBuilder();
- auto innerLoop =
- outerLoopBuilder.create<loop::ForOp>(loc, zero, one, one);
- OpBuilder innerLoopBuilder = innerLoop.getBodyBuilder();
- innerLoopBuilder.clone(*linalgOp.getOperation());
- linalgOp.erase();
- return;
- }
-
- // Tile sizes to use are reverse of the workGroupSize.
- SmallVector<int64_t, 3> tileSizes(reverse(workGroupSize));
- // Linalg convention is to use 0 for no tiling. If the workgroup size is
- // 1, then dont tile along that dimension. So overriding 1 to 0.
- for (auto &tileSize : tileSizes)
- if (tileSize == 1) tileSize = 0;
- tileSizes.resize(numOuterParallelLoops, 0);
- if (linalg::tileLinalgOp(builder, linalgOp, tileSizes, {}, &folder)) {
- linalgOp.erase();
- return;
- }
- linalgOp.emitError("unable to tile linalg op for SPIR-V code generation");
- return signalPassFailure();
- }
-};
-
/// To be able to use the workgroup size from the dispatch function attribute to
/// convert loops to GPU kernel, need to actually implement a pass to retrieve
/// the attribute value from the function and pass it along.
@@ -242,72 +151,12 @@
}
}
};
-
-/// Pass to override the workgroup_size attribute value of a dispatch function.
-struct UpdateWorkGroupSizePass : FunctionPass<UpdateWorkGroupSizePass> {
- UpdateWorkGroupSizePass(ArrayRef<int64_t> workGroupSize)
- : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
- void runOnFunction() {
- FuncOp funcOp = getFunction();
- if (!isDispatchFunction(funcOp)) return;
-
- if (workGroupSize.empty()) {
- // By default look at the number of "parallel" loops in the generic op.
- Region &body = funcOp.getBody();
- // Only handle single block functions.
- if (body.getBlocks().size() != 1) {
- funcOp.emitError("unhandled dispatch function with multiple blocks");
- return signalPassFailure();
- }
- Block &block = body.front();
- auto linalgOps = block.getOps<linalg::LinalgOp>();
- if (linalgOps.empty()) {
- // Nothing to update. Return.
- return;
- }
- if (!mlir::has_single_element(linalgOps)) {
- funcOp.emitError(
- "unhandled updating workgroup size of dispatch function with "
- "multiple linalg ops");
- return signalPassFailure();
- }
- // Find the number of leading parallel loops in the generic op
- unsigned numOuterParallelLoops =
- getNumOuterParallelLoops(*linalgOps.begin());
- workGroupSize.resize(3, 1);
- if (numOuterParallelLoops > 0) {
- workGroupSize[0] = 32;
- }
- if (numOuterParallelLoops > 1) {
- workGroupSize[1] = 4;
- }
- if (numOuterParallelLoops > 2) {
- // Change workGroupsSize[1] such that the total size is equal to 128,
- // which is the minimum gauranteed by Vulkan spec.
- workGroupSize[1] = 2;
- workGroupSize[2] = 2;
- }
- // TODO(ravishankarm): The current code-generation will "serialize" all
- // the inner loops that are more than 3 deep. We can potentially "fold"
- // all the parallel loops so that they all executed on different
- // workitems.
- }
- OpBuilder builder(&getContext());
- assert(workGroupSize.size() == 3);
- funcOp.setAttr("iree.executable.workgroup_size",
- builder.getIndexArrayAttr(workGroupSize));
- }
-
- private:
- SmallVector<int64_t, 3> workGroupSize;
-};
} // namespace
void addLinalgToSPIRVPasses(OpPassManager &pm,
ArrayRef<int64_t> workGroupSize) {
// Linalg to loops.
- pm.addPass(std::make_unique<UpdateWorkGroupSizePass>(workGroupSize));
- pm.addPass(std::make_unique<IREETileLinalgPass>());
+ pm.addPass(createLinalgTileAndFusePass(workGroupSize));
pm.addPass(createConvertLinalgToLoopsPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
index ef1f419..58fd328 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
@@ -23,6 +23,15 @@
/// Pass to get gpu.module from a gpu.launch operation.
std::unique_ptr<OpPassBase<ModuleOp>> createIREEGpuKernelOutliningPass();
+/// Pass to tile and fuse linalg operations on buffers. The pass takes as
+/// argument the `workgroupSize` that the tiling should use. Note that the
+/// tile-sizes are the reverse of the workgroup size. So workgroup size along
+/// "x" is used to tile the innermost loop, along "y" for the next innermost (if
+/// it exists) and along "z" for the next loop (if it exists). The workgroup
+/// size is expected to be of size at-most 3.
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgTileAndFusePass(
+ ArrayRef<int64_t> workGroupSize = {});
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
new file mode 100644
index 0000000..1d8d07c
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
new file mode 100644
index 0000000..17d53b9
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
new file mode 100644
index 0000000..2cd316a
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -0,0 +1,123 @@
+// RUN: iree-opt -split-input-file -iree-linalg-tile-and-fuse %s | IreeFileCheck %s
+
+module {
+ // CHECK-LABEL: func @tile_only
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<4x8xi32>
+ // CHECK: loop.for
+ // CHECK: loop.for
+ // CHECK: %[[VIEW0:.*]] = subview %[[ARG0]]
+ // CHECK: %[[VIEW1:.*]] = subview %[[ARG1]]
+ // CHECK: %[[VIEW2:.*]] = subview %[[ARG2]]
+ // CHECK: linalg.generic
+ // CHECK-SAME: %[[VIEW0]]
+ // CHECK-SAME: %[[VIEW1]]
+ // CHECK-SAME: %[[VIEW2]]
+ // CHECK: return
+ func @tile_only(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
+ %arg2: memref<4x8xi32>)
+ attributes
+ {iree.executable.export,
+ iree.executable.workgroup_size = dense<[32, 4, 1]> : vector<3xi32>} {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
+ %0 = addi %arg3, %arg4 : i32
+ linalg.yield %0 : i32
+ }: memref<4x8xi32>, memref<4x8xi32>, memref<4x8xi32>
+ return
+ }
+}
+
+// -----
+
+module {
+ // CHECK-LABEL: func @sequential
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<10xi32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<i32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<i32>
+ func @sequential(%arg0: memref<10xi32>, %arg1: memref<i32>,
+ %arg2: memref<i32>)
+ attributes
+ {iree.executable.export,
+ iree.executable.workgroup_size = dense<1> : vector<3xi32>} {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: loop.for %{{.*}} = %[[C0]] to %[[C1]]
+ // CHECK: loop.for %{{.*}} = %[[C0]] to %[[C1]]
+ // CHECK: linalg.indexed_generic
+ // CHECK-SAME: %[[ARG0]]
+ // CHECK-SAME: %[[ARG1]]
+ // CHECK-SAME: %[[ARG2]]
+ linalg.indexed_generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: index, %arg4: i32, %arg5: i32, %arg6: i32):
+ %c0 = constant 0 : index
+ %cst = constant true
+ %0 = cmpi "eq", %arg3, %c0 : index
+ %1 = and %cst, %0 : i1
+ %2 = select %1, %arg5, %arg6 : i32
+ %3 = addi %arg4, %2 : i32
+ linalg.yield %3 : i32
+ }: memref<10xi32>, memref<i32>, memref<i32>
+ return
+ }
+}
+
+// -----
+
+module {
+ // CHECK-LABEL: func @tile_and_fuse
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+ // CHECK: loop.for
+ // CHECK: loop.for
+ // CHECK-DAG: %[[VIEW0:.*]] = subview %[[ARG0]]
+ // CHECK-DAG: %[[VIEW1:.*]] = subview %[[ARG1]]
+ // CHECK-DAG: %[[VIEW2READ:.*]] = subview %[[ARG2]]
+ // CHECK-DAG: %[[VIEW2WRITE:.*]] = subview %[[ARG2]]
+ // CHECK-DAG: %[[VIEW3:.*]] = subview %[[ARG3]]
+ // CHECK: linalg.generic
+ // CHECK-SAME: %[[VIEW0]]
+ // CHECK-SAME: %[[VIEW1]]
+ // CHECK-SAME: %[[VIEW2WRITE]]
+ // CHECK: linalg.generic
+ // CHECK-SAME: %[[VIEW2READ]]
+ // CHECK-SAME: %[[VIEW3]]
+ func @tile_and_fuse(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+ %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>)
+ attributes
+ {iree.executable.export,
+ iree.executable.workgroup_size = dense<[32, 4, 1]> : vector<3xi32>} {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
+ %0 = addf %arg4, %arg5 : f32
+ linalg.yield %0 : f32
+ }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ linalg.generic
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]} %arg2, %arg3 {
+ ^bb0(%arg7: f32, %arg8: f32):
+ %1 = mulf %arg7, %arg7 : f32
+ linalg.yield %1 : f32
+ }: memref<?x?xf32>, memref<?x?xf32>
+ return
+ }
+}