blob: 78102023b7fd6980bf34511f433fc22cf0b4d0a0 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <cassert>
#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinAttributes.h"
bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
}
MlirAttribute
ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory,
bool *noReduceSharedMemoryBankConflicts,
bool *useIgemmConvolution,
MlirAttribute *reorderWorkgroupsStrategy) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
mlir::Builder b(ctx);
auto prefetchSharedMemoryAttr = mlir::BoolAttr();
if (prefetchSharedMemory) {
prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory);
}
auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr();
if (noReduceSharedMemoryBankConflicts) {
noReduceSharedMemoryBankConflictsAttr =
b.getBoolAttr(*noReduceSharedMemoryBankConflicts);
}
auto useIgemmConvolutionAttr = mlir::BoolAttr();
if (useIgemmConvolution) {
useIgemmConvolutionAttr = b.getBoolAttr(*useIgemmConvolution);
}
auto strategyAttr =
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr();
if (reorderWorkgroupsStrategy) {
strategyAttr = llvm::dyn_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(*reorderWorkgroupsStrategy));
}
return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get(
ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr,
useIgemmConvolutionAttr, strategyAttr));
}
MlirAttribute
ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getPrefetchSharedMemory());
}
MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts(
MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts());
}
MlirAttribute
ireeGPUPipelineOptionsAttrGetUseIgemmConvolution(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getUseIgemmConvolution());
}
MlirAttribute
ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) {
auto gpuAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
unwrap(attr));
return wrap(gpuAttr.getReorderWorkgroupsStrategy());
}
MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
return wrap(
mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
}
bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
return llvm::isa<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr));
}
MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::
getTypeID());
}
static_assert(
std::is_same_v<
uint32_t,
std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>>,
"Enum type changed");
MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(MlirContext mlirCtx,
uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
ctx, static_cast<
mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>(
value)));
}
uint32_t ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
"attr is not a GPUReorderWorkgroupsStrategyAttr");
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
unwrap(attr))
.getValue());
}
bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
unwrap(attr));
}
bool ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
unwrap(attr));
}
MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::getTypeID());
}
MlirTypeID ireeGPUVirtualMMAIntrinsicAttrGetTypeID() {
return wrap(
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr::getTypeID());
}
static_assert(
std::is_same_v<uint32_t, std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::MMAIntrinsic>>,
"Enum type changed");
static_assert(
std::is_same_v<uint32_t,
std::underlying_type_t<
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>>,
"Enum type changed");
MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}
MlirAttribute ireeGPUVirtualMMAIntrinsicAttrGet(MlirContext mlirCtx,
uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr::get(
ctx,
static_cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>(value)));
}
uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
"attr is not a GPUMMAIntrinsicAttr");
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(unwrap(attr))
.getValue());
}
uint32_t ireeGPUVirtualMMAIntrinsicAttrGetValue(MlirAttribute attr) {
assert(ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(attr) &&
"attr is not a GPUVirtualMMAIntrinsicAttr");
return static_cast<uint32_t>(
llvm::cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
unwrap(attr))
.getValue());
}
bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
}
bool ireeAttributeIsAGPUVirtualMMAAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::VirtualMMAAttr>(
unwrap(attr));
}
MlirTypeID ireeGPUMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
}
MlirTypeID ireeGPUVirtualMMAAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAAttr::getTypeID());
}
MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
}
MlirAttribute ireeGPUVirtualMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAAttr::get(
ctx,
static_cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>(value)));
}
ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
return llvm::TypeSwitch<mlir::Attribute, ireeGPUMMAInfo>(unwrap(attr))
.Case<mlir::iree_compiler::IREE::GPU::MMAAttr,
mlir::iree_compiler::IREE::GPU::VirtualMMAAttr>([](auto mma) {
ireeGPUMMAInfo info = {};
auto [aType, bType, cType] = mma.getABCElementTypes();
info.aElementType = wrap(aType);
info.bElementType = wrap(bType);
info.cElementType = wrap(cType);
auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
info.aVectorType = wrap(aVecType);
info.bVectorType = wrap(bVecType);
info.cVectorType = wrap(cVecType);
std::tie(info.mElements, info.nElements, info.kElements) =
mma.getMNKShape();
return info;
})
.Default([](mlir::Attribute) -> ireeGPUMMAInfo {
assert(false && "Unexpected attribute type for MMA info");
return {};
});
}
MlirAttribute ireeGPUMMAAttrGetVirtualMMAIntrinsic(MlirAttribute attr) {
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
llvm::SmallVector<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>
virtualIntrinsics = mma.getVirtualIntrinsics();
llvm::SmallVector<int64_t> rawValues;
for (auto v : virtualIntrinsics) {
rawValues.push_back(static_cast<int64_t>(v));
}
mlir::MLIRContext *ctx = mma.getContext();
mlir::Builder builder(ctx);
return wrap(builder.getI64ArrayAttr(rawValues));
}
bool ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr) {
return llvm::isa<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
}
MlirTypeID ireeGPULoweringConfigAttrGetTypeID() {
return wrap(mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::getTypeID());
}
MlirAttribute ireeGPULoweringConfigAttrGet(MlirContext mlirCtx,
MlirAttribute attributesDictionary) {
assert(mlirAttributeIsADictionary(attributesDictionary));
auto attributes =
llvm::cast<mlir::DictionaryAttr>(unwrap(attributesDictionary));
mlir::MLIRContext *ctx = unwrap(mlirCtx);
return wrap(
mlir::iree_compiler::IREE::GPU::LoweringConfigAttr::get(ctx, attributes));
}
MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
return wrap(llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr))
.getAttributes());
}
ireeGPUTileSizes ireeGPULoweringConfigAttrGetTileSizes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
ireeGPUTileSizes tilesizes = {};
mlir::DictionaryAttr dict =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr))
.getAttributes();
llvm::StringRef workgroupName =
mlir::iree_compiler::IREE::GPU::getTilingLevelName(
mlir::iree_compiler::IREE::GPU::TilingLevel::Workgroup);
if (auto workgroupArray = dict.getAs<mlir::ArrayAttr>(workgroupName)) {
tilesizes.workgroupAttr = wrap(workgroupArray);
}
llvm::StringRef reductionName =
mlir::iree_compiler::IREE::GPU::getTilingLevelName(
mlir::iree_compiler::IREE::GPU::TilingLevel::Reduction);
if (auto reductionArray = dict.getAs<mlir::ArrayAttr>(reductionName)) {
tilesizes.reductionAttr = wrap(reductionArray);
}
return tilesizes;
}
ireeGPUSubgroupCountInfo
ireeGPULoweringConfigAttrGetSubgroupCount(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
auto loweringConfigAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
std::optional<int64_t> subgroupMCount =
mlir::iree_compiler::IREE::GPU::getSubgroupMCount(loweringConfigAttr);
std::optional<int64_t> subgroupNCount =
mlir::iree_compiler::IREE::GPU::getSubgroupNCount(loweringConfigAttr);
ireeGPUSubgroupCountInfo info = {};
if (subgroupMCount) {
info.subgroupMCountAttr = wrap(mlir::IntegerAttr::get(
mlir::IndexType::get(loweringConfigAttr.getContext()),
*subgroupMCount));
}
if (subgroupNCount) {
info.subgroupNCountAttr = wrap(mlir::IntegerAttr::get(
mlir::IndexType::get(loweringConfigAttr.getContext()),
*subgroupNCount));
}
return info;
}
MlirAttribute ireeGPULoweringConfigAttrGetMmaKind(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
auto loweringConfigAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
mlir::iree_compiler::IREE::Codegen::InnerTileDescAttrInterface mma_attr =
mlir::iree_compiler::IREE::GPU::getMmaKind(loweringConfigAttr);
return wrap(mma_attr);
}
ireeGPUMMASingleSubgroupLayout
ireeGPUGetSingleSubgroupLayout(MlirAttribute attr, uint32_t fragment) {
assert((ireeAttributeIsAGPUMMAIntrinsicAttr(attr) ||
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(attr)) &&
"Expected MMA or VirtualMMA Intrinsic");
mlir::Attribute baseAttr = unwrap(attr);
mlir::iree_compiler::IREE::GPU::MMASingleSubgroupLayout layout;
mlir::iree_compiler::IREE::GPU::MMAFragment frag =
static_cast<mlir::iree_compiler::IREE::GPU::MMAFragment>(fragment);
if (auto intrinsicAttr =
llvm::dyn_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr>(
baseAttr)) {
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
intrinsicAttr.getValue(), frag);
} else if (auto virtualIntrinsicAttr = llvm::dyn_cast<
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
baseAttr)) {
layout = mlir::iree_compiler::IREE::GPU::getSingleSubgroupLayout(
virtualIntrinsicAttr.getValue(), frag);
} else {
assert(false &&
"Unreachable: attribute must be MMA or VirtualMMA intrinsic");
}
mlir::MLIRContext *context = baseAttr.getContext();
mlir::Builder builder(context);
ireeGPUMMASingleSubgroupLayout result = {};
result.outer = wrap(builder.getI64ArrayAttr(layout.outer));
result.thread = wrap(builder.getI64ArrayAttr(layout.thread));
result.tstrides = wrap(builder.getI64ArrayAttr(layout.tstrides));
result.element = wrap(builder.getI64ArrayAttr(layout.element));
return result;
}