blob: bcc2d00c69bdf39e5567ec64c0fab99110a55360 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_LLVMGPUVECTORLOWERINGPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
//====---------------------------------------------------------------------===//
// Patterns for late vector op lowering.
//====---------------------------------------------------------------------===//
namespace {
struct LLVMGPUVectorLoweringPass final
: impl::LLVMGPUVectorLoweringPassBase<LLVMGPUVectorLoweringPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect>();
registry.insert<memref::MemRefDialect>();
registry.insert<vector::VectorDialect>();
registry.insert<scf::SCFDialect>();
}
void runOnOperation() override {
auto funcOp = getOperation();
{
// Lower high level vector operations like contract or multidim reduce ops
// to lower level vector ops.
RewritePatternSet contractLoweringPatterns(funcOp.getContext());
vector::populateVectorTransferPermutationMapLoweringPatterns(
contractLoweringPatterns);
vector::TransposeOp::getCanonicalizationPatterns(contractLoweringPatterns,
funcOp.getContext());
vector::populateVectorBroadcastLoweringPatterns(contractLoweringPatterns);
vector::populateVectorContractLoweringPatterns(
contractLoweringPatterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
vector::populateVectorMaskOpLoweringPatterns(contractLoweringPatterns);
vector::populateVectorShapeCastLoweringPatterns(contractLoweringPatterns);
vector::populateVectorMultiReductionLoweringPatterns(
contractLoweringPatterns,
vector::VectorMultiReductionLowering::InnerParallel);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(contractLoweringPatterns)))) {
return signalPassFailure();
}
}
RewritePatternSet vectorToLoopsPatterns(&getContext());
VectorTransferToSCFOptions vectorToSCFOptions;
vectorToSCFOptions.enableFullUnroll();
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
vectorToSCFOptions);
memref::populateFoldMemRefAliasOpPatterns(vectorToLoopsPatterns);
vector::populateVectorTransferLoweringPatterns(vectorToLoopsPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorToLoopsPatterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler