Add a primitive Linalg TileAndDistributeOnTensors pass. (#4105)

The TileAndDistributeOnTensors pass performs tiling and distribution to a 3-D grid of processors identified by their WorkgroupIdOp in a range of [0, WorkgroupSizeOp).
This results in a parametric tiling that currently does not allow further canonicalizations.

PiperOrigin-RevId: 345629018
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 0ef31cb..bcdfc24 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -25,6 +25,7 @@
         "ConvertToLLVM.cpp",
         "KernelDispatch.cpp",
         "LinalgRewriteDestructiveUpdatesPass.cpp",
+        "LinalgTileAndDistributeOnTensorsPass.cpp",
         "LinalgTileAndDistributePass.cpp",
         "LinalgTileAndVectorizePass.cpp",
         "Passes.cpp",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 46d3787..199d55e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -25,6 +25,7 @@
     "ConvertToLLVM.cpp"
     "KernelDispatch.cpp"
     "LinalgRewriteDestructiveUpdatesPass.cpp"
+    "LinalgTileAndDistributeOnTensorsPass.cpp"
     "LinalgTileAndDistributePass.cpp"
     "LinalgTileAndVectorizePass.cpp"
     "Passes.cpp"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
new file mode 100644
index 0000000..33b05e1
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributeOnTensorsPass.cpp
@@ -0,0 +1,138 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-tile-and-distribute-on-tensors"
+
+namespace mlir {
+namespace iree_compiler {
+
+struct LinalgTileAndDistributeOnTensorsPass
+    : public PassWrapper<LinalgTileAndDistributeOnTensorsPass,
+                         OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, IREEDialect, AffineDialect,
+                    scf::SCFDialect>();
+  }
+  LinalgTileAndDistributeOnTensorsPass() = default;
+  LinalgTileAndDistributeOnTensorsPass(
+      const LinalgTileAndDistributeOnTensorsPass &pass) {}
+  void runOnOperation() override;
+
+ private:
+  ListOption<int64_t> tileSizes{
+      *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+
+static std::pair<Value, Value> buildWorkgroupOpPair(OpBuilder &b,
+                                                    StringRef dim) {
+  Type indexType = b.getIndexType();
+  StringAttr attr = b.getStringAttr(dim);
+  return {b.create<IREE::WorkgroupIdOp>(b.getInsertionPoint()->getLoc(),
+                                        indexType, attr),
+          b.create<IREE::WorkgroupSizeOp>(b.getInsertionPoint()->getLoc(),
+                                          indexType, attr)};
+}
+
+// Rewrite pattern to ensure only ops with tensor semantics are tiled.
+struct TileAndDistributeOnTensorsPattern
+    : public linalg::LinalgBaseTilingPattern {
+  using Base = linalg::LinalgBaseTilingPattern;
+  TileAndDistributeOnTensorsPattern(linalg::LinalgTilingOptions options,
+                                    linalg::LinalgMarker marker,
+                                    PatternBenefit benefit = 1)
+      : Base(options, marker, benefit) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+    if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure();
+    SmallVector<Value, 4> tensorResults;
+    if (failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)))
+      return failure();
+    // TODO: Wrap in sequentialized SPMD loops.
+    rewriter.replaceOp(op, tensorResults);
+    return success();
+  }
+};
+
+void LinalgTileAndDistributeOnTensorsPass::runOnOperation() {
+  if (tileSizes.empty()) return;
+  ModuleOp module = getOperation();
+  MLIRContext *context = module->getContext();
+
+  // Distribution strategy along at most 3 dimensions with WorkgroupIdOp in
+  // range [0, WorkgroupSizeOp).
+  static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+      [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+        // TODO: drop magic names.
+        std::array<StringRef, 3> dimStrs{"x", "y", "z"};
+        auto numParallelDims = parallelLoopRanges.size();
+        SmallVector<linalg::ProcInfo, 2> procInfo(numParallelDims);
+        for (unsigned dim = 0; dim < std::min(numParallelDims, 3ul); ++dim) {
+          auto p = buildWorkgroupOpPair(builder, dimStrs[dim]);
+          procInfo[dim] = {p.first, p.second};
+        }
+        return procInfo;
+      },
+      {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+       linalg::DistributionMethod::Cyclic}};
+
+  for (FuncOp funcOp : module.getOps<FuncOp>()) {
+    // TODO: maybe activate when put in a real pipeline.
+    // if (!isEntryPoint(funcOp)) continue;
+
+    OwningRewritePatternList patterns;
+    auto linalgTilingOptions =
+        linalg::LinalgTilingOptions()
+            .setDistributionOptions(workgroupDistributionOptions)
+            .setLoopType(linalg::LinalgTilingLoopType::Loops)
+            .setTileSizes(ArrayRef<int64_t>(tileSizes));
+    assert(linalgTilingOptions.distribution.hasValue());
+
+    // In the future, derive from LinalgTilingPattern to create sequentialized
+    // SPMD loops.
+    patterns.insert<TileAndDistributeOnTensorsPattern>(
+        linalgTilingOptions,
+        linalg::LinalgMarker(ArrayRef<Identifier>(),
+                             Identifier::get(getWorkgroupMarker(), context)));
+    // Add canonicalization patterns.
+    linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context);
+    patterns.insert<AffineMinCanonicalizationPattern>(context);
+    applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createLinalgTileAndDistributeOnTensorsPass() {
+  return std::make_unique<LinalgTileAndDistributeOnTensorsPass>();
+}
+
+static PassRegistration<LinalgTileAndDistributeOnTensorsPass> pass(
+    "iree-codegen-llvm-linalg-tile-and-distribute-on-tensors",
+    "Tile and distribute Linalg operations on tensors",
+    [] { return std::make_unique<LinalgTileAndDistributeOnTensorsPass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 09edd29..42036b6 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -50,6 +50,11 @@
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgRewriteDestructiveUpdatesPass();
 
+/// Pass to perform tiling and distribution of Linalg ops with tensor semantics
+/// to sequentialized SPMD loops.
+std::unique_ptr<OperationPass<ModuleOp>>
+createLinalgTileAndDistributeOnTensorsPass();
+
 /// Populates passes needed to lower a XLA HLO op to LLVM dialect via the
 /// structured ops path. The pass manager `pm` in here should operate on the
 /// module within the IREE::HAL::ExecutableOp.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir
new file mode 100644
index 0000000..b869782
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute_on_tensors.mlir
@@ -0,0 +1,41 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-codegen-llvm-linalg-tile-and-distribute-on-tensors=tile-sizes="1,2" %s | IreeFileCheck %s
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (2, -d0 + 4)>
+
+// CHECK-LABEL: func @tensor
+func @tensor() -> tensor<2x4xf32> {
+  %A = iree.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+  %B = iree.unfoldable_constant dense<[[1.0, 2.0, 3.0, 4.0],
+                       [5.0, 6.0, 7.0, 8.0],
+                       [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32>
+  %C = iree.unfoldable_constant dense<1000.0> : tensor<2x4xf32>
+
+  //  CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  //  CHECK-DAG: %[[C2:.*]] = constant 2 : index
+  //  CHECK-DAG: %[[C4:.*]] = constant 4 : index
+  //  CHECK-DAG: %[[bix:.*]] = iree.workgroup_id {dimension = "x"} : index
+  //  CHECK-DAG: %[[bdx:.*]] = iree.workgoup_size {dimension = "x"} : index
+  //  CHECK-DAG: %[[biy:.*]] = iree.workgroup_id {dimension = "y"} : index
+  //  CHECK-DAG: %[[bdy:.*]] = iree.workgoup_size {dimension = "y"} : index
+  //      CHECK: %{{.*}} = scf.for %[[I:.*]] = %[[bix]] to %[[C2]] step %[[bdx]] iter_args(%arg1 = %2) -> (tensor<2x4xf32>) {
+  // CHECK-NEXT:   %[[biy_scaled:.*]] = muli %[[biy]], %[[C2]] : index
+  // CHECK-NEXT:   %[[bdy_scaled:.*]] = muli %[[bdy]], %[[C2]] : index
+  // CHECK-NEXT:   %{{.*}} = scf.for %[[J:.*]] = %[[biy_scaled]] to %[[C4]] step %[[bdy_scaled]] iter_args(%arg3 = %arg1) -> (tensor<2x4xf32>) {
+  // CHECK-NEXT:     subtensor %{{.*}}[%[[I]], 0] [1, 3] [1, 1] : tensor<2x3xf32> to tensor<1x3xf32>
+  //
+  // Canonicalizations not yet powerful enough here.
+  // CHECK-NEXT:     %[[J_slice_1:.*]] = affine.min #[[$MAP]](%[[J]])
+  // CHECK-NEXT:     subtensor %1[0, %[[J]]] [3, %[[J_slice_1]]] [1, 1] : tensor<3x4xf32> to tensor<3x?xf32>
+  //
+  // Canonicalizations not yet powerful enough here.
+  // CHECK-NEXT:     %[[J_slice_2:.*]] = affine.min #[[$MAP]](%[[J]])
+  // CHECK-NEXT:     subtensor %arg3[%[[I]], %[[J]]] [1, %[[J_slice_2]]] [1, 1] : tensor<2x4xf32> to tensor<1x?xf32>
+  // CHECK-NEXT:     linalg.matmul
+  // CHECK-NEXT:     subtensor_insert {{.*}} : tensor<1x?xf32> into tensor<2x4xf32>
+  // CHECK-NEXT:     scf.yield %{{.*}} : tensor<2x4xf32>
+  // CHECK-NEXT:   }
+  // CHECK-NEXT:   scf.yield %{{.*}} : tensor<2x4xf32>
+  %E = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
+                    init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+  return %E : tensor<2x4xf32>
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 6157e00..e250550 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -66,6 +66,7 @@
     // LinalgToLLVM
     createConvImg2ColMatmulConversionPass();
     createLinalgTileAndDistributePass();
+    createLinalgTileAndDistributeOnTensorsPass();
     createLinalgTileAndVectorizeWorkgroupsPass();
     createLinalgRewriteDestructiveUpdatesPass();
     return true;