blob: a6d630717bb6e7f5e2a1589f4f8fbe48a902bc18 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_LLVMGPUSELECTLOWERINGSTRATEGYPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
namespace {
/// Selects a lowering strategy for taking a hal.executable.variant operation
/// to scalar/native-vector code.
class LLVMGPUSelectLoweringStrategyPass final
: public impl::LLVMGPUSelectLoweringStrategyPassBase<
LLVMGPUSelectLoweringStrategyPass> {
public:
using impl::LLVMGPUSelectLoweringStrategyPassBase<
LLVMGPUSelectLoweringStrategyPass>::LLVMGPUSelectLoweringStrategyPassBase;
void getDependentDialects(DialectRegistry &registry) const override {
// TODO(qedawkins): Once TransformStrategies is deprecated, drop the
// unnecessary dialect registrations.
// clang-format off
registry
.insert<IREE::Codegen::IREECodegenDialect,
IREE::GPU::IREEGPUDialect,
IREE::HAL::HALDialect,
IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect,
gpu::GPUDialect,
nvgpu::NVGPUDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
transform::TransformDialect,
vector::VectorDialect>();
// clang-format on
}
void runOnOperation() override;
};
} // namespace
/// Verify that valid configuration is set for all ops within the compiled
/// module.
template <typename F>
static LogicalResult
verifyLoweringConfiguration(FunctionOpInterface funcOp,
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize, F verificationFn) {
auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
auto loweringConfig =
getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(op);
if (!loweringConfig)
return WalkResult::advance();
return verificationFn(op, loweringConfig, translationInfo, workgroupSize);
});
return failure(walkResult.wasInterrupted());
}
static LogicalResult
verifyEntryPoint(FunctionOpInterface funcOp,
IREE::Codegen::TranslationInfoAttr translationInfo) {
std::optional<SmallVector<int64_t>> workgroupSize = getWorkgroupSize(funcOp);
if (!workgroupSize) {
return funcOp->emitOpError(
"failed to get workgroup size needed for verification");
}
return verifyLoweringConfiguration(
funcOp, translationInfo, workgroupSize.value(), verifyGPUMatmulPipeline);
return success();
}
void LLVMGPUSelectLoweringStrategyPass::runOnOperation() {
auto moduleOp = getOperation();
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
if (failed(initGPULaunchConfig(funcOp))) {
return signalPassFailure();
}
IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(funcOp);
if (!translationInfo) {
// Dont do anything if translation info is not set.
return;
}
// Verify the properties of each entry point based on the target pipeline.
if (failed(verifyEntryPoint(funcOp, translationInfo))) {
return signalPassFailure();
}
}
}
} // namespace mlir::iree_compiler