NFC refactor dispatch region formation to allow multiple results. (#8937)
Current dispatch region formation does not handle cases where the
fused operations themselves have uses outside of the ops pulled into
the same dispatch, i.e. it only handles cases where the results of the
dispatch are same as the results of the root op. Relax this to allow
for fusing with operations, where even the fused ops can have uses
outside of the dispatch. Enable this in a following commit.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index c4a87bb..ec4df48 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -4,6 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <deque>
+
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h"
@@ -125,27 +127,6 @@
}
//===----------------------------------------------------------------------===//
-// Utility methods
-//===----------------------------------------------------------------------===//
-
-/// Given the `shape` of the computation with the first element being the
-/// slowest varying and last element being the fastest warying returns the
-/// workload value with
-/// - fastest varying dimension first, i.e., x, y, z order
-/// - the workload padded to `kNumMaxParallelDims` with ones if needed.
-/// The `shape` is expected to be of size less than or equal to
-/// `kNumMaxParallelDims`.
-static SmallVector<Value, 4> convertToWorkload(OpBuilder &b, Location loc,
- ArrayRef<Value> shape) {
- assert(shape.size() <= kNumMaxParallelDims &&
- "workload cannot be more than 3D for now");
- SmallVector<Value, 4> workload = llvm::to_vector<4>(llvm::reverse(shape));
- Value one = b.create<arith::ConstantIndexOp>(loc, 1);
- workload.resize(kNumMaxParallelDims, one);
- return workload;
-}
-
-//===----------------------------------------------------------------------===//
// Op property charecterizations
//===----------------------------------------------------------------------===//
@@ -174,7 +155,7 @@
// trivially clonable too, but they cause problems
// with bufferization. Make them clonable when fixed.
if (isa<arith::IndexCastOp, linalg::InitTensorOp, tensor::CastOp,
- tensor::ExtractOp, tensor::ExtractSliceOp>(op)) {
+ tensor::ExtractOp, tensor::ExtractSliceOp, tensor::PadOp>(op)) {
return true;
}
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
@@ -204,93 +185,265 @@
}
//===----------------------------------------------------------------------===//
+// Methods for getting the workload information for dispatch region creation.
+//===----------------------------------------------------------------------===//
+
+/// For a given operation returns the loop ranges needed to compute the op.
+template <typename T>
+static SmallVector<Range> getLoopRanges(T operation, Location loc,
+ PatternRewriter &rewriter);
+
+template <>
+SmallVector<Range> getLoopRanges<linalg::LinalgOp>(linalg::LinalgOp linalgOp,
+ Location loc,
+ PatternRewriter &rewriter) {
+ return linalgOp.createLoopRanges(rewriter, loc);
+}
+
+template <>
+SmallVector<Range> getLoopRanges<IREE::LinalgExt::TiledOpInterface>(
+ IREE::LinalgExt::TiledOpInterface tilableOp, Location loc,
+ PatternRewriter &rewriter) {
+ return tilableOp.getIterationDomain(rewriter);
+}
+
+template <>
+SmallVector<Range> getLoopRanges<tensor::InsertSliceOp>(
+ tensor::InsertSliceOp insertSliceOp, Location loc,
+ PatternRewriter &rewriter) {
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value source = insertSliceOp.source();
+ SmallVector<Range> loopRanges(insertSliceOp.getSourceType().getRank(),
+ Range{zero, one, one});
+ for (auto dim : llvm::seq<unsigned>(0, loopRanges.size())) {
+ loopRanges[dim].size = rewriter.create<tensor::DimOp>(loc, source, dim);
+ }
+ return loopRanges;
+}
+
+template <>
+SmallVector<Range> getLoopRanges<tensor::ExtractSliceOp>(
+ tensor::ExtractSliceOp sliceOp, Location loc, PatternRewriter &rewriter) {
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ ReifiedRankedShapedTypeDims resultDims;
+ (void)sliceOp.reifyResultShapes(rewriter, resultDims);
+ return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) {
+ return Range{zero, v, one};
+ }));
+}
+
+/// Given the `shape` of the computation with the first element being the
+/// slowest varying and last element being the fastest warying returns the
+/// workload value with
+/// - fastest varying dimension first, i.e., x, y, z order
+/// - the workload padded to `kNumMaxParallelDims` with ones if needed.
+/// The `shape` is expected to be of size less than or equal to
+/// `kNumMaxParallelDims`.
+static SmallVector<Value> convertToWorkload(OpBuilder &b, Location loc,
+ ArrayRef<Value> shape) {
+ assert(shape.size() <= kNumMaxParallelDims &&
+ "workload cannot be more than 3D for now");
+ SmallVector<Value> workload = llvm::to_vector(llvm::reverse(shape));
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ workload.resize(kNumMaxParallelDims, one);
+ return workload;
+}
+
+/// Compute the workload to use for the workgroup based on the root op.
+template <typename OpTy>
+static FailureOr<SmallVector<Value>> getWorkloadForRootOp(
+ PatternRewriter &rewriter, OpTy rootOp) {
+ // Compute workgroup count to use for the dispatch op. These are the ranges
+ // of the outermost parallel loops that can be distributed.
+ Location loc = rootOp->getLoc();
+ SmallVector<Range> loopRanges = getLoopRanges(rootOp, loc, rewriter);
+
+ // TODO: The use of PartitionableLoopsInterface to get the loop bounds
+ // of the distributed loop is legacy. This can be controlled purely in the
+ // backend.
+ auto partitionableLoopsOp =
+ dyn_cast<PartitionableLoopsInterface>(rootOp.getOperation());
+ if (!partitionableLoopsOp) {
+ return rewriter.notifyMatchFailure(
+ rootOp, "expected op to implement ParitionableLoopsInterface");
+ }
+ SmallVector<unsigned> partitionedLoops =
+ partitionableLoopsOp.getPartitionableLoops(kNumMaxParallelDims);
+ SmallVector<Value> count;
+ for (auto dim : partitionedLoops) {
+ count.push_back(loopRanges[dim].size);
+ }
+ return convertToWorkload(rewriter, loc, count);
+}
+
+//===----------------------------------------------------------------------===//
// Methods that help creating the dispatch regions
//===----------------------------------------------------------------------===//
-// Creates a flow.dispatch.workgroup op without arguments.
-// All the necessary operands are transiently captured and rewritten late as
-// operands. This greatly simplifies transformations into the resulting op.
-static std::pair<IREE::Flow::DispatchWorkgroupsOp, Operation *>
-buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc,
- ArrayRef<Value> count, Operation *op,
- ValueRange resultDynamicDims) {
- SmallVector<Value> operands, operandDims;
- SmallVector<int64_t> tiedOperands;
+/// For an operation to be moved into the dispatch region, append `resultTypes`
+/// with the type of the results dispatch region has to return. Also
+/// append `resultDynamicDims` with values that represent the dynamic shapes of
+/// result values returned.
+static LogicalResult computeDispatchResultTypeAndDynamicDims(
+ PatternRewriter &rewriter, Operation *dispatchOp,
+ SmallVector<Type> &resultTypes, SmallVector<Value> &resultDynamicDims) {
+ auto currResultTypes = dispatchOp->getResultTypes();
+ resultTypes.append(currResultTypes.begin(), currResultTypes.end());
+ auto rankedShapedTypeOp =
+ dyn_cast<ReifyRankedShapedTypeOpInterface>(dispatchOp);
+ if (!rankedShapedTypeOp) {
+ return rewriter.notifyMatchFailure(
+ dispatchOp,
+ "expected op to implement the ReifyRankedShapedTypeOpInterface");
+ }
+ // Get the values for the result dims.
+ ReifiedRankedShapedTypeDims resultDims;
+ if (failed(rankedShapedTypeOp.reifyResultShapes(rewriter, resultDims))) {
+ return rewriter.notifyMatchFailure(dispatchOp,
+ "failed to reify shape of the result");
+ }
+ if (currResultTypes.size() != resultDims.size()) {
+ return rewriter.notifyMatchFailure(
+ dispatchOp, "expected as many result shapes as number of outputs");
+ }
+ for (auto outputType : llvm::enumerate(currResultTypes)) {
+ auto shapedOutputType = outputType.value().dyn_cast<ShapedType>();
+ if (!shapedOutputType) continue;
+ for (auto dim : llvm::enumerate(shapedOutputType.getShape())) {
+ if (ShapedType::isDynamic(dim.value())) {
+ resultDynamicDims.push_back(
+ resultDims[outputType.index()][dim.index()]);
+ }
+ }
+ }
+ return success();
+}
+
+/// Returns true if the operation has only uses in `tensor.dim` ops.
+static bool hasComputeUsesOutsideDispatch(
+ Operation *op, ArrayRef<Operation *> dispatchOps = {}) {
+ return !llvm::all_of(op->getUsers(), [&](Operation *user) {
+ return isa<tensor::DimOp>(user) || llvm::is_contained(dispatchOps, user);
+ });
+}
+
+/// Creates a flow.dispatch.workgroup op without arguments.
+/// All the necessary operands are transiently captured and rewritten late as
+/// operands. This greatly simplifies transformations into the resulting op.
+static FailureOr<SmallVector<Operation *>>
+buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> workload,
+ ArrayRef<Operation *> dispatchOps) {
+ SmallVector<Value> resultDynamicDims;
+ SmallVector<Type> resultTypes;
+
+ // 1. Compute the result types for the dispatch and the dynamic dimensions
+ // of the result of the dispatch. If operation has only dim uses
+ // do not make the dispatch op return those values. Those uses are
+ // kept on the original op, and later patterns are expected to take care
+ // of them.
+ for (auto op : dispatchOps) {
+ if (!hasComputeUsesOutsideDispatch(op, dispatchOps)) continue;
+ if (failed(computeDispatchResultTypeAndDynamicDims(
+ rewriter, op, resultTypes, resultDynamicDims))) {
+ return failure();
+ }
+ }
+
+ // 2. Create a dispatch op with just the `flow.return` terminator.
auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
- loc, count, op->getResultTypes(), resultDynamicDims, operands,
- operandDims, tiedOperands);
+ loc, workload, resultTypes, resultDynamicDims,
+ /*operands=*/ArrayRef<Value>{}, /*operandDims=*/ArrayRef<Value>{},
+ /*tiedOperands=*/ArrayRef<int64_t>{});
Region ®ion = dispatchOp.body();
Block *block = ®ion.front();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToEnd(block);
+ auto returnOp = rewriter.create<IREE::Flow::ReturnOp>(loc);
+ rewriter.setInsertionPoint(returnOp);
- Operation *clonedOp;
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointToStart(block);
- clonedOp = rewriter.clone(*op);
- unsigned dynamicDimIdx = 0;
- for (auto it : llvm::zip(clonedOp->getResults(),
- dispatchOp.body().getArguments().take_back(
- clonedOp->getNumResults()))) {
- auto resultType = std::get<0>(it).getType().cast<ShapedType>();
- rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
- loc, std::get<0>(it), std::get<1>(it),
- resultDynamicDims.slice(dynamicDimIdx,
- resultType.getNumDynamicDims()));
- dynamicDimIdx += resultType.getNumDynamicDims();
+ // 3. Clone the necessary operations into the dispatch and replace
+ // all uses of the original op with the cloned op within the dispatch.
+ auto resultArgs = region.getArguments();
+ unsigned resultPos = 0;
+ unsigned resultDynamicDimsPos = 0;
+ SmallVector<Value> dispatchOpResults = dispatchOp.getResults();
+ SmallVector<Operation *> clonedOps;
+ clonedOps.reserve(dispatchOps.size());
+ for (auto op : dispatchOps) {
+ Operation *clonedOp = rewriter.clone(*op);
+ clonedOps.push_back(clonedOp);
+ rewriter.replaceOpWithinBlock(op, clonedOp->getResults(), block);
+ rewriter.setInsertionPoint(clonedOp);
+ if (!hasComputeUsesOutsideDispatch(op, dispatchOps)) continue;
+
+ // 3a. Replace all non-dim uses of the original operation with the
+ // corresponding result of the dispatch.
+ rewriter.replaceOpWithIf(op,
+ ArrayRef<Value>(dispatchOpResults)
+ .slice(resultPos, op->getNumResults()),
+ [&](OpOperand &operand) {
+ return !isa<tensor::DimOp>(operand.getOwner());
+ });
+
+ // 3b. For each of the result create a `flow.dispatch.tensor.store`
+ // operation to publish the result of the cloned operation (from within
+ // the dispatch).
+ for (auto clonedOpResult : clonedOp->getResults()) {
+ auto resultType = clonedOpResult.getType().dyn_cast<ShapedType>();
+ if (resultType) {
+ OpBuilder::InsertionGuard g2(rewriter);
+ rewriter.setInsertionPoint(returnOp);
+ unsigned numDynamicDims = resultType.getNumDynamicDims();
+ rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
+ loc, clonedOpResult, resultArgs[resultPos],
+ ArrayRef<Value>(resultDynamicDims)
+ .slice(resultDynamicDimsPos, numDynamicDims));
+ resultDynamicDimsPos += numDynamicDims;
+ }
+ resultPos++;
}
- rewriter.create<IREE::Flow::ReturnOp>(loc);
}
-
LLVM_DEBUG(llvm::dbgs() << "Created dispatchOp shell \n"
<< *dispatchOp << "\n");
- return {dispatchOp, clonedOp};
+ return clonedOps;
}
-// Fuses producers marked in the same group recursively.
-//
-// The impl does not worry about the dispatchOp, operands and arguments are set
-// in a post-pattern `legalizeDispatchWorkgroupOperands` function.
-// To simplify the implementation of the dispatch region formation, we just
-// clone the op that needs to be fused inside the dispatch region and just fuse
-// that one. This avoid any concerns related to tensor operands that are only
-// used for their DimOp. This is a canonicalization that is more involved than
-// necessary across the boundary of regions without captures.
-static void pullInProducersInSameGroup(
- PatternRewriter &rewriter, IREE::Flow::DispatchWorkgroupsOp dispatchOp,
- linalg::LinalgOp rootOp, int64_t groupNum) {
- LLVM_DEBUG(llvm::dbgs() << "pull in producers for op: " << rootOp << "\n");
+/// Returns the list of operations that are to be cloned into the dispatch
+/// based on the root operation.
+static SmallVector<Operation *> getOperationsToMoveIntoDispatch(
+ Operation *rootOp) {
+ SmallVector<Operation *> dispatchOps;
+ dispatchOps.push_back(rootOp);
+ if (!hasRootOpAttribute(rootOp)) return dispatchOps;
- // Scoped within DispatchWorkgroupOp.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointToStart(&dispatchOp.getRegion().front());
- for (auto en : llvm::enumerate(rootOp->getOperands())) {
- if (auto producer = en.value().getDefiningOp<linalg::LinalgOp>()) {
- if (!isInFusionGroup(producer, groupNum)) continue;
- DEBUG_WITH_TYPE(DEBUG_TYPE,
- llvm::dbgs() << "current producer: " << producer << "\n");
+ int64_t groupNum = getRootNumber(rootOp);
+ std::deque<Operation *> worklist;
+ worklist.push_back(rootOp);
+ llvm::SmallDenseSet<Operation *, 2> movedOps;
+ movedOps.insert(rootOp);
- Operation *fusedProducer = rewriter.clone(*producer);
- rewriter.replaceOpWithinBlock(producer, fusedProducer->getResults(),
- &dispatchOp.getRegion().front());
- removeFusionGroupsAttribute(fusedProducer);
-
- pullInProducersInSameGroup(rewriter, dispatchOp, fusedProducer, groupNum);
- } else if (auto producer = en.value().getDefiningOp<tensor::PadOp>()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE,
- llvm::dbgs() << "current producer: " << producer << "\n");
-
- Operation *fusedProducer = rewriter.clone(*producer);
- rewriter.replaceOpWithinBlock(producer, fusedProducer->getResults(),
- &dispatchOp.getRegion().front());
+ while (!worklist.empty()) {
+ Operation *currRoot = worklist.front();
+ worklist.pop_front();
+ for (auto operand : currRoot->getOperands()) {
+ auto producer = operand.getDefiningOp();
+ if (movedOps.count(producer)) continue;
+ if (!producer || !isInFusionGroup(producer, groupNum)) continue;
+ movedOps.insert(producer);
+ worklist.push_back(producer);
+ dispatchOps.push_back(producer);
}
}
+ return dispatchOps;
}
-template <typename OpTy>
-static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
- return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
-}
+//===---------------------------------------------------------------------===//
+// Methods to legalize a dispatch region op, i.e. make it isolated from above.
+//===---------------------------------------------------------------------===//
/// Reorders the operations in `ops` such that they could be inlined into the
/// dispatch region in that order to satisfy dependencies.
@@ -441,6 +594,10 @@
std::swap(reversedValues, valuesDefinedAbove);
}
+//===---------------------------------------------------------------------===//
+// Methods to tie operands and results of a dispatch op.
+//===---------------------------------------------------------------------===//
+
/// Returns the tied operand for the given `resultArg`. Returns nullptr if error
/// or not found.
static BlockArgument getTiedOperandBlockArgument(BlockArgument resultArg) {
@@ -671,61 +828,10 @@
return success();
}
-static bool hasOnlyDimUses(Operation *op) {
- return llvm::all_of(op->getUsers(), [&](Operation *user) {
- return isa<tensor::DimOp>(user);
- });
-}
-
//===----------------------------------------------------------------------===//
-// Patterns that create the dispatch region.
+// Pattern that create the dispatch region.
//===----------------------------------------------------------------------===//
-template <typename T>
-static SmallVector<Range> getLoopRanges(T operation, Location loc,
- PatternRewriter &rewriter);
-
-template <>
-SmallVector<Range> getLoopRanges<linalg::LinalgOp>(linalg::LinalgOp linalgOp,
- Location loc,
- PatternRewriter &rewriter) {
- return linalgOp.createLoopRanges(rewriter, loc);
-}
-
-template <>
-SmallVector<Range> getLoopRanges<IREE::LinalgExt::TiledOpInterface>(
- IREE::LinalgExt::TiledOpInterface tilableOp, Location loc,
- PatternRewriter &rewriter) {
- return tilableOp.getIterationDomain(rewriter);
-}
-
-template <>
-SmallVector<Range> getLoopRanges<tensor::InsertSliceOp>(
- tensor::InsertSliceOp insertSliceOp, Location loc,
- PatternRewriter &rewriter) {
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value source = insertSliceOp.source();
- SmallVector<Range> loopRanges(insertSliceOp.getSourceType().getRank(),
- Range{zero, one, one});
- for (auto dim : llvm::seq<unsigned>(0, loopRanges.size())) {
- loopRanges[dim].size = rewriter.create<tensor::DimOp>(loc, source, dim);
- }
- return loopRanges;
-}
-
-template <>
-SmallVector<Range> getLoopRanges<tensor::ExtractSliceOp>(
- tensor::ExtractSliceOp sliceOp, Location loc, PatternRewriter &rewriter) {
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- ReifiedRankedShapedTypeDims resultDims;
- (void)sliceOp.reifyResultShapes(rewriter, resultDims);
- return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) {
- return Range{zero, v, one};
- }));
-}
-
namespace {
template <typename OpType, template <typename> class Base>
struct CreateDispatchRegionOp : Base<OpType> {
@@ -739,7 +845,7 @@
// TODO(ravishankarm): It is getting strange to track when to apply this
// pattern and when not to. Need to revisit this, with dynamic shape cases
// in mind.
- if (hasOnlyDimUses(rootOp)) return failure();
+ if (!hasComputeUsesOutsideDispatch(rootOp)) return failure();
if (rootOp->template getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
return failure();
}
@@ -748,74 +854,26 @@
return failure();
}
- // Compute workgroup count to use for the dispatch op. These are the ranges
- // of the outermost parallel loops that can be distributed.
- Location loc = rootOp->getLoc();
- SmallVector<Range> loopRanges = getLoopRanges(rootOp, loc, rewriter);
-
- // TODO: The use of PartitionableLoopsInterface to get the loop bounds
- // of the distributed loop is legacy. This can be controlled purely in the
- // backend.
- auto partitionableLoopsOp =
- dyn_cast<PartitionableLoopsInterface>(rootOp.getOperation());
- if (!partitionableLoopsOp) {
- return rewriter.notifyMatchFailure(
- rootOp, "expected op to implement ParitionableLoopsInterface");
- }
- SmallVector<unsigned> partitionedLoops =
- partitionableLoopsOp.getPartitionableLoops(kNumMaxParallelDims);
- SmallVector<Value> count;
- for (auto dim : partitionedLoops) {
- count.push_back(loopRanges[dim].size);
- }
- auto workload = convertToWorkload(rewriter, loc, count);
-
- // Capture dynamic result dimensions.
- ReifiedRankedShapedTypeDims resultDims;
- auto rankedShapedTypeOp =
- dyn_cast<ReifyRankedShapedTypeOpInterface>(rootOp.getOperation());
- if (!rankedShapedTypeOp) {
- return rewriter.notifyMatchFailure(
- rootOp,
- "expected op to implement the ReifyRankedShapedTypeOpInterface");
- }
- if (failed(rankedShapedTypeOp.reifyResultShapes(rewriter, resultDims))) {
- return rewriter.notifyMatchFailure(rootOp,
- "failed to reify shape of the result");
+ // Get the workload to use for the dispatch.
+ FailureOr<SmallVector<Value>> workload =
+ getWorkloadForRootOp(rewriter, rootOp);
+ if (failed(workload)) {
+ return failure();
}
- SmallVector<Value, 4> resultDynamicDims;
- for (auto output : llvm::enumerate(rootOp->getResults())) {
- auto outputType =
- output.value().getType().template dyn_cast<ShapedType>();
- if (!outputType) continue;
- for (auto dim : llvm::enumerate(outputType.getShape())) {
- if (!ShapedType::isDynamic(dim.value())) continue;
- resultDynamicDims.push_back(resultDims[output.index()][dim.index()]);
- }
- }
-
+ SmallVector<Operation *> dispatchOps =
+ getOperationsToMoveIntoDispatch(rootOp);
// Create a simple dispatch op with no operands, and not isolated from
// above.
- auto en = buildOperandLessFlowDispatchWorkgroupOp(
- rewriter, loc, workload, rootOp, resultDynamicDims);
- IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first;
- Operation *clonedOp = en.second;
-
- // Scoped within DispatchWorkgroupOp.
- if (hasRootOpAttribute(rootOp)) {
- if (auto clonedLinalgOp = dyn_cast<linalg::LinalgOp>(clonedOp)) {
- pullInProducersInSameGroup(rewriter, dispatchOp, clonedLinalgOp,
- getRootNumber(rootOp));
- }
+ auto clonedOps = buildOperandLessFlowDispatchWorkgroupOp(
+ rewriter, rootOp.getLoc(), workload.getValue(), dispatchOps);
+ if (failed(clonedOps)) {
+ return failure();
}
- rewriter.replaceOpWithIf(rootOp, dispatchOp.getResults(),
- [&](OpOperand &operand) {
- return !isa<tensor::DimOp>(operand.getOwner());
- });
transformationFilter.replaceLinalgTransformationFilter(rewriter, rootOp);
- transformationFilter.replaceLinalgTransformationFilter(rewriter, clonedOp);
+ transformationFilter.replaceLinalgTransformationFilter(
+ rewriter, clonedOps.getValue()[0]);
return success();
}
@@ -1141,6 +1199,7 @@
// Finally walk all the ops and remove the attributes
funcOp.walk([](Operation *op) {
+ removeFusionGroupsAttribute(op);
removeRootOpAttribute(op);
op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
});
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index c28d184..0c472c6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -106,7 +106,8 @@
// CHECK-DAG: %[[ARG1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-NEXT: flow.dispatch.workgroups[%[[ARG1_DIM1]], %[[ARG0_DIM0]], %[[C1]]]
// CHECK-SAME: (%[[ARG0_DIM0]], %[[ARG1_DIM1]], %[[ARG0]], %[[ARG0_DIM1]], %[[ARG1]], %[[ARG1_DIM0]])
-// CHECK-NEXT: (%[[ARG0_DIM0_CAPTURE:[a-zA-Z0-9_]+]]: index, %[[ARG1_DIM1_CAPTURE:[a-zA-Z0-9_]+]]: index,
+// CHECK-NEXT: (%[[ARG0_DIM0_CAPTURE:[a-zA-Z0-9_]+]]: index,
+// CHECK-SAME: %[[ARG1_DIM1_CAPTURE:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG0_CAPTURE:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>,
// CHECK-SAME: %[[ARG0_DIM1_CAPTURE:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG1_CAPTURE:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>,
@@ -1286,13 +1287,13 @@
// CHECK-SAME: %[[LHS:.+]]: tensor<?xf32>
// CHECK-DAG: %[[RESHAPE:.+]] = flow.tensor.reshape %[[LHS]]
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
-// CHECK-SAME: , %[[RESHAPE]]
+// CHECK-SAME: %[[RESHAPE]]
// CHECK-NOT: tensor.expand_shape
// CHECK: linalg.fill
// CHECK: linalg.matmul
// CHECK: flow.return
// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.workgroups
-// CHECK-SAME: , %[[RESHAPE]]
+// CHECK-SAME: %[[RESHAPE]]
// CHECK-NOT: tensor.expand_shape
// CHECK: linalg.fill
// CHECK: linalg.matmul