blob: a02b5a5d904aafb4757911299cee854fde91bd14 [file]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===--------------- SplitReduction.cpp ----------------------------===//
//
// Split reduction dimension to increase parallelism of a linalg operation.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_SPLITREDUCTIONPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
// TODO(thomasraoux): Move to attributes.
static llvm::cl::opt<int64_t>
splitReductionRatio("iree-dispatch-creation-split-matmul-reduction",
llvm::cl::desc("split ratio"), llvm::cl::init(1));
static llvm::cl::list<int64_t> topkSplitReductionRatio(
"iree-dispatch-creation-topk-split-reduction",
llvm::cl::desc("comma separated list of split ratios"),
llvm::cl::CommaSeparated);
static LogicalResult splitReductionOnMatmul(
RewriterBase &rewriter, linalg::MatmulOp op,
linalg::ControlSplitReductionFn controlSplitReductionFn) {
// Since user information about compilation are passed through attributes we
// need to make sure to propagate those.
SmallVector<NamedAttribute> prunedAttributeList =
linalg::getPrunedAttributeList(op);
FailureOr<linalg::SplitReductionResult> result =
linalg::splitReduction(rewriter, op, controlSplitReductionFn);
if (failed(result)) {
return failure();
}
result->splitLinalgOp->setAttrs(prunedAttributeList);
return result;
}
namespace {
struct SplitReductionPass final
: public impl::SplitReductionPassBase<SplitReductionPass> {
void runOnOperation() override {
if (splitReductionRatio.getValue() <= 1 &&
topkSplitReductionRatio.empty()) {
return;
}
MLIRContext *context = &getContext();
auto funcOp = getOperation();
auto matmulSplitReductionControlFn =
[&](linalg::LinalgOp op) -> linalg::SplitReductionOptions {
// For matmul make the new parallel dimension first so that it looks
// like a batch_matmul and can follow the same codegen.
return {int64_t(splitReductionRatio), 0, /*innerParallel=*/false};
};
SmallVector<linalg::MatmulOp> matmulCandidates;
IRRewriter rewriter(context);
funcOp->walk([&](linalg::MatmulOp op) { matmulCandidates.push_back(op); });
for (auto op : matmulCandidates) {
(void)splitReductionOnMatmul(rewriter, op, matmulSplitReductionControlFn);
}
IREE::LinalgExt::TopkSplitReductionControlFn topkSplitReductionControlFn =
[&](int64_t splitReductionDepth) -> int64_t {
SmallVector<int64_t> reductionRatios(topkSplitReductionRatio.begin(),
topkSplitReductionRatio.end());
if (splitReductionDepth >= reductionRatios.size()) {
return -1;
} else {
return reductionRatios[splitReductionDepth];
}
};
SmallVector<IREE::LinalgExt::TopkOp> topkCandidates;
funcOp->walk(
[&](IREE::LinalgExt::TopkOp op) { topkCandidates.push_back(op); });
for (auto op : topkCandidates) {
(void)splitReduction(rewriter, op, topkSplitReductionControlFn);
}
}
};
} // namespace
} // namespace mlir::iree_compiler::DispatchCreation