Initial Implementation of tile (+distribute) of LinalgExt ops using `TiledOpInterface` (#6423)

This changes adds tiling transformations for LinalgExt using the
TiledOpInterface. The linalg_ext.scatter and linalg_ext.sort
(only parallel dims) operation is used as the candidate for tiling.
The operations implements the TiledOpInterface
which is used by the tiling transformation to tile + distribute the
op.
The tiling transformation is controlled similar to LinalgOps, using
LinalgTilingOptions, LinalgTransformationFilter and
LinalgLoopDistributionOptions.
diff --git a/iree/compiler/Dialect/LinalgExt/IR/BUILD b/iree/compiler/Dialect/LinalgExt/IR/BUILD
index 2530839..23c4f1e 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/IR/BUILD
@@ -23,6 +23,7 @@
             "LinalgExtBase.td",
             "LinalgExtOps.td",
             "LinalgExtInterfaces.td",
+            "TiledOpInterface.td",
         ],
         include = ["*.td"],
     ),
@@ -47,13 +48,17 @@
     deps = [
         ":LinalgExtInterfacesGen",
         ":LinalgExtOpsGen",
+        ":TiledOpInterface",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:SideEffects",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
     ],
 )
 
@@ -116,3 +121,41 @@
         "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
     ],
 )
+
+gentbl_cc_library(
+    name = "TiledOpInterfaceGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "TiledOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "TiledOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "TiledOpInterface.td",
+    td_srcs = [
+        "@llvm-project//mlir:OpBaseTdFiles",
+    ],
+)
+
+cc_library(
+    name = "TiledOpInterface",
+    srcs = [
+        "TiledOpInterface.cpp",
+        "TiledOpInterface.cpp.inc",
+    ],
+    hdrs = [
+        "TiledOpInterface.h",
+        "TiledOpInterface.h.inc",
+    ],
+    deps = [
+        ":TiledOpInterfaceGen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:ViewLikeInterface",
+    ],
+)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index 010ef31..271f30e 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -28,13 +28,16 @@
   DEPS
     ::LinalgExtInterfacesGen
     ::LinalgExtOpsGen
+    ::TiledOpInterface
     LLVMSupport
     MLIRControlFlowInterfaces
     MLIRIR
+    MLIRMemRef
     MLIRParser
     MLIRSideEffectInterfaces
     MLIRStandard
     MLIRSupport
+    MLIRTensor
   PUBLIC
 )
 
@@ -67,4 +70,32 @@
     -gen-dialect-doc LinalgExtDialect.md
 )
 
+iree_tablegen_library(
+  NAME
+    TiledOpInterfaceGen
+  TD_FILE
+    "TiledOpInterface.td"
+  OUTS
+    -gen-op-interface-decls TiledOpInterface.h.inc
+    -gen-op-interface-defs TiledOpInterface.cpp.inc
+)
+
+iree_cc_library(
+  NAME
+    TiledOpInterface
+  HDRS
+    "TiledOpInterface.h"
+    "TiledOpInterface.h.inc"
+  SRCS
+    "TiledOpInterface.cpp"
+    "TiledOpInterface.cpp.inc"
+  DEPS
+    ::TiledOpInterfaceGen
+    LLVMSupport
+    MLIRIR
+    MLIRSupport
+    MLIRViewLikeInterface
+  PUBLIC
+)
+
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 832b5c5..eb0c002 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -9,10 +9,16 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -49,6 +55,43 @@
   }
 }
 
+/// Returns a memref.subview or a tensor.extract_slice based on the type of the
+/// `source`.
+static Value getSlice(OpBuilder &b, Location loc, Value source,
+                      ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      ArrayRef<OpFoldResult> strides) {
+  return TypeSwitch<Type, Value>(source.getType())
+      .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+        return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+                                                strides);
+      })
+      .Case<MemRefType>([&](MemRefType type) -> Value {
+        return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
+                                           strides);
+      })
+      .Default([&](Type t) { return nullptr; });
+}
+
+Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
+  return TypeSwitch<Type, Value>(v.getType())
+      .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+        return builder.create<tensor::DimOp>(loc, v, dim);
+      })
+      .Case<MemRefType>([&](MemRefType t) -> Value {
+        return builder.create<memref::DimOp>(loc, v, dim);
+      })
+      .Default([&](Type t) { return Value(); });
+}
+
+OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) {
+  auto t = v.getType().cast<ShapedType>();
+  if (t.isDynamicDim(dim)) {
+    return getDimValue(builder, loc, v, dim);
+  }
+  return builder.getI64IntegerAttr(t.getDimSize(dim));
+}
+
 //===----------------------------------------------------------------------===//
 // Common methods from Linalg dialect.
 //===----------------------------------------------------------------------===//
@@ -123,7 +166,7 @@
     return t1.getShape()[dim] == t2.getShape()[dim];
   };
 
-  auto indicesType = op.inputs()[1].getType().cast<ShapedType>();
+  auto indicesType = op.getIndicesType();
   if (indicesType.getRank() != 2 ||
       !indicesType.getElementType().isInteger(32)) {
     return op.emitOpError(
@@ -136,7 +179,7 @@
 
   // The first dimension of the indices should match the first dimension of the
   // output. They indicate to the number of updates.
-  auto updateType = op.inputs()[0].getType().cast<ShapedType>();
+  auto updateType = op.getUpdateType();
   if (updateType.getRank() < 1) {
     return op.emitOpError("expected update value to be at least rank 1");
   }
@@ -144,7 +187,7 @@
     return op.emitOpError(
         "mismatch in shape of indices and update value at dim#0");
   }
-  auto originalType = op.outputs()[0].getType().cast<ShapedType>();
+  auto originalType = op.getOriginalType();
   // indexDepth + update dims should match to original dims. The first dim of
   // update is the number of updates.
   if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
@@ -196,6 +239,66 @@
   return success();
 }
 
+SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
+  return {getParallelIteratorTypeName()};
+}
+
+SmallVector<Range> ScatterOp::getLoopBounds(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<ConstantIndexOp>(loc, 0);
+  Value one = builder.create<ConstantIndexOp>(loc, 1);
+  Value ub = getDimValue(builder, loc, updates(), 0);
+  return {Range{zero, ub, one}};
+}
+
+Operation *ScatterOp::getTiledImplementation(
+    OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes,
+    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets) {
+  assert(outputs.size() == 1 && offsets.size() == 1 && sizes.size() == 1);
+  Location loc = getLoc();
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  // Slice of the updates.
+  auto updateRank = getUpdateType().getRank();
+  SmallVector<OpFoldResult> updateOffsets(updateRank, zeroAttr);
+  SmallVector<OpFoldResult> updateSizes(updateRank, zeroAttr);
+  updateOffsets[0] = offsets[0];
+  updateSizes[0] = sizes[0];
+  for (auto dim : llvm::seq<int64_t>(1, updateRank)) {
+    updateSizes[dim] = getDim(builder, loc, updates(), dim);
+  }
+  SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
+  Value tiledUpdate = getSlice(builder, loc, updates(), updateOffsets,
+                               updateSizes, updateStrides);
+  assert(tiledUpdate && "failed to get slice of update");
+
+  // Slice of indices.
+  auto indicesRank = getIndicesType().getRank();
+  SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr);
+  SmallVector<OpFoldResult> indicesSizes(indicesRank, zeroAttr);
+  indicesOffsets[0] = offsets[0];
+  indicesSizes[0] = sizes[0];
+  for (auto dim : llvm::seq<int64_t>(1, indicesRank)) {
+    indicesSizes[dim] = getDim(builder, loc, indices(), dim);
+  }
+  SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
+  Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets,
+                                indicesSizes, indicesStrides);
+  assert(tiledIndices && "failed to get slice of indices");
+
+  resultOffsets.resize(1);
+  resultOffsets[0].resize(getUpdateType().getRank(), zeroAttr);
+  SmallVector<Type> resultTypes;
+  if (getNumResults()) {
+    resultTypes.push_back(getResultTypes()[0]);
+  }
+  return cast<LinalgExtOp>(getOperation())
+      .clone(builder, loc, resultTypes,
+             ValueRange{tiledUpdate, tiledIndices, outputs[0]});
+}
+
 //===----------------------------------------------------------------------===//
 // SortOp
 //===----------------------------------------------------------------------===//
@@ -213,6 +316,9 @@
   if (op.getNumInputs()) {
     return op.emitOpError("does not expect to take any inputs");
   }
+  if (op.getNumOutputs() == 0) {
+    return op.emitOpError("expected at least one `outs` operand");
+  }
 
   Block &block = op.region().front();
   size_t numOutputs = op.getNumOutputs();
@@ -221,7 +327,8 @@
            << 2 * numOutputs << " arguments";
   }
 
-  int rank = op.getRank(op.getOutputOperand(0));
+  int64_t rank = op.getOperandRank();
+  ArrayRef<int64_t> shape = op.getOperandShape();
   if (rank > 1 && !op.dimensionAttr()) {
     return op.emitOpError("dimension must be specified if rank > 1");
   }
@@ -233,10 +340,18 @@
     return op.emitOpError("dimension must be within (0, ") << rank << "]";
   }
 
-  for (auto indexedOperand : llvm::enumerate(op.inputs())) {
+  for (auto indexedOperand : llvm::enumerate(op.outputs())) {
     int index = indexedOperand.index();
-    Type elemType =
-        indexedOperand.value().getType().cast<ShapedType>().getElementType();
+    auto operandType = op.getOperandType(index);
+    if (operandType.getRank() != rank) {
+      return op.emitOpError("expected operand ")
+             << index << " to be rank " << rank << ", same as other operands";
+    }
+    if (operandType.getShape() != shape) {
+      return op.emitOpError("expected operand ")
+             << index << " to have same shape as other operands";
+    }
+    Type elemType = operandType.getElementType();
     for (int i : {2 * index, 2 * index + 1}) {
       Type argType = block.getArgument(i).getType();
       if (argType != elemType) {
@@ -259,6 +374,57 @@
   return success();
 }
 
+SmallVector<StringRef> SortOp::getLoopIteratorTypes() {
+  // All loops except the dimension to sort along are parallel.
+  SmallVector<StringRef> iteratorTypes(getOperandRank(),
+                                       getParallelIteratorTypeName());
+  iteratorTypes[getSortedDimension()] = getReductionIteratorTypeName();
+  return iteratorTypes;
+}
+
+SmallVector<Range> SortOp::getLoopBounds(OpBuilder &builder) {
+  int64_t operandRank = getOperandRank();
+  SmallVector<Range> loopBounds(operandRank);
+  Location loc = getLoc();
+  Value zero = builder.create<ConstantIndexOp>(loc, 0);
+  Value one = builder.create<ConstantIndexOp>(loc, 1);
+  Value source = operand(0);
+  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+Operation *SortOp::getTiledImplementation(
+    OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes,
+    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets) {
+  assert(outputs.size() == this->outputs().size());
+  int64_t rank = getOperandRank();
+  assert(offsets.size() == static_cast<size_t>(rank) &&
+         sizes.size() == static_cast<size_t>(rank));
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  SmallVector<OpFoldResult> strides(rank, oneAttr);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands(outputs.size());
+  resultOffsets.resize(outputs.size());
+  for (auto en : llvm::enumerate(outputs)) {
+    tiledOperands[en.index()] =
+        getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
+    assert(tiledOperands[en.index()] && "failed to get slice of operand");
+    resultOffsets[en.index()].assign(offsets.begin(), offsets.end());
+  }
+  SmallVector<Type, 4> resultTypes;
+  if (getNumResults()) {
+    resultTypes = llvm::to_vector<4>(
+        llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); }));
+  }
+  return cast<LinalgExtOp>(getOperation())
+      .clone(builder, loc, resultTypes, tiledOperands);
+}
+
 }  // namespace linalg_ext
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
index 1925f54..3100b6f 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -7,6 +7,8 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -17,14 +19,19 @@
 namespace mlir {
 namespace iree_compiler {
 namespace linalg_ext {
-class LinalgExtOp;
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`.
+Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
+OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
 
 }  // namespace linalg_ext
 }  // namespace iree_compiler
 }  // namespace mlir
 
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h.inc"  // IWYU pragma: export
 
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index e76c441..d691fed 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -9,6 +9,7 @@
 
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
+include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 
@@ -34,7 +35,9 @@
 // Non-structured ops
 //===----------------------------------------------------------------------===//
 
-def LinalgExt_ScatterOp : LinalgExt_Op<"scatter"> {
+def LinalgExt_ScatterOp : LinalgExt_Op<"scatter",
+    [DeclareOpInterfaceMethods<TiledOpInterface,
+        ["getTiledImplementation"]>]> {
   let summary = "Scatter operator";
   let description = [{
     Based on XLA operation semantics, takes two `inputs` (`update` and
@@ -88,14 +91,26 @@
       return getInputOperand(0)->get();
     }
 
+    ShapedType getUpdateType() {
+      return updates().getType().cast<ShapedType>();
+    }
+
     Value indices() {
       return getInputOperand(1)->get();
     }
 
+    ShapedType getIndicesType() {
+      return indices().getType().cast<ShapedType>();
+    }
+
     Value original() {
       return getOutputOperand(0)->get();
     }
 
+    ShapedType getOriginalType() {
+      return original().getType().cast<ShapedType>();
+    }
+
     int64_t getUpdateSliceRank() {
       return updates().getType().cast<ShapedType>().getRank() - 1;
     }
@@ -106,7 +121,9 @@
   }];
 }
 
-def LinalgExt_SortOp : LinalgExt_Op<"sort"> {
+def LinalgExt_SortOp : LinalgExt_Op<"sort",
+    [DeclareOpInterfaceMethods<TiledOpInterface,
+        ["getTiledImplementation"]>]> {
   let summary = "Sort operator";
   let description = [{
     Based on XLA operation semantics, sorts the given `operands` at the given
@@ -116,7 +133,7 @@
   }];
 
   // Define arguments and results like linalg.generic op. The attribute has the
-  // same definision as mhlo.sort::dimension. If the rank is greater than 1,
+  // same definition as mhlo.sort::dimension. If the rank is greater than 1,
   // the attribute must be set. If the rank is exacatly 1, the dimension is
   // optional.
   let arguments = (ins Variadic<AnyType>:$inputs,
@@ -131,6 +148,27 @@
     custom<LinalgExtOutsList>($outputs, type($outputs))
     $region (`->` type($results)^)?
   }];
+  let extraClassDeclaration = [{
+    Value operand(int index) {
+      return outputs()[index];
+    }
+    ShapedType getOperandType(int index) {
+      return operand(index).getType().cast<ShapedType>();
+    }
+    int64_t getOperandRank() {
+      return getOperandType(0).getRank();
+    }
+    ArrayRef<int64_t> getOperandShape() {
+      return getOperandType(0).getShape();
+    }
+    uint64_t getSortedDimension() {
+      uint64_t sortedDim = 0;
+      if (auto setSortedDim = dimension()) {
+        sortedDim = *setSortedDim;
+      }
+      return sortedDim;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
new file mode 100644
index 0000000..b65d5e8
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -0,0 +1,17 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp.inc"
+
+}
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h
new file mode 100644
index 0000000..164531b
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h
@@ -0,0 +1,19 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+#define IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Support/LLVM.h"
+
+/// Include the ODS generated interface header files.
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h.inc"
+
+#endif  // IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
new file mode 100644
index 0000000..697acae
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -0,0 +1,61 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+#define IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def TiledOpInterface : OpInterface<"TiledOpInterface"> {
+  let description = [{
+    Interface for allowing operations to expose information needed to
+    tile it (similar to LinalgOp, but without having access to
+    indexing maps)
+  }];
+  let cppNamespace = "::mlir::iree_compiler::linalg_ext";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of `StringRef`s that describe the number of
+          loops and the iterator types of the operation. The list is
+          expected to use
+          `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
+          from MLIR Structured Op Utils.
+        }],
+        /*retType=*/"SmallVector<StringRef>",
+        /*methodName=*/"getLoopIteratorTypes"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of ranges that describe the loop bounds and
+          step for the loops of the operation.
+        }],
+        /*retTy=*/"SmallVector<Range>",
+        /*methodName=*/"getLoopBounds",
+        /*args=*/(ins "OpBuilder &":$b)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Generates a tiled version of the operation given the tile
+          size for the loops.
+        }],
+        /*retType=*/"Operation *",
+        /*methodName=*/"getTiledImplementation",
+        /*args=*/(ins
+            "OpBuilder &":$b,
+            "ValueRange ":$outputs,
+            "ArrayRef<OpFoldResult> ":$offsets,
+            "ArrayRef<OpFoldResult> ":$sizes,
+            "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultOffsets),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return nullptr;
+        }]
+      >
+  ];
+}
+
+#endif  // IREE_DIALECT_LINALGEXT_TILEDOPINTERFACES
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 13f9acd..86db7f0 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -26,6 +26,34 @@
 
 // -----
 
+func @sort_mismatch_rank(%arg0: tensor<?x?xi32>, %arg1: tensor<?xf32>)
+    -> (tensor<?x?xi32>, tensor<?xf32>) {
+  // expected-error @+1 {{expected operand 1 to be rank 2, same as other operands}}
+  %0:2 = linalg_ext.sort dimension(0)
+      outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?xf32>) {
+      ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+        %1 = cmpf ogt, %arg4, %arg5 : f32
+        linalg_ext.yield %1 : i1
+      } -> tensor<?x?xi32>, tensor<?xf32>
+  return %0#0, %0#1 : tensor<?x?xi32>, tensor<?xf32>
+}
+
+// -----
+
+func @sort_mismatch_shape(%arg0: tensor<?xi32>, %arg1: tensor<42xf32>)
+    -> (tensor<?xi32>, tensor<42xf32>) {
+  // expected-error @+1 {{expected operand 1 to have same shape as other operands}}
+  %0:2 = linalg_ext.sort dimension(0)
+      outs(%arg0, %arg1 : tensor<?xi32>, tensor<42xf32>) {
+      ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+        %1 = cmpf ogt, %arg4, %arg5 : f32
+        linalg_ext.yield %1 : i1
+      } -> tensor<?xi32>, tensor<42xf32>
+  return %0#0, %0#1 : tensor<?xi32>, tensor<42xf32>
+}
+
+// -----
+
 func @scatter_mixed_tensor_memref(
     %update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
     %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 12dc335..6d73fbf 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -32,6 +32,44 @@
 
 // -----
 
+func @sort_multi_result_tensor(
+    %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+    -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+  %0:2 = linalg_ext.sort dimension(0)
+      outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+      ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+        %1 = cmpf ogt, %arg4, %arg5 : f32
+        linalg_ext.yield %1 : i1
+      } -> tensor<?x?xi32>, tensor<?x?xf32>
+  return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @sort_multi_result_tensor
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>
+//  CHECK-SAME:   %[[ARG1:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]]:2 = linalg_ext.sort dimension(0)
+//  CHECK-SAME:      outs(%[[ARG0]], %[[ARG1]]
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_multi_result_memref(
+    %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+  linalg_ext.sort dimension(0)
+     outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+     ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+       %1 = cmpf ogt, %arg4, %arg5 : f32
+       linalg_ext.yield %1 : i1
+     }
+  return
+}
+// CHECK-LABEL: func @sort_multi_result_memref
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?xi32>
+//  CHECK-SAME:   %[[ARG1:.+]]: memref<?x?xf32>
+//       CHECK:   linalg_ext.sort dimension(0)
+//  CHECK-SAME:      outs(%[[ARG0]], %[[ARG1]]
+
+// -----
+
 func @scatter_tensor_dynamic(
     %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
     %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
index 3139752..e48f8ef 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
@@ -31,24 +31,30 @@
     name = "Transforms",
     srcs = [
         "ConvertToLoops.cpp",
-        "PassDetail.h",
         "Passes.cpp",
+        "Tiling.cpp",
     ],
     hdrs = [
+        "PassDetail.h",
         "Passes.h",
         "Passes.h.inc",
+        "Transforms.h",
     ],
     deps = [
         ":PassesIncGen",
+        "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/LinalgExt/IR",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:Transforms",
     ],
 )
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index b0ae75c..e0edb4c 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -23,23 +23,29 @@
   NAME
     Transforms
   HDRS
+    "PassDetail.h"
     "Passes.h"
     "Passes.h.inc"
+    "Transforms.h"
   SRCS
     "ConvertToLoops.cpp"
-    "PassDetail.h"
     "Passes.cpp"
+    "Tiling.cpp"
   DEPS
     ::PassesIncGen
     LLVMSupport
+    MLIRAffine
     MLIRIR
     MLIRLinalg
+    MLIRLinalgTransforms
     MLIRMemRef
     MLIRPass
     MLIRSCF
     MLIRStandard
     MLIRSupport
+    MLIRTensor
     MLIRTransforms
+    iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::LinalgExt::IR
   PUBLIC
 )
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index df5ed5a..d84b97e 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -13,6 +13,8 @@
 namespace iree_compiler {
 namespace linalg_ext {
 
+std::unique_ptr<OperationPass<FuncOp>> createLinalgExtTilingPass();
+
 std::unique_ptr<OperationPass<FuncOp>> createLinalgExtToLoopsPass();
 
 void registerLinalgExtPasses();
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 41cc182..e221e33 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -15,4 +15,10 @@
   let constructor = "mlir::iree_compiler::linalg_ext::createLinalgExtToLoopsPass()";
 }
 
+def LinalgExtTiling :
+    Pass<"iree-linalg-ext-tile", "FuncOp"> {
+  let summary = "Test pass for tiling LinalgExt ops using TiledOpInterface";
+  let constructor = "mlir::iree_compiler::linalg_ext::createLinalgExtTilingPass()";
+}
+
 #endif  // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
new file mode 100644
index 0000000..7b4be2c
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -0,0 +1,433 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+//===----------------------------------------------------------------------===//
+// Utility methods for tiling a linalg_ext operation that implements a
+// TiledOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns failure if the options are unsupported.
+static LogicalResult verifySupportedTilingOptions(
+    PatternRewriter &rewriter, Operation *op,
+    const linalg::LinalgTilingOptions &options) {
+  if (!options.interchangeVector.empty()) {
+    return rewriter.notifyMatchFailure(op,
+                                       "unsupported interchange during tiling");
+  }
+  if (options.paddingValueComputationFunction) {
+    return rewriter.notifyMatchFailure(op, "unsupported tile + pad option");
+  }
+  if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
+    return rewriter.notifyMatchFailure(op,
+                                       "only tiling with scf.for is supported");
+  }
+  if (options.distribution) {
+    if (llvm::any_of(options.distribution->distributionMethod,
+                     [](linalg::DistributionMethod method) {
+                       return method != linalg::DistributionMethod::Cyclic;
+                     })) {
+      return rewriter.notifyMatchFailure(op,
+                                         "only cyclic distibution is allowed");
+    }
+  }
+  return success();
+}
+
+/// Converts a `Value` to an `OpFoldRedult` by extracting the constant value if
+/// the value is defined by a constant op.
+static OpFoldResult getOpFoldResult(Value value) {
+  IntegerAttr::ValueType attr;
+  if (matchPattern(value, m_ConstantInt(&attr))) {
+    return IntegerAttr::get(value.getType(), attr);
+  }
+  return value;
+}
+static SmallVector<OpFoldResult, 4> getOpFoldResult(ArrayRef<Value> values) {
+  return llvm::to_vector<4>(llvm::map_range(
+      values, [](Value value) { return getOpFoldResult(value); }));
+}
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+                      OpFoldResult valueOrAttr) {
+  if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+    return builder.create<ConstantIndexOp>(loc,
+                                           attr.cast<IntegerAttr>().getInt());
+  }
+  return valueOrAttr.get<Value>();
+}
+
+/// Returns true if loop is untiled. Only checks if the value is statically
+/// zero. It is assumed that a `Value` defined by a constant op is already
+/// converted to an `IntegerAttr` of that value. So here just return true if
+/// this is an attribute with a zero value.
+static bool isUntiledLoop(OpFoldResult valueOrAttr) {
+  auto attr = valueOrAttr.dyn_cast<Attribute>();
+  return attr && attr.cast<IntegerAttr>().getValue() == 0;
+}
+
+/// Generates the tiled loops and the body by invoking the interface methods of
+/// TiledOpInterface.
+/// - `outputs` are the operands to use for outputs of the tiled operation.
+/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
+///   loop is to be untiled it is set to 0.
+/// - `iteratorType` is the type of the loop iterator returned by the
+///   TiledOpInterface.
+/// - `loopBounds` are the bounds of all the loops of the op returned by the
+///   TiledOpInterface.
+/// - `loopDepth` is the current loop depth being processed.
+/// - `offsets` are the `Value`s that represent the position of the tile being
+///   operated on. The offsets are computed as the tiled loops are being
+///   generated.
+/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
+///   distributed loops. It is a stack, and once an entry at the top of the
+///   stack is used for distribution it is popped before processing the inner
+///   loops.
+static FailureOr<TiledOp> tileLinalgExtOpImpl(
+    OpBuilder &builder, TiledOpInterface op, ValueRange outputs,
+    MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<Range> loopBounds, unsigned loopDepth,
+    SmallVectorImpl<OpFoldResult> &offsets,
+    ArrayRef<linalg::ProcInfo> distributionInfo) {
+  Location loc = op.getLoc();
+  // If this is the innermost loop, then generated the tiled implementation of
+  // the op by invoking the TiledOpInterface methods.
+  if (loopDepth == tileSizes.size()) {
+    SmallVector<SmallVector<OpFoldResult, 4>> resultOffsets;
+    Operation *tiledOp = op.getTiledImplementation(builder, outputs, offsets,
+                                                   tileSizes, resultOffsets);
+    if (!tiledOp) {
+      return static_cast<LogicalResult>(
+          op.emitOpError("failed to get tiled implementation"));
+    }
+    assert(tiledOp->getNumResults() == 0 ||
+           (resultOffsets.size() == tiledOp->getNumResults()));
+    TiledOp ret;
+    ret.op = tiledOp;
+
+    // If the operation has results, then the result of the tiled operation is
+    // to be inserted into the `initValues` and returned.
+    if (tiledOp->getNumResults()) {
+      SmallVector<Value> results;
+      results.reserve(tiledOp->getNumResults());
+      for (auto en : llvm::enumerate(tiledOp->getResults())) {
+        Value result = en.value();
+        ArrayRef<OpFoldResult> offsets(resultOffsets[en.index()]);
+        auto resultType = result.getType().cast<ShapedType>();
+        auto oneAttr = builder.getI64IntegerAttr(1);
+        SmallVector<OpFoldResult> strides(resultType.getRank(), oneAttr);
+        auto sizes = llvm::to_vector<4>(llvm::map_range(
+            llvm::seq<int64_t>(0, resultType.getRank()),
+            [&](int64_t dim) { return getDim(builder, loc, result, dim); }));
+        Value insert = builder.create<tensor::InsertSliceOp>(
+            loc, result, outputs[en.index()], offsets, sizes, strides);
+        results.push_back(insert);
+      }
+      std::swap(ret.results, results);
+    }
+    return ret;
+  }
+
+  // If tile size at this depth is empty, do nothing.
+  if (isUntiledLoop(tileSizes[loopDepth])) {
+    auto zeroAttr = builder.getI64IntegerAttr(0);
+    offsets.push_back(zeroAttr);
+    assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
+           "expected loop bounds to have lower bound of zero");
+    tileSizes[loopDepth] = getOpFoldResult(loopBounds[loopDepth].size);
+    return tileLinalgExtOpImpl(builder, op, outputs, tileSizes, iteratorTypes,
+                               loopBounds, loopDepth + 1, offsets,
+                               distributionInfo);
+  }
+
+  // Generate an scf.for for the current loop depth.
+  Value lb = loopBounds[loopDepth].offset;
+  Value ub = loopBounds[loopDepth].size;
+  if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
+    return static_cast<LogicalResult>(
+        op.emitOpError("expected stride to be 1"));
+  }
+  Value step = getValue(builder, loc, tileSizes[loopDepth]);
+
+  // Update lb, ub and step for cyclic distribution.
+  if (!distributionInfo.empty() &&
+      iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
+    linalg::updateBoundsForCyclicDistribution(
+        builder, loc, distributionInfo.front().procId,
+        distributionInfo.front().nprocs, lb, ub, step);
+    distributionInfo = distributionInfo.drop_front();
+  }
+  FailureOr<TiledOp> innerReturnValue;
+  bool isBufferTiling = op->getNumResults() == 0;
+  ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
+  auto forOp = builder.create<scf::ForOp>(
+      loc, lb, ub, step, initValues,
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+        offsets.push_back(iv);
+        auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
+            b.getAffineSymbolExpr(0),
+            b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
+        // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
+        // iv) to account for cases where tile size does not divide (ub - lb)
+        // exactly.
+        Value inBoundsTileSize = b.create<AffineMinOp>(
+            loc, affineMaps,
+            ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
+        tileSizes[loopDepth] = getOpFoldResult(inBoundsTileSize);
+        // Recursively proceed to generate the tiled loop for the next level.
+        innerReturnValue = tileLinalgExtOpImpl(
+            b, op, (isBufferTiling ? outputs : args), tileSizes, iteratorTypes,
+            loopBounds, loopDepth + 1, offsets, distributionInfo);
+        if (failed(innerReturnValue)) return;
+        b.create<scf::YieldOp>(loc, innerReturnValue->results);
+      });
+  if (failed(innerReturnValue)) {
+    return innerReturnValue;
+  }
+  innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
+                                 forOp.getOperation());
+  innerReturnValue->results = forOp.getResults();
+  return innerReturnValue;
+}
+
+FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, LinalgExtOp op,
+                                   const linalg::LinalgTilingOptions &options) {
+  TiledOpInterface tilableOp = dyn_cast<TiledOpInterface>(op.getOperation());
+  if (!tilableOp) return TiledOp{};
+
+  SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+  SmallVector<Value, 4> tileSizesVals =
+      options.tileSizeComputationFunction(b, tilableOp.getOperation());
+  auto zeroAttr = b.getI64IntegerAttr(0);
+
+  // The actual tile sizes used converts `Value` defined as constant 0, to a
+  // zero integer attributes. Currently if the iterator type is not "parallel",
+  // the tile size is forced to zero as well.
+  auto tileSizes = getOpFoldResult(tileSizesVals);
+  tileSizes.resize(iteratorTypes.size(), zeroAttr);
+  for (auto en : llvm::enumerate(iteratorTypes)) {
+    if (en.value() == getParallelIteratorTypeName()) continue;
+    if (!isUntiledLoop(tileSizes[en.index()])) {
+      return static_cast<LogicalResult>(op.emitOpError(
+          "unimplemented tiling of non-parallel loop iterator type"));
+    }
+  }
+
+  // Trivial early exit case of tile sizes being zero for all parallel loops.
+  if (llvm::all_of(tileSizes, isUntiledLoop)) {
+    return TiledOp{op.getOperation(), {}, {}};
+  }
+
+  SmallVector<Range> loopBounds = tilableOp.getLoopBounds(b);
+  SmallVector<linalg::ProcInfo> distributionInfo;
+  // If the tiled loops are distributed, get the proc_id and nprocs for the
+  // distributed loops. First collect the parallel loops by iterating over the
+  // tileSizes and getting the loops that are distribute, i.e.,
+  // - parallel, i.e. iteratorTypes is "parallel"
+  // - tiled, i.e. tileSize != 0
+  if (options.distribution) {
+    SmallVector<Range> distributedLoopRange;
+    for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
+      if (isUntiledLoop(tileSizes[i])) continue;
+      if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+      distributedLoopRange.push_back(loopBounds[i]);
+    }
+    distributionInfo =
+        options.distribution->procInfo(b, op.getLoc(), distributedLoopRange);
+  }
+
+  SmallVector<OpFoldResult> offsets;
+  return tileLinalgExtOpImpl(b, tilableOp, op.outputs(), tileSizes,
+                             iteratorTypes, loopBounds, 0, offsets,
+                             distributionInfo);
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns for tiling LinalgExtOps.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Base pattern for tiling LinalgExtOps.
+struct LinalgExtBaseTilingPattern : public RewritePattern {
+  LinalgExtBaseTilingPattern(StringRef opName, MLIRContext *context,
+                             linalg::LinalgTilingOptions options,
+                             linalg::LinalgTransformationFilter filter =
+                                 linalg::LinalgTransformationFilter(),
+                             PatternBenefit benefit = 1)
+      : RewritePattern(opName, benefit, context),
+        filter(filter),
+        options(options) {}
+
+  LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+                                    TiledOp &result) const;
+
+ private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  linalg::LinalgTransformationFilter filter;
+  /// Options to control tiling;
+  linalg::LinalgTilingOptions options;
+};
+
+template <typename OpTy>
+struct LinalgExtTilingPattern : public LinalgExtBaseTilingPattern {
+  LinalgExtTilingPattern(MLIRContext *context,
+                         linalg::LinalgTilingOptions options,
+                         linalg::LinalgTransformationFilter filter =
+                             linalg::LinalgTransformationFilter(),
+                         PatternBenefit benefit = 1)
+      : LinalgExtBaseTilingPattern(OpTy::getOperationName(), context, options,
+                                   filter, benefit) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    TiledOp tiledOp;
+    // Check for failure.
+    if (failed(LinalgExtBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+                                                               tiledOp))) {
+      return failure();
+    }
+    // Check for do-nothing case.
+    if (!tiledOp.op) return failure();
+    if (tiledOp.op != op) {
+      if (tiledOp.results.empty()) {
+        rewriter.eraseOp(op);
+      } else {
+        rewriter.replaceOp(op, tiledOp.results);
+      }
+    }
+    return success();
+  }
+};
+}  // namespace
+
+LogicalResult LinalgExtBaseTilingPattern::matchAndRewriteBase(
+    Operation *op, PatternRewriter &rewriter, TiledOp &result) const {
+  auto linalgExtOp = dyn_cast<LinalgExtOp>(op);
+  if (!linalgExtOp) return failure();
+  if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+  if (failed(verifySupportedTilingOptions(rewriter, op, options))) {
+    return failure();
+  }
+
+  FailureOr<TiledOp> res = tileLinalgExtOp(rewriter, linalgExtOp, options);
+  if (failed(res)) return res;
+  result = *res;
+  if (result.op) {
+    filter.replaceLinalgTransformationFilter(rewriter, result.op);
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass for tiling Linalg Ext ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LinalgExtTilingPass : public LinalgExtTilingBase<LinalgExtTilingPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
+                memref::MemRefDialect, StandardOpsDialect,
+                tensor::TensorDialect, scf::SCFDialect>();
+  }
+
+  void runOnOperation() override;
+};
+}  // namespace
+
+template <typename OpTy>
+static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
+  return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
+}
+
+void LinalgExtTilingPass::runOnOperation() {
+  FuncOp funcOp = getOperation();
+  MLIRContext *context = funcOp.getContext();
+  RewritePatternSet patterns(context);
+  patterns.add<LinalgExtTilingPattern<ScatterOp>>(
+      context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
+      linalg::LinalgTransformationFilter(
+          Identifier::get("tiling_input", context),
+          Identifier::get("tiling_output", context)));
+  patterns.add<LinalgExtTilingPattern<ScatterOp>>(
+      context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
+      linalg::LinalgTransformationFilter(
+          Identifier::get("no_tiling_input", context),
+          Identifier::get("no_tiling_output", context)));
+  patterns.add<LinalgExtTilingPattern<SortOp>>(
+      context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+      linalg::LinalgTransformationFilter(
+          Identifier::get("outer_reduce_input", context),
+          Identifier::get("outer_reduce_output", context)));
+  patterns.add<LinalgExtTilingPattern<SortOp>>(
+      context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
+      linalg::LinalgTransformationFilter(
+          Identifier::get("inner_reduce_input", context),
+          Identifier::get("inner_reduce_output", context)));
+
+  static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+      [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+        auto numParallelDims = parallelLoopRanges.size();
+
+        SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
+        for (size_t dim = 0; dim < numParallelDims; ++dim) {
+          procInfo[numParallelDims - dim - 1] = {
+              buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp>(
+                  builder, dim),
+              buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupCountOp>(
+                  builder, dim)};
+        }
+        return procInfo;
+      },
+      {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+       linalg::DistributionMethod::Cyclic},
+      DenseMap<StringRef,
+               std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+
+  patterns
+      .add<LinalgExtTilingPattern<ScatterOp>, LinalgExtTilingPattern<SortOp>>(
+          context,
+          linalg::LinalgTilingOptions()
+              .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
+              .setDistributionOptions(workgroupDistributionOptions),
+          linalg::LinalgTransformationFilter(
+              Identifier::get("distribute_input", context),
+              Identifier::get("distribute_output", context)));
+
+  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createLinalgExtTilingPass() {
+  return std::make_unique<LinalgExtTilingPass>();
+}
+
+}  // namespace linalg_ext
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
new file mode 100644
index 0000000..08e1a81
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
@@ -0,0 +1,37 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+#define IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+/// Structure to represent the result of tiling operation.
+struct TiledOp {
+  /// Tiled op.
+  Operation *op;
+  /// Loops generated during tiling.
+  SmallVector<Operation *> loops;
+  /// Values that are replacements for the untiled operations.
+  SmallVector<Value> results;
+};
+
+/// Main entry point for tiling LinalgExtOps using TiledOpInterface.  If the
+/// `op` does not implement the `TiledOpInterface` returns a `TiledOp{}` value.
+FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, LinalgExtOp op,
+                                   const linalg::LinalgTilingOptions &options);
+
+}  // namespace linalg_ext
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
index a2aef95..e12869b 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
@@ -18,6 +18,7 @@
     srcs = enforce_glob(
         [
             "convert_to_loops.mlir",
+            "tiling.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 931e320..0c94851 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "convert_to_loops.mlir"
+    "tiling.mlir"
   DATA
     iree::tools::IreeFileCheck
     iree::tools::iree-opt
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
new file mode 100644
index 0000000..b870530
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -0,0 +1,422 @@
+// RUN: iree-opt -iree-linalg-ext-tile -split-input-file %s | IreeFileCheck %s
+
+func @scatter_tiling(
+    %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+    %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg_ext.scatter
+    {__internal_linalg_transform__ = "tiling_input"}
+    ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//       CHECK: func @scatter_tiling(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+//  CHECK-SAME:   %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[ORIGINAL]])
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+//   CHECK-DAG:     %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+//       CHECK:     %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], 1]
+//       CHECK:     %[[SCATTER_TILE:.+]] = linalg_ext.scatter
+//  CHECK-SAME:         __internal_linalg_transform__ = "tiling_output"
+//  CHECK-SAME:         ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+//  CHECK-SAME:         outs(%[[INIT]]
+//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C1]]
+//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
+//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
+//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scatter_tiling_memref(
+    %original: memref<?x?xf32>, %indices: memref<?x1xi32>,
+    %update : memref<?x?xf32>) {
+  linalg_ext.scatter
+    {__internal_linalg_transform__ = "tiling_input"}
+    ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
+    outs(%original : memref<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    }
+  return
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//       CHECK: func @scatter_tiling_memref(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: memref<?x1xi32>
+//  CHECK-SAME:   %[[UPDATES:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]]
+//       CHECK:   scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+//   CHECK-DAG:     %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]]
+//       CHECK:     %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], 1]
+//       CHECK:     linalg_ext.scatter
+//  CHECK-SAME:         __internal_linalg_transform__ = "tiling_output"
+//  CHECK-SAME:         ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+//  CHECK-SAME:         outs(%[[ORIGINAL]]
+
+// -----
+
+func @scatter_tiling_distribution(
+    %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+    %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg_ext.scatter
+    {__internal_linalg_transform__ = "distribute_input"}
+    ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//       CHECK: func @scatter_tiling_distribution(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+//  CHECK-SAME:   %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+//   CHECK-DAG:   %[[ID:.+]] = flow.dispatch.workgroup.id[0]
+//   CHECK-DAG:   %[[COUNT:.+]] = flow.dispatch.workgroup.count[0]
+//   CHECK-DAG:   %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]]
+//   CHECK-DAG:   %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]]
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]]
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[ORIGINAL]])
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+//   CHECK-DAG:     %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+//       CHECK:     %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], 1]
+//       CHECK:     %[[SCATTER_TILE:.+]] = linalg_ext.scatter
+//  CHECK-SAME:         __internal_linalg_transform__ = "distribute_output"
+//  CHECK-SAME:         ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+//  CHECK-SAME:         outs(%[[INIT]]
+//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C1]]
+//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
+//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
+//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scatter_no_tiling(
+    %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+    %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg_ext.scatter
+    {__internal_linalg_transform__ = "no_tiling_input"}
+    ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//       CHECK: func @scatter_no_tiling
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+//  CHECK-SAME:   %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = linalg_ext.scatter
+//  CHECK-SAME:       __internal_linalg_transform__ = "no_tiling_output"
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[INDICES]]
+//  CHECK-SAME:       outs(%[[ORIGINAL]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @sort_1d(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+  %0 = linalg_ext.sort
+       {__internal_linalg_transform__ = "outer_reduce_input"}
+       outs(%arg0 : tensor<?xi32>) {
+       ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
+         %0 = cmpi sgt, %arg2, %arg3 : i32
+         linalg_ext.yield %0 : i1
+       } -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+//      CHECK: func @sort_1d(
+// CHECK-SAME:   %[[OPERAND:.+]]: tensor<?xi32>
+//      CHECK:   %[[RESULT:.+]] = linalg_ext.sort
+// CHECK-SAME:       {__internal_linalg_transform__ = "outer_reduce_output"}
+// CHECK-SAME:       outs(%[[OPERAND]] :
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @sort_2d(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg_ext.sort dimension(1)
+       {__internal_linalg_transform__ = "inner_reduce_input"}
+       outs(%arg0 : tensor<?x?xi32>) {
+       ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
+         %0 = cmpi sgt, %arg2, %arg3 : i32
+         linalg_ext.yield %0 : i1
+       } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//       CHECK: func @sort_2d(
+//  CHECK-SAME:   %[[OPERAND:.+]]: tensor<?x?xi32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[OPERAND]])
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+//       CHECK:     %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[SORT_TILE:.+]] = linalg_ext.sort
+//  CHECK-SAME:         __internal_linalg_transform__ = "inner_reduce_output"
+//  CHECK-SAME:         outs(%[[OPERAND_SLICE]]
+//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]], %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]], %[[C1]]
+//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
+//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @sort_2d_inner_parallel(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %0 = linalg_ext.sort dimension(0)
+       {__internal_linalg_transform__ = "outer_reduce_input"}
+       outs(%arg0 : tensor<?x?xi32>) {
+       ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
+         %0 = cmpi sgt, %arg2, %arg3 : i32
+         linalg_ext.yield %0 : i1
+       } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//       CHECK: func @sort_2d_inner_parallel(
+//  CHECK-SAME:   %[[OPERAND:.+]]: tensor<?x?xi32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 20 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[OPERAND]])
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+//       CHECK:     %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, %[[IV]]]
+//  CHECK-SAME:         [%[[D0]], %[[USED_TILESIZE]]]
+//       CHECK:     %[[SORT_TILE:.+]] = linalg_ext.sort
+//  CHECK-SAME:         __internal_linalg_transform__ = "outer_reduce_output"
+//  CHECK-SAME:         outs(%[[OPERAND_SLICE]]
+//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]], %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]], %[[C1]]
+//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]]
+//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
+//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @sort_2d_multi_result(
+    %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+    -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+  %0:2 = linalg_ext.sort dimension(1)
+       {__internal_linalg_transform__ = "inner_reduce_input"}
+       outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+       ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+         %1 = cmpf ogt, %arg4, %arg5 : f32
+         linalg_ext.yield %1 : i1
+       } -> tensor<?x?xi32>, tensor<?x?xf32>
+  return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//       CHECK: func @sort_2d_multi_result(
+//  CHECK-SAME:   %[[OPERAND1:.+]]: tensor<?x?xi32>
+//  CHECK-SAME:   %[[OPERAND2:.+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+//       CHECK:   %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+//  CHECK-SAME:       iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+//       CHECK:     %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT2]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
+//       CHECK:     %[[SORT_TILE:.+]]:2 = linalg_ext.sort
+//  CHECK-SAME:         __internal_linalg_transform__ = "inner_reduce_output"
+//  CHECK-SAME:         outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]]#0, %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]]#0, %[[C1]]
+//       CHECK:     %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
+//   CHECK-DAG:     %[[SLICE_D0_0:.+]] = tensor.dim %[[SORT_TILE]]#1, %[[C0]]
+//   CHECK-DAG:     %[[SLICE_D1_0:.+]] = tensor.dim %[[SORT_TILE]]#1, %[[C1]]
+//       CHECK:     %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0]
+//  CHECK-SAME:         [%[[SLICE_D0_0]], %[[SLICE_D1_0]]]
+//       CHECK:     scf.yield %[[YIELD1]], %[[YIELD2]]
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_2d_multi_result_memref(
+    %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+  linalg_ext.sort dimension(0)
+     {__internal_linalg_transform__ = "outer_reduce_input"}
+     outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+     ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+       %0 = cmpf ogt, %arg4, %arg5 : f32
+       linalg_ext.yield %0 : i1
+     }
+  return
+}
+//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//       CHECK: func @sort_2d_multi_result_memref(
+//  CHECK-SAME:   %[[OPERAND1:.+]]: memref<?x?xi32>
+//  CHECK-SAME:   %[[OPERAND2:.+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE:.+]] = constant 20 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+//       CHECK:   scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+//   CHECK-DAG:     %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+//       CHECK:     %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]]
+//  CHECK-SAME:         [%[[D0]], %[[USED_TILESIZE]]]
+//       CHECK:     %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]]
+//  CHECK-SAME:         [%[[D0]], %[[USED_TILESIZE]]]
+//       CHECK:     linalg_ext.sort
+//  CHECK-SAME:         __internal_linalg_transform__ = "outer_reduce_output"
+//  CHECK-SAME:         outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+
+// -----
+
+func @sort_3d_multi_result_distribute(
+  %arg0: tensor<?x?x?xi32>, %arg1 : tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+  %0, %1 = linalg_ext.sort dimension(1)
+      {__internal_linalg_transform__ = "distribute_input"}
+      outs(%arg0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+      ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+        %2 = cmpf ogt, %arg4, %arg5 : f32
+        linalg_ext.yield %2 : i1
+      } -> tensor<?x?x?xi32>, tensor<?x?x?xf32>
+  return %0, %1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//       CHECK: func @sort_3d_multi_result_distribute(
+//  CHECK-SAME:   %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor<?x?x?xi32>
+//  CHECK-SAME:   %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE1:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[TILESIZE2:.+]] = constant 30 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]]
+//   CHECK-DAG:   %[[IDX:.+]] = flow.dispatch.workgroup.id[0]
+//   CHECK-DAG:   %[[COUNTX:.+]] = flow.dispatch.workgroup.count[0]
+//   CHECK-DAG:   %[[IDY:.+]] = flow.dispatch.workgroup.id[1]
+//   CHECK-DAG:   %[[COUNTY:.+]] = flow.dispatch.workgroup.count[1]
+//   CHECK-DAG:   %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+//   CHECK-DAG:   %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+//       CHECK:   %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+//  CHECK-SAME:       iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+//   CHECK-DAG:     %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+//   CHECK-DAG:     %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+//   CHECK-DAG:     %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+//       CHECK:     %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+//  CHECK-SAME:         iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]])
+//   CHECK-DAG:       %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+//       CHECK:       %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+//  CHECK-SAME:           [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+//       CHECK:       %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+//  CHECK-SAME:           [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+//       CHECK:       %[[SORT_SLICE:.+]]:2 = linalg_ext.sort
+//  CHECK-SAME:           __internal_linalg_transform__ = "distribute_output"
+//  CHECK-SAME:           outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+//       CHECK:       %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0
+//  CHECK-SAME:           into %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+//       CHECK:       %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1
+//  CHECK-SAME:           into %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+//       CHECK:       scf.yield %[[YIELD1]], %[[YIELD2]]
+//       CHECK:     scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_3d_multi_result_distribute_memref(
+  %arg0: memref<?x?x?xi32>, %arg1 : memref<?x?x?xf32>) {
+  linalg_ext.sort dimension(1)
+      {__internal_linalg_transform__ = "distribute_input"}
+      outs(%arg0, %arg1 : memref<?x?x?xi32>, memref<?x?x?xf32>) {
+      ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32):  // no predecessors
+        %0 = cmpf ogt, %arg4, %arg5 : f32
+        linalg_ext.yield %0 : i1
+      }
+  return
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//       CHECK: func @sort_3d_multi_result_distribute_memref(
+//  CHECK-SAME:   %[[OPERAND1:[a-zA-Z0-9_]+]]: memref<?x?x?xi32>
+//  CHECK-SAME:   %[[OPERAND2:[a-zA-Z0-9_]+]]: memref<?x?x?xf32>
+//   CHECK-DAG:   %[[TILESIZE1:.+]] = constant 10 : index
+//   CHECK-DAG:   %[[TILESIZE2:.+]] = constant 30 : index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+//   CHECK-DAG:   %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]]
+//   CHECK-DAG:   %[[IDX:.+]] = flow.dispatch.workgroup.id[0]
+//   CHECK-DAG:   %[[COUNTX:.+]] = flow.dispatch.workgroup.count[0]
+//   CHECK-DAG:   %[[IDY:.+]] = flow.dispatch.workgroup.id[1]
+//   CHECK-DAG:   %[[COUNTY:.+]] = flow.dispatch.workgroup.count[1]
+//   CHECK-DAG:   %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+//   CHECK-DAG:   %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+//       CHECK:   scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+//   CHECK-DAG:     %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+//   CHECK-DAG:     %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+//   CHECK-DAG:     %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+//       CHECK:     scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+//   CHECK-DAG:       %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+//       CHECK:       %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]]
+//  CHECK-SAME:           [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+//       CHECK:       %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]]
+//  CHECK-SAME:           [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+//       CHECK:       linalg_ext.sort
+//  CHECK-SAME:           __internal_linalg_transform__ = "distribute_output"
+//  CHECK-SAME:           outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]