Make Dispatch region formation handle `LinalgExtOp`s that implement `TiledOpInterface`. (#6540)
Using the pattern to tile + distribute operations that implement the
TiledOpInterface, this change extends dispatch region creation logic
to handle LinalgExpOp that implement the tiled op interface.
Also a few clean ups
Add utility functions to interpret the root and fusion attributes to
make the intent clearer.
Modify the logic that decides the loops to be partitioned to be more
configurable
Make the analysis to find tied operands of the dispatch region run
after destructive updates are resolved to use
flow.dispatch.tensor.load/store. This seems to be a bug, that
would disallow having any LinalgOp not have tied operands either.
Remove the use of AffineMinSCFCanonicalizationPattern. It seems to
hit an assertion (see Issue Segfault during application of AffineMinSCFCanonicalizationPattern #6520), also since all loop bounds are
dynamic, there is no reason to apply this canonicalization.
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index f0e8010..17608bc 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -180,40 +180,37 @@
namespace iree_compiler {
LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
- linalg::LinalgOp rootOperation;
+ // TODO(ravishankarm): The following logic to get LinalgOps needs
+ // fixing.
+ // - Should be able to handle multiple entry points (so assert on single
+ // funcop is unnecessary)
+ // - The funcOp itself cannot be handled if it doesnt have a single block. The
+ // compilation logic here (which also sets the entry point configuration),
+ // doesnt work when there is arbitrary control flow.
auto funcOps = moduleOp.getOps<FuncOp>();
assert(llvm::hasSingleElement(funcOps));
FuncOp funcOp = *funcOps.begin();
SmallVector<linalg::LinalgOp, 4> linalgOps;
funcOp.walk([&](linalg::LinalgOp op) { linalgOps.push_back(op); });
+
if (linalgOps.empty()) {
return ::setTranslationInfo(
funcOp, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute,
{1, 1, 1});
}
- if (linalgOps.size() == 1) rootOperation = *linalgOps.begin();
+
+ linalg::LinalgOp rootOperation;
// if there is more than one linalg op, look for the root one.
- for (linalg::LinalgOp op : linalgOps) {
- if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
- linalg::ConvInputNHWCFilterHWCFOp,
- linalg::DepthwiseConvInputNHWCFilterHWCOp,
- linalg::ConvInputNHWCFilterHWCFOp,
- linalg::DepthwiseConvInputNHWCFilterHWCFOp,
- linalg::DepthwiseConvInputNHWCFilterHWCOp, linalg::PoolingNhwcMaxOp,
- linalg::PoolingNhwcMinOp, linalg::PoolingNhwcSumOp>(
+ for (linalg::LinalgOp op : llvm::reverse(linalgOps)) {
+ if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(
op.getOperation())) {
rootOperation = op;
break;
}
}
if (!rootOperation) {
- // If no named ops the dispatch region should have at exactly one generic op
- // which is root operation.
- assert(llvm::count_if(linalgOps, [](linalg::LinalgOp op) {
- return isa<linalg::GenericOp>(op);
- }) == 1);
- for (linalg::LinalgOp op : linalgOps) {
- if (isa<linalg::GenericOp>(op)) {
+ for (linalg::LinalgOp op : llvm::reverse(linalgOps)) {
+ if (isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(op)) {
rootOperation = op;
break;
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 23e5bb0..167a506 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -67,6 +67,8 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/LinalgExt/IR",
+ "//iree/compiler/Dialect/LinalgExt/Transforms",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Shape/Utils:TypeConversion",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index f6ecdbb..57484a1 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -71,6 +71,8 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Shape::Utils::TypeConversion
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
index dc6e9d0..69249fc 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -56,25 +56,16 @@
// TODO(nicolasvasilache): Use some interface instead of op names directly.
static bool hasDestructiveUpdateSubTensorUses(
- Value v, SpecialTerminatorOpCapture &capture) {
- SmallVector<tensor::ExtractSliceOp, 4> reads;
- SmallVector<tensor::InsertSliceOp, 4> writes;
- for (auto &u : v.getUses()) {
- if (auto subTensorOp = dyn_cast<tensor::ExtractSliceOp>(u.getOwner())) {
- reads.push_back(subTensorOp);
- continue;
- }
+ BlockArgument arg, SpecialTerminatorOpCapture &capture) {
+ SmallVector<Operation *> reads;
+ SmallVector<tensor::InsertSliceOp> writes;
+ for (OpOperand &u : arg.getUses()) {
if (auto subTensorInsertOp =
dyn_cast<tensor::InsertSliceOp>(u.getOwner())) {
writes.push_back(subTensorInsertOp);
- continue;
+ } else {
+ reads.push_back(u.getOwner());
}
- if (auto dimOp = dyn_cast<tensor::DimOp>(u.getOwner())) {
- continue;
- }
- LLVM_DEBUG(llvm::dbgs() << "found non-destructive update pattern use: "
- << *(u.getOwner()) << "\n");
- return false;
}
// For now, only allow exactly a single SubTensorInsertOp that must be
// dominated by all SubTensorOp.
@@ -82,12 +73,11 @@
// Small local dominance computation.
DominanceInfo domInfo(writes.front()->getParentOp());
for (auto read : reads) {
- LLVM_DEBUG(llvm::dbgs() << "read: " << *read.getOperation() << "\n");
- if (!domInfo.properlyDominates(read.getOperation(), writes.front())) {
- LLVM_DEBUG(llvm::dbgs()
- << "non-destructive use-def: " << *(read.getOperation())
- << " does not properly dominate "
- << *(writes.front().getOperation()) << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "read: " << *read << "\n");
+ if (!domInfo.properlyDominates(read, writes.front())) {
+ LLVM_DEBUG(llvm::dbgs() << "non-destructive use-def: " << *read
+ << " does not properly dominate "
+ << *(writes.front().getOperation()) << "\n");
return false;
}
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index ab38622..d89902d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -10,10 +10,13 @@
#include "iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -63,57 +66,68 @@
static unsigned kNumMaxParallelDims = 3;
-namespace {
-/// PatternRewriter that allows replacing only a subset of uses.
-/// Since this only adds a method, it can just be static_cast'ed to when
-/// applying a rewrite.
-/// TODO(nicolasvasilache): upstream support for this is landing, rebase on that
-struct PatternRewriterWithScopedReplaceOp : public PatternRewriter {
- void replaceOpWithinScope(Operation *op, ValueRange newValues, Block *block) {
- // Notify the rewriter subclass that we're about to replace this root.
- notifyRootReplaced(op);
+//===----------------------------------------------------------------------===//
+// Root and fusion group attribute handling
+//===----------------------------------------------------------------------===//
- assert(op->getNumResults() == newValues.size() &&
- "incorrect # of replacement values");
- bool erase = true;
- SmallVector<Operation *, 4> ops;
- SmallVector<Value, 4> operands, repls;
- for (auto &use : op->getUses()) {
- if (!block->getParentOp()->isProperAncestor(use.getOwner())) {
- erase = false;
- continue;
- }
- OpResult opResult = use.get().cast<OpResult>();
- ops.push_back(use.getOwner());
- operands.push_back(use.get());
- repls.push_back(newValues[opResult.getResultNumber()]);
- }
- // Perform the actual replacements.
- for (auto it : llvm::zip(ops, operands, repls))
- std::get<0>(it)->replaceUsesOfWith(std::get<1>(it), std::get<2>(it));
- if (erase) {
- notifyOperationRemoved(op);
- op->erase();
- }
+/// Returns true if an op has a root operation.
+static bool hasRootOpAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<IntegerAttr>(kRootOpAttr));
+}
+/// Removes root attribute. Asserts if root attribute is not present.
+static void removeRootOpAttribute(Operation *op) {
+ assert(op->hasAttr(kRootOpAttr) &&
+ "removing root attribute from op that is not a root attribute");
+ op->removeAttr(kRootOpAttr);
+}
+/// Sets the root attribute for an operation. The root attribute needs a number
+/// to identify the root. Asserts if root attribute is already set on an
+/// operation.
+static void setRootAttribute(MLIRContext *context, Operation *op,
+ int64_t rootNumber) {
+ assert(!op->hasAttr(kRootOpAttr) &&
+ "invalid to update root attribute on an op");
+ op->setAttr(kRootOpAttr,
+ IntegerAttr::get(IntegerType::get(context, 64), rootNumber));
+}
+/// Returns the number of the root. Asserts if the operation is not already set
+/// as a root.
+static int64_t getRootNumber(Operation *op) {
+ return op->getAttrOfType<IntegerAttr>(kRootOpAttr).getInt();
+}
+/// Returns true if an op is part of a fusion group.
+static bool hasFusionGroupsAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr));
+}
+/// Returns the fusion groups for the given `op`.
+static SmallVector<int64_t, 1> getFusionGroups(Operation *op) {
+ SmallVector<int64_t, 1> fusionGroups = {};
+ if (auto fusionGroupsAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
+ fusionGroups = llvm::to_vector<1>(llvm::map_range(
+ fusionGroupsAttr,
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
}
-};
-
-struct DispatchLinalgOnTensorsPass
- : public DispatchLinalgOnTensorsBase<DispatchLinalgOnTensorsPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
- scf::SCFDialect, ShapeDialect, tensor::TensorDialect>();
+ return fusionGroups;
+}
+/// Appends the given `op` to the `newGroups` fusion groups.
+static void appendToFusionGroup(Operation *op, ArrayRef<int64_t> newGroups) {
+ SmallVector<int64_t, 1> fusionGroups = getFusionGroups(op);
+ fusionGroups.append(newGroups.begin(), newGroups.end());
+ op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups));
+}
+/// Returns true if the given `op` is in the `targetGroup` fusion group.
+static bool isInFusionGroup(Operation *op, unsigned targetGroup) {
+ if (ArrayAttr opGroupAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
+ return llvm::any_of(opGroupAttr, [&targetGroup](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() == targetGroup;
+ });
}
- DispatchLinalgOnTensorsPass() = default;
- DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
- void runOnOperation() override;
-
- private:
- Statistic numDispatches{this, "number of dispatches",
- "Number of Flow dispatches created"};
-};
-} // namespace
+ return false;
+}
+/// Removes the fusion groups attribute.
+static void removeFusionGroupsAttribute(Operation *op) {
+ op->removeAttr(kFusionGroupsAttr);
+}
//===----------------------------------------------------------------------===//
// Utility methods
@@ -132,11 +146,6 @@
.size();
}
-/// Returns the number of loops of the operation that are to be tiled.
-static size_t getNumTilableLoops(linalg::LinalgOp op) {
- return std::min<size_t>(getNumOuterParallelLoops(op), kNumMaxParallelDims);
-}
-
/// Given the `shape` of the computation with the first element being the
/// slowest varying and last element being the fastest warying returns the
/// workload value with
@@ -154,34 +163,6 @@
return workload;
}
-/// Returns the fusion groups for the given `op`.
-static SmallVector<int64_t, 1> getFusionGroups(Operation *op) {
- SmallVector<int64_t, 1> fusionGroups = {};
- if (auto fusionGroupsAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
- fusionGroups = llvm::to_vector<1>(llvm::map_range(
- fusionGroupsAttr,
- [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
- }
- return fusionGroups;
-}
-
-/// Appends the given `op` to the `newGroups` fusion groups.
-static void appendToFusionGroup(Operation *op, ArrayRef<int64_t> newGroups) {
- SmallVector<int64_t, 1> fusionGroups = getFusionGroups(op);
- fusionGroups.append(newGroups.begin(), newGroups.end());
- op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups));
-}
-
-/// Returns true if the given `op` is in the `targetGroup` fusion group.
-static bool isInFusionGroup(Operation *op, unsigned targetGroup) {
- if (ArrayAttr opGroupAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
- return llvm::any_of(opGroupAttr, [&targetGroup](Attribute attr) {
- return attr.cast<IntegerAttr>().getInt() == targetGroup;
- });
- }
- return false;
-}
-
//===----------------------------------------------------------------------===//
// Op property charecterizations
//===----------------------------------------------------------------------===//
@@ -213,19 +194,12 @@
/// linalg.init_tensor operations.
static bool isRootOp(Operation *op) {
- if (auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op)) {
- if (contractionOp.isRowMajorMatmul() ||
- contractionOp.isColumnMajorMatmul() ||
- contractionOp.isRowMajorBatchMatmul()) {
- return true;
- }
+ if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return false;
}
-
- return isa<linalg::ConvInputNHWCFilterHWCFOp,
- linalg::DepthwiseConvInputNHWCFilterHWCOp,
- linalg::DepthwiseConvInputNHWCFilterHWCFOp,
- linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
- linalg::PoolingNhwcMinOp>(op);
+ return (isa<linalg::LinalgOp>(op) &&
+ !isa<linalg::GenericOp, linalg::FillOp>(op)) ||
+ isa<linalg_ext::LinalgExtOp>(op);
}
static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
@@ -316,65 +290,48 @@
// that one. This avoid any concerns related to tensor operands that are only
// used for their DimOp. This is a canonicalization that is more involved than
// necessary across the boundary of regions without captures.
-//
-// TODO(nicolasvasilache): This implementation jumps an abstraction gap as it
-// knows that `clonedLinalgOp` has been tiled into `tiledLinalgOp`. In the case
-// where a `rootOp`, i.e. the untiled original operation used to create the
-// dispatch region, can be fused with its producer, this allows calling into a
-// `fuseProducerOfTensor` to which we provide the producer by construction. This
-// avoids an analysis that would need to reconstruct a destructive update from
-// the loop nest + operations in order to get the producer of an `out` tensor.
-// In the future, this analysis should be implemented in core but for now it is
-// IREE-only.
-//
-// TODO(antiagainst): Right now this function requires taking all shaped
-// operands of the tiled op to inspect them. This should probably be changed to
-// just take one operand we know that need to be fused.
static void pullInProducersInSameGroup(
PatternRewriter &rewriter, IREE::Flow::DispatchWorkgroupsOp dispatchOp,
- linalg::LinalgOp tiledOp, ValueRange tiledOpOperands,
+ linalg::LinalgOp tiledOp, ValueRange untiledOpOperands,
ArrayRef<Operation *> tiledLoops, int64_t groupNum) {
DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs() << "pull in producers for tiled op: "
<< tiledOp << "\n");
+
// Scoped within DispatchWorkgroupOp.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&dispatchOp.getRegion().front());
- for (auto en : llvm::enumerate(tiledOpOperands)) {
+ for (auto en : llvm::enumerate(untiledOpOperands)) {
if (auto producer = en.value().getDefiningOp<linalg::LinalgOp>()) {
if (!isInFusionGroup(producer, groupNum)) continue;
DEBUG_WITH_TYPE(DEBUG_TYPE,
llvm::dbgs() << "current producer: " << producer << "\n");
- Operation *clonedOpToFuse = rewriter.clone(*producer);
+ Operation *clonedOrigProducer = rewriter.clone(*producer);
+ rewriter.replaceOpWithinBlock(producer, clonedOrigProducer->getResults(),
+ &dispatchOp.getRegion().front());
+
linalg::LinalgOp fusedProducer;
-
- static_cast<PatternRewriterWithScopedReplaceOp &>(rewriter)
- .replaceOpWithinScope(producer, clonedOpToFuse->getResults(),
- &dispatchOp.getRegion().front());
-
if (tiledLoops.empty()) {
DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
<< "no loops; just copy over the op\n");
- // The root op wasn't tiled. We are done then; just to remove the
- // attribute.
- clonedOpToFuse->removeAttr(kFusionGroupsAttr);
- fusedProducer = cast<linalg::LinalgOp>(clonedOpToFuse);
+ // The root op wasn't tiled. We are done then.
+ removeFusionGroupsAttribute(clonedOrigProducer);
+ fusedProducer = cast<linalg::LinalgOp>(clonedOrigProducer);
} else {
// TODO: this is incorrect on general pattern failures, try pattern
// within pattern.
OpResult opResult = en.value().cast<OpResult>();
auto maybeFusionInfo = linalg::fuseProducerOfTensor(
- rewriter, clonedOpToFuse->getResult(opResult.getResultNumber()),
- *tiledOp.getInputAndOutputOperands()[en.index()]);
+ rewriter, clonedOrigProducer->getResult(opResult.getResultNumber()),
+ tiledOp->getOpOperand(en.index()));
if (!maybeFusionInfo.hasValue()) {
DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
<< "failed to fuse with tensor\n");
- rewriter.replaceOp(clonedOpToFuse, producer->getResults());
+ rewriter.replaceOp(clonedOrigProducer, producer->getResults());
} else {
DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
<< "succeeded to fuse with tensor\n");
- maybeFusionInfo->fusedProducer.getOperation()->removeAttr(
- kFusionGroupsAttr);
+ removeFusionGroupsAttribute(maybeFusionInfo->fusedProducer);
fusedProducer = maybeFusionInfo->fusedProducer;
}
}
@@ -383,10 +340,12 @@
// producer's operands and pull them in if they are marked to be fused
// into the current group.
if (fusedProducer) {
- SmallVector<Value> producerOperands =
- cast<linalg::LinalgOp>(clonedOpToFuse).getInputAndOutputOperands();
+ SmallVector<Value> origProducerOpOperands =
+ cast<linalg::LinalgOp>(clonedOrigProducer)
+ .getInputAndOutputOperands();
pullInProducersInSameGroup(rewriter, dispatchOp, fusedProducer,
- producerOperands, tiledLoops, groupNum);
+ origProducerOpOperands, tiledLoops,
+ groupNum);
}
}
}
@@ -540,21 +499,31 @@
// TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
// when it's available in MLIR core in some form.
- if (auto insertOp = dyn_cast_or_null<tensor::InsertSliceOp>(tieOp)) {
- auto loadOp =
- insertOp.dest().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
- if (!loadOp) return nullptr;
- return loadOp.source().cast<BlockArgument>();
- } else if (auto linalgOp = dyn_cast_or_null<linalg::LinalgOp>(tieOp)) {
- unsigned resultIndex = storeOp.value().cast<OpResult>().getResultNumber();
- auto loadOp = linalgOp.getOutputTensorOperands()[resultIndex]
- ->get()
- .getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
- if (!loadOp) return nullptr;
- return loadOp.source().cast<BlockArgument>();
- }
+ BlockArgument tiedArg =
+ TypeSwitch<Operation *, BlockArgument>(tieOp)
+ .Case<tensor::InsertSliceOp>(
+ [&](tensor::InsertSliceOp insertOp) -> BlockArgument {
+ auto loadOp = insertOp.dest()
+ .template getDefiningOp<
+ IREE::Flow::DispatchTensorLoadOp>();
+ if (!loadOp) return nullptr;
+ return loadOp.source().cast<BlockArgument>();
+ })
+ .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>(
+ [&](auto linalgLikeOp) -> BlockArgument {
+ unsigned resultIndex =
+ storeOp.value().cast<OpResult>().getResultNumber();
+ auto loadOp =
+ linalgLikeOp.getOutputTensorOperands()[resultIndex]
+ ->get()
+ .template getDefiningOp<
+ IREE::Flow::DispatchTensorLoadOp>();
+ if (!loadOp) return nullptr;
+ return loadOp.source().template cast<BlockArgument>();
+ })
+ .Default([&](Operation *) -> BlockArgument { return nullptr; });
- return nullptr;
+ return tiedArg;
};
SmallVector<BlockArgument, 4> tiedOperands;
@@ -680,15 +649,38 @@
dispatchOp.operandsMutable().assign(llvm::to_vector<4>(valuesDefinedAbove));
dispatchOp.operand_dimsMutable().assign(operandDynamicDims);
- // Now try to see if we can tie certain results to operands in order to
- // indicate sharing storage. This need to happen here because it needs to
- // access region block arguments for input/output tensors, which aren't
- // available until now.
- tryToTieOperandsAndResults(dispatchOp);
-
return success();
}
+/// Returns the loops that are partitioned during dispatch region formations, in
+/// order, i.e. starting from the outer-most to innermost.
+static SmallVector<unsigned> getPartitionedLoops(Operation *op) {
+ SmallVector<unsigned> partitionedLoops;
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ size_t numOuterParallelLoops = getNumOuterParallelLoops(linalgOp);
+ partitionedLoops =
+ llvm::to_vector<4>(llvm::seq<unsigned>(0, numOuterParallelLoops));
+ if (partitionedLoops.size() > kNumMaxParallelDims) {
+ partitionedLoops.erase(
+ partitionedLoops.begin(),
+ std::next(partitionedLoops.begin(),
+ numOuterParallelLoops - kNumMaxParallelDims));
+ }
+ return partitionedLoops;
+ }
+ if (auto tilableOp = dyn_cast<linalg_ext::TiledOpInterface>(op)) {
+ auto iteratorTypes = tilableOp.getLoopIteratorTypes();
+ for (auto en : llvm::enumerate(iteratorTypes)) {
+ if (en.value() == getParallelIteratorTypeName()) {
+ partitionedLoops.push_back(en.index());
+ }
+ if (partitionedLoops.size() == kNumMaxParallelDims) break;
+ }
+ return partitionedLoops;
+ }
+ return {};
+}
+
/// Computes the shape of the output. This is used to get the workload of the
/// dispatch region if a dispatch region contains a single "Dispatchable op"
static Optional<SmallVector<SmallVector<Value, 4>, 1>> computeOutputShape(
@@ -729,10 +721,10 @@
namespace {
// Rewrite pattern to ensure only ops with tensor semantics are tiled.
-struct TileAndDistributeOnTensorsPattern
+struct TileAndDistributeLinalgOpsPattern
: public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
- TileAndDistributeOnTensorsPattern(MLIRContext *context,
+ TileAndDistributeLinalgOpsPattern(MLIRContext *context,
linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
@@ -742,8 +734,7 @@
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure();
- IntegerAttr rootOpAttr = op->getAttrOfType<IntegerAttr>(kRootOpAttr);
- if (!rootOpAttr) return failure();
+ if (!hasRootOpAttribute(op)) return failure();
// TODO(ravishankarm): It is getting strange to track when to apply this
// pattern and when not to. Need to revisit this, with dynamic shape cases
@@ -753,16 +744,12 @@
// Compute workgroup count to use for the dispatch op. These are the ranges
// of the outermost parallel loops that can be distributed.
Location loc = op->getLoc();
- SmallVector<Value, 4> count = llvm::to_vector<4>(
- llvm::map_range(linalgOp.createLoopRanges(rewriter, loc),
- [](Range r) { return r.size; }));
- size_t numParallelLoops = getNumOuterParallelLoops(op);
- if (numParallelLoops > kNumMaxParallelDims) {
- count.erase(
- count.begin(),
- std::next(count.begin(), numParallelLoops - kNumMaxParallelDims));
+ SmallVector<Range> loopRanges = linalgOp.createLoopRanges(rewriter, loc);
+ SmallVector<unsigned> partitionedLoops = getPartitionedLoops(op);
+ SmallVector<Value> count;
+ for (auto dim : partitionedLoops) {
+ count.push_back(loopRanges[dim].size);
}
- count.resize(getNumTilableLoops(op));
auto workload = convertToWorkload(rewriter, loc, count);
// Capture dynamic result dimensions.
@@ -799,16 +786,95 @@
rewriter.eraseOp(dispatchOp);
return failure();
}
- // Keep track of the tiledOpOperands for fusion.
- SmallVector<Value> tiledOperands =
+
+ SmallVector<Value> clonedOpOperands =
clonedLinalgOp.getInputAndOutputOperands();
+ pullInProducersInSameGroup(rewriter, dispatchOp, tiledLinalgOp.op,
+ clonedOpOperands, tiledLinalgOp.loops,
+ getRootNumber(op));
+
+ // Keep track of the tiledOpOperands for fusion.
rewriter.replaceOp(clonedLinalgOp, tiledLinalgOp.tensorResults);
- pullInProducersInSameGroup(rewriter, dispatchOp, tiledLinalgOp.op,
- tiledOperands, tiledLinalgOp.loops,
- rootOpAttr.getInt());
+ removeRootOpAttribute(tiledLinalgOp.op);
- tiledLinalgOp.op.getOperation()->removeAttr(kRootOpAttr);
+ rewriter.replaceOpWithIf(op, dispatchOp.getResults(),
+ [&](OpOperand &operand) {
+ return !isa<tensor::DimOp>(operand.getOwner());
+ });
+ return success();
+ }
+};
+
+/// Rewrite pattern to tile and distribute `LinalgExt` ops.
+struct TiledOpInterfacePattern
+ : public linalg_ext::TiledOpInterfaceBaseTilingPattern {
+ using Base = linalg_ext::TiledOpInterfaceBaseTilingPattern;
+ using Base::TiledOpInterfaceBaseTilingPattern;
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Check if the op implements the LinalgExt interface and the
+ // TiledOpInterface.
+ auto tilableOp = dyn_cast<linalg_ext::TiledOpInterface>(op);
+ auto linalgExtOp = dyn_cast<linalg_ext::LinalgExtOp>(op);
+ if (!linalgExtOp || !tilableOp) return failure();
+ if (!hasRootOpAttribute(op)) return failure();
+ if (hasOnlyDimUses(op)) return failure();
+
+ SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+ SmallVector<Range> loopRanges = tilableOp.getLoopBounds(rewriter);
+ SmallVector<unsigned> partitionedLoops = getPartitionedLoops(op);
+ SmallVector<Value> count;
+ for (auto dim : partitionedLoops) {
+ count.push_back(loopRanges[dim].size);
+ }
+ Location loc = op->getLoc();
+ auto workload = convertToWorkload(rewriter, loc, count);
+
+ // Capture dynamic result dimensions.
+ SmallVector<Value, 4> resultDynamicDims;
+ for (auto result : linalgExtOp.outputs()) {
+ resultDynamicDims.append(
+ Shape::buildOrFindDynamicDimsForValue(loc, result, rewriter));
+ }
+
+ // Note: DispatchTensorStoreOp generated by the
+ // `buildOperandLessFlowDispatchWorkgroupOp` is an abstraction jump that
+ // consumes the SSA value produced by `clonedOp` but it does not comply with
+ // the semantics of DispatchWorkgroupsOp which explicitly states: "behavior
+ // is undefined if multiple workgroups store to the same regions of the
+ // output tensors". Similarly to sequentialized SPMD loops, the semantics
+ // is valid assuming a sequential ordering of execution. After destructive
+ // update rewrites, the abstraction gap disappears.
+ auto en =
+ buildOperandLessFlowDispatchWorkgroupOp(rewriter, loc, workload, op);
+ IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first;
+ auto clonedOp = cast<linalg_ext::LinalgExtOp>(en.second);
+ dispatchOp.result_dimsMutable().assign(resultDynamicDims);
+
+ // Scoped within DispatchWorkgroupOp.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(clonedOp);
+
+ linalg_ext::TiledOp tiledOp;
+ LogicalResult tilingResult = Base::matchAndRewriteBase(
+ clonedOp, clonedOp.outputs(), rewriter, tiledOp);
+ if (failed(tilingResult)) {
+ // GreedyPatternRewriter is not transactional and does not stop on
+ // failure. Must explicitly delete on all failure paths.
+ rewriter.eraseOp(dispatchOp);
+ return failure();
+ }
+ // Keep track of the tiledOpOperands for fusion.
+ SmallVector<Value> tiledOperands = clonedOp.getInputAndOutputOperands();
+ if (tiledOp.op != clonedOp) {
+ rewriter.replaceOp(clonedOp, tiledOp.results);
+ }
+
+ // TODO(ravishankarm): To fuse ops with `linalg_ext` operations (tile+fuse),
+ // look into calling `pullInProducersInSameGroup`.
+ removeRootOpAttribute(tiledOp.op);
rewriter.replaceOpWithIf(op, dispatchOp.getResults(),
[&](OpOperand &operand) {
@@ -896,8 +962,7 @@
// If this is a dispatchable op that is to be fused into dispatch ops, and
// all its uses are dispatchable ops, don't do anything.
- if ((op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr) ||
- isAlwaysFusedIntoDispatchOp(op)) &&
+ if ((hasFusionGroupsAttribute(op) || isAlwaysFusedIntoDispatchOp(op)) &&
llvm::all_of(op->getUsers(), [](Operation *user) {
return isDispatchableOp(user) ||
user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ||
@@ -955,15 +1020,16 @@
// If this is a root op for fusion, try to pull in the ops to be fused
// together with it.
- if (auto rootOpAttr = op->getAttrOfType<IntegerAttr>(kRootOpAttr)) {
- linalg::LinalgOp clonedLinalgOp = cast<linalg::LinalgOp>(en.second);
- SmallVector<Value> tiledOperands =
- clonedLinalgOp.getInputAndOutputOperands();
-
- pullInProducersInSameGroup(
- rewriter, dispatchOp, clonedLinalgOp, tiledOperands,
- /*tiledLoops=*/ArrayRef<Operation *>(), rootOpAttr.getInt());
- clonedLinalgOp->removeAttr(kRootOpAttr);
+ if (hasRootOpAttribute(op)) {
+ auto clonedLinalgOp = dyn_cast<linalg::LinalgOp>(en.second);
+ if (clonedLinalgOp) {
+ SmallVector<Value> opOperandsVal =
+ clonedLinalgOp.getInputAndOutputOperands();
+ pullInProducersInSameGroup(
+ rewriter, dispatchOp, clonedLinalgOp, opOperandsVal,
+ /*tiledLoops=*/ArrayRef<Operation *>(), getRootNumber(op));
+ removeRootOpAttribute(clonedLinalgOp);
+ }
}
rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(),
@@ -1000,9 +1066,9 @@
auto linalgOps = block.getOps<linalg::LinalgOp>();
for (linalg::LinalgOp linalgOp : llvm::reverse(linalgOps)) {
Operation *op = linalgOp.getOperation();
- if (op->getAttrOfType<IntegerAttr>(kRootOpAttr) ||
- op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr))
+ if (hasRootOpAttribute(op) || hasFusionGroupsAttribute(op)) {
continue;
+ }
if (!isa<GenericOpTy>(op) ||
!llvm::all_of(
cast<linalg::LinalgOp>(op).getIndexingMaps(),
@@ -1010,7 +1076,7 @@
continue;
}
unsigned newGroup = numRoots++;
- op->setAttr(kRootOpAttr, builder.getI64IntegerAttr(newGroup));
+ setRootAttribute(context, op, newGroup);
for (OpOperand *operand : linalgOp.getOutputTensorOperands()) {
auto producer = operand->get().getDefiningOp<linalg::LinalgOp>();
@@ -1032,19 +1098,23 @@
MLIRContext *context = funcOp.getContext();
OpBuilder builder(context);
for (Block &block : funcOp) {
- auto linalgOps = block.getOps<linalg::LinalgOp>();
-
- // Tiling and fusion in linalg works by tiling the last operation in the
- // fusion group and then pull producer ops into the tiled loops. So go in
- // the reverse order here.
- for (linalg::LinalgOp linalgOp : llvm::reverse(linalgOps)) {
+ // Tiling and fusion works by tiling the last operation in the fusion group
+ // and then pull producer ops into the tiled loops. So go in the reverse
+ // order here.
+ for (Operation &op : llvm::reverse(block)) {
// Start with a root operation and fuse its producers.
- Operation *op = linalgOp.getOperation();
- if (!isRootOp(op)) continue;
+ if (!isRootOp(&op)) continue;
unsigned newGroup = numRootOps++;
- op->setAttr(kRootOpAttr, builder.getI64IntegerAttr(newGroup));
+ setRootAttribute(context, &op, newGroup);
- for (OpOperand *operand : linalgOp.getOutputTensorOperands()) {
+ linalg::OpOperandVector outOperands =
+ TypeSwitch<Operation *, linalg::OpOperandVector>(&op)
+ .Case<linalg::LinalgOp>([&](auto linalgOp) {
+ return linalgOp.getOutputTensorOperands();
+ })
+ .Default(
+ [&](Operation *) -> linalg::OpOperandVector { return {}; });
+ for (OpOperand *operand : outOperands) {
auto producer = operand->get().getDefiningOp<linalg::LinalgOp>();
if (!producer) continue;
if (producer.getNumLoops() != producer.getNumParallelLoops()) continue;
@@ -1059,20 +1129,20 @@
// maps The root operation can be fused with its consumer. To do this,
// mark the consumer as the root and add the operation to the fusion
// group.
- for (linalg::LinalgOp linalgOp : linalgOps) {
+ for (linalg::LinalgOp linalgOp : block.getOps<linalg::LinalgOp>()) {
Operation *op = linalgOp.getOperation();
- IntegerAttr rootOpAttr = op->getAttrOfType<IntegerAttr>(kRootOpAttr);
- if (!rootOpAttr) continue;
+ if (!hasRootOpAttribute(op)) continue;
if (op->getNumResults() != 1 || !op->hasOneUse()) continue;
OpOperand &use = *op->use_begin();
Operation *user = use.getOwner();
- if (user->getAttrOfType<IntegerAttr>(kRootOpAttr) ||
- user->getAttrOfType<IntegerAttr>(kFusionGroupsAttr))
+ if (hasRootOpAttribute(user) || hasFusionGroupsAttribute(user)) {
continue;
+ }
linalg::LinalgOp consumer = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (!consumer ||
- consumer.getNumLoops() != consumer.getNumParallelLoops())
+ consumer.getNumLoops() != consumer.getNumParallelLoops()) {
continue;
+ }
AffineMap consumerIndexingMap = consumer.getTiedIndexingMap(&use);
AffineMap producerIndexingMap =
linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0));
@@ -1081,15 +1151,35 @@
consumerIndexingMap.getResults()) {
continue;
}
- user->setAttr(kRootOpAttr, rootOpAttr);
- op->removeAttr(kRootOpAttr);
- appendToFusionGroup(op, rootOpAttr.getInt());
+ int64_t rootNumber = getRootNumber(op);
+ setRootAttribute(context, user, rootNumber);
+ removeRootOpAttribute(op);
+ appendToFusionGroup(op, rootNumber);
}
}
}
return numRootOps;
}
+namespace {
+/// Pass declaration.
+struct DispatchLinalgOnTensorsPass
+ : public DispatchLinalgOnTensorsBase<DispatchLinalgOnTensorsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
+ scf::SCFDialect, ShapeDialect, tensor::TensorDialect>();
+ }
+ DispatchLinalgOnTensorsPass() = default;
+ DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
+ void runOnOperation() override;
+
+ private:
+ Statistic numDispatches{this, "number of dispatches",
+ "Number of Flow dispatches created"};
+};
+} // namespace
+
void DispatchLinalgOnTensorsPass::runOnOperation() {
FuncOp funcOp = getOperation();
@@ -1128,30 +1218,35 @@
auto tileSizeFn = [&](OpBuilder &builder,
Operation *op) -> SmallVector<Value, 4> {
- auto numParallelDims = getNumOuterParallelLoops(cast<linalg::LinalgOp>(op));
- auto numTiledLoops = getNumTilableLoops(cast<linalg::LinalgOp>(op));
-
- // Default to zero to skip tiling.
- auto zero = builder.create<ConstantIndexOp>(op->getLoc(), 0);
- SmallVector<Value, 4> useTileSizes(numParallelDims, zero);
+ SmallVector<unsigned> partitionedLoops = getPartitionedLoops(op);
+ if (partitionedLoops.empty()) return {};
+ unsigned maxDepth = partitionedLoops.back() + 1;
if (!clLinalgOnTensorsTileSizes.empty()) {
SmallVector<int64_t, 2> tileSizes(clLinalgOnTensorsTileSizes.begin(),
clLinalgOnTensorsTileSizes.end());
- useTileSizes.resize(std::min<size_t>(tileSizes.size(), numParallelDims));
return llvm::to_vector<4>(llvm::map_range(
ArrayRef<int64_t>(tileSizes).take_front(
- std::min<size_t>(tileSizes.size(), numParallelDims)),
+ std::min<size_t>(tileSizes.size(), maxDepth)),
[&](int64_t t) -> Value {
return builder.create<ConstantIndexOp>(op->getLoc(), t);
}));
}
- // For ops with more than 3 parallel dimensions, we want to ignore the
- // higher dimension and tile along last three dimensions.
- for (size_t dim = 0; dim < numTiledLoops; ++dim) {
- useTileSizes[numParallelDims - dim - 1] =
- buildFlowWorkgroupInfoOp<Flow::DispatchWorkgroupSizeOp>(builder, dim);
+ // Set all loops not partitioned to tile size 0. and those partitioned to
+ // `flow.workgroup.size`.
+ auto zero = builder.create<ConstantIndexOp>(op->getLoc(), 0);
+ SmallVector<Value, 4> useTileSizes(maxDepth, zero);
+ llvm::DenseSet<unsigned> partitionedLoopsSet;
+ partitionedLoopsSet.insert(partitionedLoops.begin(),
+ partitionedLoops.end());
+ unsigned currFlowDim = 0;
+ for (size_t dim = maxDepth; dim > 0; dim--) {
+ if (partitionedLoopsSet.count(dim - 1)) {
+ useTileSizes[dim - 1] =
+ buildFlowWorkgroupInfoOp<Flow::DispatchWorkgroupSizeOp>(
+ builder, currFlowDim++);
+ }
}
return useTileSizes;
};
@@ -1168,7 +1263,7 @@
.setTileSizeComputationFunction(tileSizeFn);
assert(linalgTilingOptions.distribution.hasValue());
- patterns.insert<TileAndDistributeOnTensorsPattern>(
+ patterns.insert<TileAndDistributeLinalgOpsPattern, TiledOpInterfacePattern>(
context, linalgTilingOptions,
// TODO(nicolavasilache): use refactored `getWorkgroupMarker()`
linalg::LinalgTransformationFilter(
@@ -1176,7 +1271,6 @@
// Add canonicalization patterns.
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
- patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
@@ -1246,6 +1340,14 @@
.wasInterrupted()) {
signalPassFailure();
}
+
+ // Now try to see if we can tie certain results to operands in order to
+ // indicate sharing storage. This need to happen here because it needs to
+ // access region block arguments for input/output tensors, which aren't
+ // available until now.
+ funcOp.walk([&](IREE::Flow::DispatchWorkgroupsOp op) {
+ tryToTieOperandsAndResults(op);
+ });
}
std::unique_ptr<OperationPass<FuncOp>> createDispatchLinalgOnTensorsPass() {
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 8382262..5e30bed 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -15,8 +15,7 @@
// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]])
// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?x?xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
// CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0]
// CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1]
// CHECK-DAG: %[[WGID_X:.+]] = flow.dispatch.workgroup.id[0]
@@ -40,7 +39,7 @@
// CHECK: %[[RESULT:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
-// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]]
+// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG5]]
// CHECK-SAME: offsets = [%[[ARG7]], %[[ARG8]]]
// -----
@@ -876,3 +875,135 @@
// CHECK-NOT: linalg.fill
// CHECK-NOT: linalg.matmul
// CHECK: return
+
+// -----
+
+func @scatter(
+ %original : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %1 = addf %arg0, %arg1 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK: func @scatter(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[WORKLOAD:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups[%[[WORKLOAD]], %[[C1]], %[[C1]]]
+// CHECK-SAME: (%[[ARG2]], %[[ARG1]], %[[ARG0]])
+// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x1xi32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
+// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG5]], offsets = [], sizes = [], strides = []
+// CHECK: scf.for %[[ARG6:.+]] =
+// CHECK-DAG: %[[UPDATE_TILE:.+]] = flow.dispatch.tensor.load %[[ARG3]], offsets = [%[[ARG6]], 0]
+// CHECK-DAG: %[[INDICES_TILE:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [%[[ARG6]], 0]
+// CHECK-DAG: %[[RESULT_TILE:.+]] = linalg_ext.scatter
+// CHECK-SAME: {__internal_linalg_transform__ = "workgroup"}
+// CHECK-SAME: ins(%[[UPDATE_TILE]], %[[INDICES_TILE]] : tensor<?x?xf32>, tensor<?x1xi32>)
+// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<?x?xf32>)
+// CHECK: flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG5]], offsets = [0, 0]
+// CHECK: return %[[RESULT]] : tensor<?x?xf32>
+
+// -----
+
+func @sort_3d(%arg0: tensor<?x?x?xi32>, %arg1 : tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ %0, %1 = linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %2 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %2 : i1
+ } -> tensor<?x?x?xi32>, tensor<?x?x?xf32>
+ return %0, %1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MULMAP:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[MINMAP:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK: func @sort_3d(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[WLY:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[WLX:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[RESULT:.+]]:2 = flow.dispatch.workgroups[%[[WLX]], %[[WLY]], %[[C1]]]
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]])
+// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?x?x?xi32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?x?x?xf32>) {
+// CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1]
+// CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0]
+// CHECK-DAG: %[[SHAPE:.+]] = flow.dispatch.shape %[[ARG2]]
+// CHECK-DAG: %[[D0:.+]] = shapex.ranked_dim %[[SHAPE]][0]
+// CHECK-DAG: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1]
+// CHECK-DAG: %[[D2:.+]] = shapex.ranked_dim %[[SHAPE]][2]
+// CHECK-DAG: %[[WGID_X:.+]] = flow.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[WGCOUNT_X:.+]] = flow.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[WGID_Y:.+]] = flow.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[WGCOUNT_Y:.+]] = flow.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[LB_Y:.+]] = affine.apply #[[MULMAP]](%[[WGID_Y]])[%[[WGSIZE_Y]]]
+// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MULMAP]](%[[WGCOUNT_Y]])[%[[WGSIZE_Y]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[D1]] step %[[STEP_Y]] {
+// CHECK-DAG: %[[TS_Y:.+]] = affine.min #[[MINMAP]](%[[IV0]])[%[[WGSIZE_Y]], %[[D1]]]
+// CHECK-DAG: %[[LB_X:.+]] = affine.apply #[[MULMAP]](%[[WGID_X]])[%[[WGSIZE_X]]]
+// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MULMAP]](%[[WGCOUNT_X]])[%[[WGSIZE_X]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[LB_X]] to %[[D2]] step %[[STEP_X]] {
+// CHECK: %[[TS_X:.+]] = affine.min #[[MINMAP]](%[[IV1]])[%[[WGSIZE_X]], %[[D2]]]
+// CHECK-DAG: %[[OUT1_TILE:.+]] = flow.dispatch.tensor.load %[[ARG2]]
+// CHECK-SAME: offsets = [0, %[[IV0]], %[[IV1]]]
+// CHECK-SAME: sizes = [%[[D0]], %[[TS_Y]], %[[TS_X]]]
+// CHECK-DAG: %[[OUT2_TILE:.+]] = flow.dispatch.tensor.load %[[ARG3]]
+// CHECK-SAME: offsets = [0, %[[IV0]], %[[IV1]]]
+// CHECK-SAME: sizes = [%[[D0]], %[[TS_Y]], %[[TS_X]]]
+// CHECK: %[[RESULT_TILE:.+]]:2 = linalg_ext.sort dimension(0)
+// CHECK-SAME: {__internal_linalg_transform__ = "workgroup"}
+// CHECK-SAME: outs(%[[OUT1_TILE]], %[[OUT2_TILE]] : tensor<?x?x?xi32>, tensor<?x?x?xf32>)
+// CHECK-DAG: flow.dispatch.tensor.store %[[RESULT_TILE]]#0
+// CHECK-SAME: offsets = [0, %[[IV0]], %[[IV1]]]
+// CHECK-SAME: sizes = [%[[D0]], %[[TS_Y]], %[[TS_X]]]
+// CHECK-DAG: flow.dispatch.tensor.store %[[RESULT_TILE]]#1
+// CHECK-SAME: offsets = [0, %[[IV0]], %[[IV1]]]
+// CHECK-SAME: sizes = [%[[D0]], %[[TS_Y]], %[[TS_X]]]
+// CHECK: }
+// CHECK: }
+// CHECK: flow.return
+// CHECK: }
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_1d(%arg0: tensor<?xi32>, %arg1 : tensor<?xf32>)
+ -> (tensor<?xi32>, tensor<?xf32>) {
+ %0, %1 = linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?xi32>, tensor<?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %2 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %2 : i1
+ } -> tensor<?xi32>, tensor<?xf32>
+ return %0, %1 : tensor<?xi32>, tensor<?xf32>
+}
+// CHECK: func @sort_1d(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[RESULT:.+]]:2 = flow.dispatch.workgroups[%[[C1]], %[[C1]], %[[C1]]]
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]])
+// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?xi32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:?xf32>) {
+// CHECK-DAG: %[[OUT1_TILE:.+]] = flow.dispatch.tensor.load %[[ARG2]], offsets = [], sizes = []
+// CHECK-DAG: %[[OUT2_TILE:.+]] = flow.dispatch.tensor.load %[[ARG3]], offsets = [], sizes = []
+// CHECK: %[[RESULT_TILE:.+]]:2 = linalg_ext.sort dimension(0)
+// CHECK-SAME: {__internal_linalg_transform__ = "workgroup"}
+// CHECK-SAME: outs(%[[OUT1_TILE]], %[[OUT2_TILE]] : tensor<?xi32>, tensor<?xf32>)
+// CHECK-DAG: flow.dispatch.tensor.store %[[RESULT_TILE]]#0, %[[ARG2]]
+// CHECK-DAG: flow.dispatch.tensor.store %[[RESULT_TILE]]#1, %[[ARG3]]
+// CHECK: flow.return
+// CHECK: }
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
index f686600..d0d413e 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -380,27 +380,6 @@
//===----------------------------------------------------------------------===//
namespace {
-/// Base pattern for tiling TiledOpInterfaceOps.
-struct TiledOpInterfaceBaseTilingPattern : public RewritePattern {
- TiledOpInterfaceBaseTilingPattern(StringRef opName, MLIRContext *context,
- linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : RewritePattern(opName, benefit, context),
- filter(filter),
- options(options) {}
-
- LogicalResult matchAndRewriteBase(Operation *op, ValueRange dest,
- PatternRewriter &rewriter,
- TiledOp &result) const;
-
- private:
- /// LinalgTransformMarker handles special attribute manipulations.
- linalg::LinalgTransformationFilter filter;
- /// Options to control tiling;
- linalg::LinalgTilingOptions options;
-};
template <typename OpTy>
struct LinalgExtTilingPattern : public TiledOpInterfaceBaseTilingPattern {
@@ -409,12 +388,12 @@
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : TiledOpInterfaceBaseTilingPattern(OpTy::getOperationName(), context,
- options, filter, benefit) {}
+ : TiledOpInterfaceBaseTilingPattern(context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto linalgExtOp = cast<LinalgExtOp>(op);
+ auto linalgExtOp = dyn_cast<LinalgExtOp>(op);
+ if (!linalgExtOp) return failure();
TiledOp tiledOp;
// Check for failure.
if (failed(TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
@@ -440,13 +419,12 @@
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : TiledOpInterfaceBaseTilingPattern(
- tensor::InsertSliceOp::getOperationName(), context, options, filter,
- benefit) {}
+ : TiledOpInterfaceBaseTilingPattern(context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- tensor::InsertSliceOp insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op);
+ if (!insertSliceOp) return failure();
TiledOp tiledOp;
// Check for failure.
if (failed(TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
index e251cab..a292c3d 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
@@ -30,6 +30,30 @@
FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, Operation *op, ValueRange dest,
const linalg::LinalgTilingOptions &options);
+/// Base rewrite pattern to tile and distribute operations that implement the
+/// `TiledOpInterface`.
+/// Base pattern for tiling TiledOpInterfaceOps.
+struct TiledOpInterfaceBaseTilingPattern : public RewritePattern {
+ TiledOpInterfaceBaseTilingPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
+ filter(filter),
+ options(options) {}
+
+ LogicalResult matchAndRewriteBase(Operation *op, ValueRange dest,
+ PatternRewriter &rewriter,
+ TiledOp &result) const;
+
+ private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ linalg::LinalgTransformationFilter filter;
+ /// Options to control tiling;
+ linalg::LinalgTilingOptions options;
+};
+
} // namespace linalg_ext
} // namespace iree_compiler
} // namespace mlir