blob: 47a352af41e212bd65db855ee847b8f2b53410b9 [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h"
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-dispatch-creation-convert-tensor-to-flow"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_CONVERTTENSORTOFLOWPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
/// Return `true` if the given op is contained in DispatchWorkgroupsOp or in a
/// DispatchRegionOp.
static bool isInDispatchRegion(Operation *op) {
return op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ||
op->getParentOfType<IREE::Flow::DispatchRegionOp>();
}
/// Wrap a single op in a DispatchWorkgroupsOp.
static FailureOr<IREE::Flow::DispatchWorkgroupsOp>
wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, Operation *op) {
SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
if (failed(IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps())))
return failure();
// Wrap operation.
auto regionOp = IREE::Flow::wrapOpInDispatchRegion(rewriter, op);
if (failed(regionOp))
return failure();
if (failed(cloneProducersToRegion(rewriter, *regionOp)))
return failure();
auto workgroupsOp =
IREE::Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(*regionOp,
rewriter);
if (failed(workgroupsOp))
return failure();
return *workgroupsOp;
}
/// Wrap all given ops in a DispatchWorkgroupsOp.
static FailureOr<SmallVector<IREE::Flow::DispatchWorkgroupsOp>>
wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter,
SmallVector<Operation *> rootOps) {
SmallVector<IREE::Flow::DispatchWorkgroupsOp> result;
for (Operation *rootOp : rootOps) {
auto workgroupsOp = wrapInWorkgroupsOp(rewriter, rootOp);
if (failed(workgroupsOp))
return failure();
result.push_back(*workgroupsOp);
}
return result;
}
/// Rewrite top-level InsertSliceOps to FlowUpdateOps or wrap them in a
/// dispatch region. Returns the number of dispatches for non-contiguous insert
/// slices created.
static FailureOr<int> convertInsertSliceOps(
mlir::TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp,
SmallVector<IREE::Flow::DispatchWorkgroupsOp> &workgroupsOps) {
// Find eligible InsertSliceOps.
SmallVector<tensor::InsertSliceOp> insertSliceOps;
funcOp.walk([&](tensor::InsertSliceOp op) {
if (!isInDispatchRegion(op))
insertSliceOps.push_back(op);
});
// Rewrite InsertSliceOps to FlowUpdateOps.
SmallVector<Operation *> remainingInsertSliceOps;
for (tensor::InsertSliceOp insertSliceOp : insertSliceOps) {
if (failed(IREE::Flow::convertInsertSliceOpToFlowUpdateOp(rewriter,
insertSliceOp))) {
remainingInsertSliceOps.push_back(insertSliceOp);
}
}
int64_t numRemainingInsertSliceOps =
static_cast<int64_t>(remainingInsertSliceOps.size());
// Create a DispatchWorkgroupsOp for every remaining InsertSliceOp.
FailureOr<SmallVector<IREE::Flow::DispatchWorkgroupsOp>> newWorkgroupsOps =
wrapInWorkgroupsOp(rewriter, remainingInsertSliceOps);
if (failed(newWorkgroupsOps))
return failure();
workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());
return numRemainingInsertSliceOps;
}
/// Rewrite top-level ExtractSliceOps to FlowSliceOps or wrap them in a
/// dispatch region. Returns the number of dispatches for non-contiguous extract
/// slices created.
static FailureOr<size_t> convertExtractSliceOps(
mlir::TensorDimTrackingRewriter &rewriter, mlir::FunctionOpInterface funcOp,
SmallVector<IREE::Flow::DispatchWorkgroupsOp> &workgroupsOps) {
// Find eligible ExtractSliceOps.
SmallVector<tensor::ExtractSliceOp> extractSliceOps;
funcOp.walk([&](tensor::ExtractSliceOp op) {
if (!isInDispatchRegion(op))
extractSliceOps.push_back(op);
});
// Rewrite ExtractSliceOps to FlowSliceOps.
SmallVector<Operation *> remainingExtractSliceOps;
for (tensor::ExtractSliceOp extractSliceOp : extractSliceOps) {
if (failed(IREE::Flow::convertExtractSliceOpToFlowSliceOp(
rewriter, extractSliceOp))) {
remainingExtractSliceOps.push_back(extractSliceOp);
}
}
int64_t numRemainingExtractSliceOps =
static_cast<int64_t>(remainingExtractSliceOps.size());
// Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp.
FailureOr<SmallVector<IREE::Flow::DispatchWorkgroupsOp>> newWorkgroupsOps =
wrapInWorkgroupsOp(rewriter, remainingExtractSliceOps);
if (failed(newWorkgroupsOps))
return failure();
workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());
return numRemainingExtractSliceOps;
}
namespace {
struct ConvertTensorToFlowPass
: public impl::ConvertTensorToFlowPassBase<ConvertTensorToFlowPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
void ConvertTensorToFlowPass::runOnOperation() {
mlir::FunctionOpInterface funcOp = getOperation();
mlir::TensorDimTrackingRewriter rewriter(funcOp);
mlir::MLIRContext *context = &getContext();
auto workgroupsOps = SmallVector<IREE::Flow::DispatchWorkgroupsOp>();
funcOp->walk([&](IREE::Flow::DispatchWorkgroupsOp workgroupsOp) {
workgroupsOps.push_back(workgroupsOp);
});
// Rewrite InsertSliceOps to FlowUpdateOps.
FailureOr<size_t> numSlowInsertSliceDispatches =
convertInsertSliceOps(rewriter, funcOp, workgroupsOps);
if (failed(numSlowInsertSliceDispatches)) {
funcOp->emitOpError(
"failed to create dispatch region for `tensor.insert_slice`");
return signalPassFailure();
}
numSlowCopyDispatches += numSlowInsertSliceDispatches.value();
// Rewrite ExtractSliceOps to FlowUpdateOps.
FailureOr<size_t> numSlowExtractSliceDispatches =
convertExtractSliceOps(rewriter, funcOp, workgroupsOps);
if (failed(numSlowExtractSliceDispatches)) {
funcOp->emitOpError(
"failed to create dispatch region for `tensor.extract_slice`");
return signalPassFailure();
}
numSlowCopyDispatches += numSlowExtractSliceDispatches.value();
// Canonicalize to flow.tensor ops.
RewritePatternSet convertToFlowPatterns(context);
IREE::Flow::populateTensorToFlowConversionPatterns(context,
convertToFlowPatterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(
convertToFlowPatterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(
convertToFlowPatterns, context);
IREE::Flow::TensorBitCastOp::getCanonicalizationPatterns(
convertToFlowPatterns, context);
if (failed(applyPatternsGreedily(funcOp, std::move(convertToFlowPatterns)))) {
funcOp->emitOpError("failed conversion to flow.tensor ops");
return signalPassFailure();
}
// fold `tensor.insert_slice/extract_slice` operations with
// `flow.dispatch.tensor.load/store`.
RewritePatternSet foldExtractInsertSliceOps(context);
IREE::Flow::populateTensorSliceOpWithDispatchTensorOpFoldingPatterns(
foldExtractInsertSliceOps, context);
if (failed(applyPatternsGreedily(funcOp,
std::move(foldExtractInsertSliceOps)))) {
funcOp->emitOpError("failed to insert/extract_slice with "
"flow.dispatch.tensor.load/store");
return signalPassFailure();
}
}
} // namespace mlir::iree_compiler::DispatchCreation