More IREEGPUAttrs.cpp cleanups (#19142)
Two things in this PR:
1. Make a big switch statement more concise.
2. Currently, `DataTileMMAAttr::buildMmaOperation` creates a `MMAAttr`
just to call `buildMmaOperation` on it, to reuse that implementation. In
addition to being roundabout, this required a comment explaining why we
discarded the error status, as `MMAAttr::buildMmaOperation` is fallible
but here we were already past validation and mutating IR. This PR
refactors that to let both call a shared, infallible helper.
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index e011b78..0a27437 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -220,88 +220,45 @@
Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);
switch (intrinsic) {
- case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
+ case MMAIntrinsic::MFMA_F64_16x16x4_F64:
return {f64, f64, f64};
- }
- case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+ case MMAIntrinsic::MFMA_F32_16x16x4_F32:
return {f32, f32, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+ case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
return {f16, f16, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
- return {f16, f16, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
- return {bf16, bf16, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
- return {bf16, bf16, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
- return {bf16, bf16, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
- return {bf16, bf16, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
- return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
- return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
- return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
- return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
- return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
- return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
- return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
- return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
- }
- case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
- return {i8, i8, i32};
- }
- case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
- return {i8, i8, i32};
- }
- case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
- return {i8, i8, i32};
- }
- case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
- return {i8, i8, i32};
- }
- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
- return {f16, f16, f32};
- }
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
return {f16, f16, f16};
- }
- case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
+ case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
+ case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
+ case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
return {bf16, bf16, f32};
- }
- case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
+ case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
return {bf16, bf16, bf16};
- }
- case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
+ return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
+ return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
+ return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
+ return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
+ case MMAIntrinsic::MFMA_I32_16x16x16_I8:
+ case MMAIntrinsic::MFMA_I32_32x32x8_I8:
+ case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+ case MMAIntrinsic::MFMA_I32_32x32x16_I8:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8:
return {i8, i8, i32};
}
- case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
- return {f16, f16, f16};
- }
- case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
- return {f16, f16, f32};
- }
- }
assert(false && "unexpected enum value");
return {};
}
@@ -498,11 +455,15 @@
return IREE::GPU::getContractionLayout(contract, layout);
}
-int64_t MMAAttr::getBlockSize() const {
+static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
// Not supporting any block size other than 1 at the moment.
return 1;
}
+int64_t MMAAttr::getBlockSize() const {
+ return IREE::GPU::getBlockSize(getIntrinsic().getValue());
+}
+
static uint32_t getArchID(MMAIntrinsic intrinsic) {
return static_cast<int>(intrinsic) & 0xFF00;
}
@@ -704,6 +665,31 @@
}
}
+static Value createMmaOp(OpBuilder &builder, Location loc,
+ MMAIntrinsic intrinsic, Type resultType, Value lhs,
+ Value rhs, Value acc) {
+ auto getVecOrSingleElem = [&](Value vec) -> Value {
+ bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
+ return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
+ };
+ auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic);
+ if (is_AMD_MFMA(intrinsic)) {
+ // MFMA intrinsics want single-element operands of element type, not vector.
+ lhs = getVecOrSingleElem(lhs);
+ rhs = getVecOrSingleElem(rhs);
+ return builder
+ .create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
+ layout.kSize, getBlockSize(intrinsic), lhs, rhs,
+ acc)
+ .getResult();
+ }
+ if (is_AMD_WMMA(intrinsic)) {
+ return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
+ .getResult();
+ }
+ return {};
+}
+
// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
// type.
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
@@ -719,23 +705,9 @@
if (cType != resultType) {
return failure();
}
- auto getVecOrSingleElem = [&](Value vec) -> Value {
- bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
- return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
- };
- MMAIntrinsic intrinsic = getIntrinsic().getValue();
- if (is_AMD_MFMA(intrinsic)) {
- // MFMA intrinsics want single-element operands of element type, not vector.
- lhs = getVecOrSingleElem(lhs);
- rhs = getVecOrSingleElem(rhs);
- auto [m, n, k] = getMNKShape();
- return builder
- .create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
- rhs, acc)
- .getResult();
- } else if (is_AMD_WMMA(intrinsic)) {
- return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
- .getResult();
+ if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
+ resultType, lhs, rhs, acc)) {
+ return value;
}
return failure();
}
@@ -1168,23 +1140,18 @@
SmallVector<Value> intrinsicsAcc =
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);
- // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
- // to create the target intrinsics.
- auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue());
- auto [intrinsicAType, intrinsicBType, intrinsicCType] =
- intrinsicMma.getABCVectorTypes();
+ MMAIntrinsic intrinsic = getIntrinsic().getValue();
+ VectorType intrinCType =
+ getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);
// Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
for (int mu = 0; mu < getUnrollM(); ++mu) {
for (int nu = 0; nu < getUnrollN(); ++nu) {
for (int ku = 0; ku < getUnrollK(); ++ku) {
- // Assume intrinsicMma.buildMmaOperation() success: validation should be
- // completed prior to mutating IR.
Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
- acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
- rhs, acc);
+ acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
}
}
}