Add a primitive Linalg TileAndDistributeOnTensors pass. (#4105)
The TileAndDistributeOnTensors pass performs tiling and distribution to a 3-D grid of processors identified by their WorkgroupIdOp in a range of [0, WorkgroupSizeOp).
This results in a parametric tiling that currently does not allow further canonicalizations.
PiperOrigin-RevId: 345629018
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 0ef31cb..bcdfc24 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -25,6 +25,7 @@
"ConvertToLLVM.cpp",
"KernelDispatch.cpp",
"LinalgRewriteDestructiveUpdatesPass.cpp",
+ "LinalgTileAndDistributeOnTensorsPass.cpp",
"LinalgTileAndDistributePass.cpp",
"LinalgTileAndVectorizePass.cpp",
"Passes.cpp",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 46d3787..199d55e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -25,6 +25,7 @@
"ConvertToLLVM.cpp"
"KernelDispatch.cpp"
"LinalgRewriteDestructiveUpdatesPass.cpp"
+ "LinalgTileAndDistributeOnTensorsPass.cpp"
"LinalgTileAndDistributePass.cpp"
"LinalgTileAndVectorizePass.cpp"
"Passes.cpp"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
new file mode 100644
index 0000000..33b05e1
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
@@ -0,0 +1,138 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-tile-and-distribute-on-tensors"
+
+namespace mlir {
+namespace iree_compiler {
+
+struct LinalgTileAndDistributeOnTensorsPass
+ : public PassWrapper<LinalgTileAndDistributeOnTensorsPass,
+ OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, IREEDialect, AffineDialect,
+ scf::SCFDialect>();
+ }
+ LinalgTileAndDistributeOnTensorsPass() = default;
+ LinalgTileAndDistributeOnTensorsPass(
+ const LinalgTileAndDistributeOnTensorsPass &pass) {}
+ void runOnOperation() override;
+
+ private:
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+
+static std::pair<Value, Value> buildWorkgroupOpPair(OpBuilder &b,
+ StringRef dim) {
+ Type indexType = b.getIndexType();
+ StringAttr attr = b.getStringAttr(dim);
+ return {b.create<IREE::WorkgroupIdOp>(b.getInsertionPoint()->getLoc(),
+ indexType, attr),
+ b.create<IREE::WorkgroupSizeOp>(b.getInsertionPoint()->getLoc(),
+ indexType, attr)};
+}
+
+// Rewrite pattern to ensure only ops with tensor semantics are tiled.
+struct TileAndDistributeOnTensorsPattern
+ : public linalg::LinalgBaseTilingPattern {
+ using Base = linalg::LinalgBaseTilingPattern;
+ TileAndDistributeOnTensorsPattern(linalg::LinalgTilingOptions options,
+ linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : Base(options, marker, benefit) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure();
+ SmallVector<Value, 4> tensorResults;
+ if (failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)))
+ return failure();
+ // TODO: Wrap in sequentialized SPMD loops.
+ rewriter.replaceOp(op, tensorResults);
+ return success();
+ }
+};
+
+void LinalgTileAndDistributeOnTensorsPass::runOnOperation() {
+ if (tileSizes.empty()) return;
+ ModuleOp module = getOperation();
+ MLIRContext *context = module->getContext();
+
+ // Distribution strategy along at most 3 dimensions with WorkgroupIdOp in
+ // range [0, WorkgroupSizeOp).
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ // TODO: drop magic names.
+ std::array<StringRef, 3> dimStrs{"x", "y", "z"};
+ auto numParallelDims = parallelLoopRanges.size();
+ SmallVector<linalg::ProcInfo, 2> procInfo(numParallelDims);
+ for (unsigned dim = 0; dim < std::min(numParallelDims, 3ul); ++dim) {
+ auto p = buildWorkgroupOpPair(builder, dimStrs[dim]);
+ procInfo[dim] = {p.first, p.second};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic}};
+
+ for (FuncOp funcOp : module.getOps<FuncOp>()) {
+ // TODO: maybe activate when put in a real pipeline.
+ // if (!isEntryPoint(funcOp)) continue;
+
+ OwningRewritePatternList patterns;
+ auto linalgTilingOptions =
+ linalg::LinalgTilingOptions()
+ .setDistributionOptions(workgroupDistributionOptions)
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizes(ArrayRef<int64_t>(tileSizes));
+ assert(linalgTilingOptions.distribution.hasValue());
+
+ // In the future, derive from LinalgTilingPattern to create sequentialized
+ // SPMD loops.
+ patterns.insert<TileAndDistributeOnTensorsPattern>(
+ linalgTilingOptions,
+ linalg::LinalgMarker(ArrayRef<Identifier>(),
+ Identifier::get(getWorkgroupMarker(), context)));
+ // Add canonicalization patterns.
+ linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context);
+ patterns.insert<AffineMinCanonicalizationPattern>(context);
+ applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createLinalgTileAndDistributeOnTensorsPass() {
+ return std::make_unique<LinalgTileAndDistributeOnTensorsPass>();
+}
+
+static PassRegistration<LinalgTileAndDistributeOnTensorsPass> pass(
+ "iree-codegen-llvm-linalg-tile-and-distribute-on-tensors",
+ "Tile and distribute Linalg operations on tensors",
+ [] { return std::make_unique<LinalgTileAndDistributeOnTensorsPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 09edd29..42036b6 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -50,6 +50,11 @@
std::unique_ptr<OperationPass<FuncOp>>
createLinalgRewriteDestructiveUpdatesPass();
+/// Pass to perform tiling and distribution of Linalg ops with tensor semantics
+/// to sequentialized SPMD loops.
+std::unique_ptr<OperationPass<ModuleOp>>
+createLinalgTileAndDistributeOnTensorsPass();
+
/// Populates passes needed to lower a XLA HLO op to LLVM dialect via the
/// structured ops path. The pass manager `pm` in here should operate on the
/// module within the IREE::HAL::ExecutableOp.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir
new file mode 100644
index 0000000..b869782
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir
@@ -0,0 +1,41 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-codegen-llvm-linalg-tile-and-distribute-on-tensors=tile-sizes="1,2" %s | IreeFileCheck %s
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (2, -d0 + 4)>
+
+// CHECK-LABEL: func @tensor
+func @tensor() -> tensor<2x4xf32> {
+ %A = iree.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+ %B = iree.unfoldable_constant dense<[[1.0, 2.0, 3.0, 4.0],
+ [5.0, 6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32>
+ %C = iree.unfoldable_constant dense<1000.0> : tensor<2x4xf32>
+
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = constant 4 : index
+ // CHECK-DAG: %[[bix:.*]] = iree.workgroup_id {dimension = "x"} : index
+ // CHECK-DAG: %[[bdx:.*]] = iree.workgoup_size {dimension = "x"} : index
+ // CHECK-DAG: %[[biy:.*]] = iree.workgroup_id {dimension = "y"} : index
+ // CHECK-DAG: %[[bdy:.*]] = iree.workgoup_size {dimension = "y"} : index
+ // CHECK: %{{.*}} = scf.for %[[I:.*]] = %[[bix]] to %[[C2]] step %[[bdx]] iter_args(%arg1 = %2) -> (tensor<2x4xf32>) {
+ // CHECK-NEXT: %[[biy_scaled:.*]] = muli %[[biy]], %[[C2]] : index
+ // CHECK-NEXT: %[[bdy_scaled:.*]] = muli %[[bdy]], %[[C2]] : index
+ // CHECK-NEXT: %{{.*}} = scf.for %[[J:.*]] = %[[biy_scaled]] to %[[C4]] step %[[bdy_scaled]] iter_args(%arg3 = %arg1) -> (tensor<2x4xf32>) {
+ // CHECK-NEXT: subtensor %{{.*}}[%[[I]], 0] [1, 3] [1, 1] : tensor<2x3xf32> to tensor<1x3xf32>
+ //
+ // Canonicalizations not yet powerful enough here.
+ // CHECK-NEXT: %[[J_slice_1:.*]] = affine.min #[[$MAP]](%[[J]])
+ // CHECK-NEXT: subtensor %1[0, %[[J]]] [3, %[[J_slice_1]]] [1, 1] : tensor<3x4xf32> to tensor<3x?xf32>
+ //
+ // Canonicalizations not yet powerful enough here.
+ // CHECK-NEXT: %[[J_slice_2:.*]] = affine.min #[[$MAP]](%[[J]])
+ // CHECK-NEXT: subtensor %arg3[%[[I]], %[[J]]] [1, %[[J_slice_2]]] [1, 1] : tensor<2x4xf32> to tensor<1x?xf32>
+ // CHECK-NEXT: linalg.matmul
+ // CHECK-NEXT: subtensor_insert {{.*}} : tensor<1x?xf32> into tensor<2x4xf32>
+ // CHECK-NEXT: scf.yield %{{.*}} : tensor<2x4xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.yield %{{.*}} : tensor<2x4xf32>
+ %E = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
+ init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+ return %E : tensor<2x4xf32>
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 6157e00..e250550 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -66,6 +66,7 @@
// LinalgToLLVM
createConvImg2ColMatmulConversionPass();
createLinalgTileAndDistributePass();
+ createLinalgTileAndDistributeOnTensorsPass();
createLinalgTileAndVectorizeWorkgroupsPass();
createLinalgRewriteDestructiveUpdatesPass();
return true;