[GPU] Remove MMAScheduleAttr (#21884)
This attribute isn't used anywhere and is deprecated.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 40aa600..2f9c623 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -230,7 +230,7 @@
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
-func.func @conv() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+func.func @conv() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64>} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 857482b..1a73ac7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -2567,7 +2567,7 @@
hal.return %x, %y, %z : index, index, index
}
builtin.module {
- func.func @set_size_to_tilesize_when_divisible() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 32, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+ func.func @set_size_to_tilesize_when_divisible() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 32>} {
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f16
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 7357997..d8e5875 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -415,28 +415,6 @@
}
//===----------------------------------------------------------------------===//
-// MMA schedule
-//===----------------------------------------------------------------------===//
-
-def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
- let mnemonic = "mma_schedule";
- let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
-
- string description = [{
- A schedule of MMA intrinsic instruction and various levels of tile sizes
- to solve a specific contraction problem.
- }];
-
- let parameters = (ins
- "::mlir::iree_compiler::IREE::Codegen::InnerTileDescAttrInterface":$intrinsic,
- "int64_t":$subgroup_m_count,
- "int64_t":$subgroup_n_count
- );
-
- let assemblyFormat = "`<` struct(params) `>`";
-}
-
-//===----------------------------------------------------------------------===//
// iree_gpu.gpu_encoding_resolver
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index da94801..c87c495 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -199,23 +199,24 @@
return bounds;
}
-static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+struct MMASchedule {
+ IREE::Codegen::InnerTileDescAttrInterface intrinsic;
+ int64_t subgroupMCount;
+ int64_t subgroupNCount;
+};
+
+static LogicalResult setContractionAnchor(MMASchedule &schedule,
SmallVector<bool> promotedOperands,
RewriterBase &rewriter,
linalg::LinalgOp contract) {
- // TODO: Add SIMT fallback.
- if (!schedule) {
- return contract->emitError("missing mma schedule for contraction");
- }
-
// This function should have only be called on a contraction op.
assert(linalg::isaContractionOpInterface(contract) &&
"cannot set contraction anchor on non contraction op");
SmallVector<int64_t> bounds = getIterationSpaceBounds(contract);
auto layouts = getContractionLayout(
- schedule.getIntrinsic(), schedule.getSubgroupMCount(),
- schedule.getSubgroupNCount(), bounds, contract.getIndexingMapsArray());
+ schedule.intrinsic, schedule.subgroupMCount, schedule.subgroupNCount,
+ bounds, contract.getIndexingMapsArray());
if (failed(layouts)) {
return contract->emitError("cannot get concrete layout for contraction");
}
@@ -230,11 +231,11 @@
// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
auto layoutedLhs =
- ToLayoutOp::create(rewriter, loc, lhs, aLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.intrinsic);
auto layoutedRhs =
- ToLayoutOp::create(rewriter, loc, rhs, bLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.intrinsic);
auto layoutedAcc =
- ToLayoutOp::create(rewriter, loc, acc, cLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.intrinsic);
// Promote matmul lhs and rhs.
// TODO: This is a hack until layout analysis is improved. The layout analysis
@@ -257,23 +258,18 @@
// Set layout for result.
rewriter.setInsertionPointAfter(contract);
- auto toLayout = ToLayoutOp::create(rewriter, loc, contract->getResult(0),
- cLayout, schedule.getIntrinsic());
+ auto toLayout = rewriter.create<ToLayoutOp>(loc, contract->getResult(0),
+ cLayout, schedule.intrinsic);
rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
toLayout);
return success();
}
-static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+static LogicalResult setConvolutionAnchor(MMASchedule schedule,
SmallVector<bool> promotedOperands,
RewriterBase &rewriter,
linalg::LinalgOp conv) {
- // TODO: Add SIMT fallback.
- if (!schedule) {
- return conv->emitError("missing mma schedule for convolution");
- }
-
// This function should have only be called on a convolution op.
FailureOr<linalg::ConvolutionDimensions> convDims =
linalg::inferConvolutionDims(conv);
@@ -298,9 +294,9 @@
}
SmallVector<int64_t> bounds = getIterationSpaceBounds(conv);
- auto layouts = getContractionLayout(
- schedule.getIntrinsic(), schedule.getSubgroupMCount(),
- schedule.getSubgroupNCount(), bounds, maps);
+ FailureOr<ContractionLayout> layouts =
+ getContractionLayout(schedule.intrinsic, schedule.subgroupMCount,
+ schedule.subgroupNCount, bounds, maps);
auto [aLayout, bLayout, cLayout] = *layouts;
Location loc = conv.getLoc();
@@ -312,11 +308,11 @@
// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(conv);
auto layoutedLhs =
- ToLayoutOp::create(rewriter, loc, lhs, aLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.intrinsic);
auto layoutedRhs =
- ToLayoutOp::create(rewriter, loc, rhs, bLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.intrinsic);
auto layoutedAcc =
- ToLayoutOp::create(rewriter, loc, acc, cLayout, schedule.getIntrinsic());
+ rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.intrinsic);
// Promote matmul lhs and rhs.
// TODO: This is a hack until layout analysis is improved. The layout analysis
@@ -339,8 +335,8 @@
// Set layout for result.
rewriter.setInsertionPointAfter(conv);
- auto toLayout = ToLayoutOp::create(rewriter, loc, conv->getResult(0), cLayout,
- schedule.getIntrinsic());
+ auto toLayout = rewriter.create<ToLayoutOp>(loc, conv->getResult(0), cLayout,
+ schedule.intrinsic);
rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
toLayout);
@@ -401,26 +397,14 @@
contractOp->setOperand(1, lhs);
}
-static IREE::GPU::MMAScheduleAttr
-transposeSchedule(RewriterBase &rewriter, IREE::GPU::MMAScheduleAttr schedule) {
- return rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
- schedule.getIntrinsic(), schedule.getSubgroupNCount(),
- schedule.getSubgroupMCount());
-}
-
static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
linalg::LinalgOp qkMatmul,
linalg::LinalgOp pvMatmul) {
- IREE::GPU::MMAScheduleAttr qkSchedule =
- rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(qkMatmul),
- getSubgroupMCount(qkMatmul),
- getSubgroupNCount(qkMatmul));
-
- IREE::GPU::MMAScheduleAttr pvSchedule =
- rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(pvMatmul),
- getSubgroupMCount(pvMatmul),
- getSubgroupNCount(pvMatmul));
+ MMASchedule qkSchedule = {getIntrinsic(qkMatmul), getSubgroupMCount(qkMatmul),
+ getSubgroupNCount(qkMatmul)};
+ MMASchedule pvSchedule = {getIntrinsic(pvMatmul), getSubgroupMCount(pvMatmul),
+ getSubgroupNCount(pvMatmul)};
// Check if the intrinsic output for qkMatmul can be reused for pvMatmul.
// We know that pvMatmul takes result of qkMatmul as it's lhs.
@@ -429,10 +413,10 @@
bool reuseIntrinsicOutput = false;
bool transposeIntrinsic = false;
- auto qkIntrinsic = cast<IREE::Codegen::InnerTileDescAttrInterface>(
- qkSchedule.getIntrinsic());
- auto pvIntrinsic = cast<IREE::Codegen::InnerTileDescAttrInterface>(
- pvSchedule.getIntrinsic());
+ auto qkIntrinsic =
+ cast<IREE::Codegen::InnerTileDescAttrInterface>(qkSchedule.intrinsic);
+ auto pvIntrinsic =
+ cast<IREE::Codegen::InnerTileDescAttrInterface>(pvSchedule.intrinsic);
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Lhs);
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
@@ -475,8 +459,8 @@
}
swapOperandsToTransposeIntrinsic(rewriter, qkGeneric);
swapOperandsToTransposeIntrinsic(rewriter, pvGeneric);
- qkSchedule = transposeSchedule(rewriter, qkSchedule);
- pvSchedule = transposeSchedule(rewriter, pvSchedule);
+ std::swap(qkSchedule.subgroupMCount, qkSchedule.subgroupNCount);
+ std::swap(pvSchedule.subgroupMCount, pvSchedule.subgroupNCount);
// Swap promoted operands.
std::swap(promotedQKOperands[0], promotedQKOperands[1]);
@@ -581,9 +565,8 @@
ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
SmallVector<bool> promotedOperands = getPromotedOperands(candidate);
- auto schedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
- getIntrinsic(candidate), getSubgroupMCount(candidate),
- getSubgroupNCount(candidate));
+ MMASchedule schedule = {getIntrinsic(candidate), getSubgroupMCount(candidate),
+ getSubgroupNCount(candidate)};
if (linalg::isaContractionOpInterface(candidate)) {
if (succeeded(setContractionAnchor(schedule, promotedOperands, rewriter,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index ea10c44..1683f88 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -1,10 +1,6 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-cast-type-to-fit-mma))' -mlir-print-local-scope %s | FileCheck %s
-func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
- subgroup_m_count = 1, subgroup_n_count = 1>,
- workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -28,11 +24,7 @@
// -----
-func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
- subgroup_m_count = 1, subgroup_n_count = 1>,
- workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -53,11 +45,7 @@
// -----
-func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
- subgroup_m_count = 1, subgroup_n_count = 1>,
- workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -77,11 +65,7 @@
// -----
-func.func @wmmar3_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf16>) -> vector<48x32xf16> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>,
- subgroup_m_count = 1, subgroup_n_count = 1>,
- workgroup_size = [32, 1, 1]} {
+func.func @wmmar3_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf16>) -> vector<48x32xf16> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -109,11 +93,7 @@
// "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout.
// this also shows that we can overwrite default intrinsics if explicitly set.
-func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
- mma_schedule = #iree_gpu.mma_schedule<
- intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
- subgroup_m_count = 1, subgroup_n_count = 1>,
- workgroup_size = [64, 1, 1]} {
+func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -141,7 +121,7 @@
// it will not have mma_schedule on function attributes, but instead it will have
// "iree.amdgpu.mma" attribute directly on vector.contract.
-func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1] subgroup_size = 64>} {
+func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 074664c..62e48ef 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -152,10 +152,7 @@
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
- subgroup_m_count = 4,
- subgroup_n_count = 1>}>
+ subgroup_size = 64>
#maps = [
affine_map<(m, n, k) -> (m, k)>,