[Codegen][CPU] Add a type-polymorphic generic-scalar MMA fallback. (#24389)
Adds two new `MMAIntrinsic` values, `MMA_GENERIC_SCALAR_1x1x1_REG8` and
`_REG16`, that the data-tiling cost model picks when no element-type-
specific intrinsic on the target supports the matmul's (LHS, RHS, ACC)
types. This intentionally breaks the "an MMAIntrinsic enum value pins
down a specific element-type triple" invariant in exchange for not
having to add one enum value per supported triple. Element types live on
new `DataTiledMMAAttr.{lhs,rhs,acc}_type` parameters, populated by the
cost model only when the chosen intrinsic is one of the polymorphic
variants.
The cost model picks `_REG16` on 64-bit ISAs (x86_64, AArch64, RISC-V)
and `_REG8` on 32-bit ISAs. The number is a register-budget for the
unroll heuristic — one element of any width occupies one register, but
the architectural register file the lowering ends up in (GPR or SIMD-
scalar lane) is up to LLVM. The budget is encoded in the low byte of the
enum value, so `chooseUnrolling` can read it back.
Since the intrinsic is 1×1×1, the operand tiles after `intrinsics_m` /
`intrinsics_n` / `intrinsics_k` are simple row-major (M, K) / (N, K) /
(M, N) — `linalg.mmt4d`-shaped.
`DataTiledMMAAttr::buildUnderlyingOperations` therefore short-circuits
the swizzle/distribute pipeline for these intrinsics and emits a single
`vector.contract` directly, with `arith.extf` / `arith.extsi` widening
narrow LHS/RHS to ACC's element type. For sub-byte LHS/RHS types
`chooseUnrolling` also picks the smallest power-of-two `intrinsics_k`
such that K*lhsBits and K*rhsBits are byte-aligned (e.g. K=2 for i4/f4,
K=4 for f6, K=8 for i1).
Progress towards #24323diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_aarch64.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_aarch64.mlir
index 7f07eac..a4efc1b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_aarch64.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_aarch64.mlir
@@ -597,3 +597,78 @@
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: return %[[MMT4D]]
+
+// -----
+
+// AArch64 has no element-type-specific MMA intrinsic for sub-byte float
+// types, so the data-tiling cost model falls back to the type-polymorphic
+// generic scalar. AArch64 is 64-bit, so `pickGenericScalarMMAForTarget`
+// picks `_REG16`. The chosen `DataTiledMMAAttr` carries
+// `lhs_type = f4E2M1FN`, `rhs_type = f4E2M1FN`, `acc_type = f32` — the
+// enum value alone doesn't determine element types.
+//
+// FP4 register pressure: K=2 is the smallest power-of-two K that makes
+// K*4 a multiple of 8 (for byte-addressable packed groups). Inside the
+// 16-register budget the cost model lands on a square (im=2, in=2, ik=2)
+// tile (2*2 + 2*2 + 2*2 = 12 registers), which wins on arithmetic
+// intensity. Whether the lowering can actually densely pack f4E2M1FN
+// values into bytes today is orthogonal — the framework only insists
+// that the packed K-group be byte-aligned.
+
+#map_g_fp4 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_g_fp4_1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_g_fp4_2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_g_fp4_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f4E2M1FN, f4E2M1FN, f32], user_indexing_maps = [#map_g_fp4, #map_g_fp4_1, #map_g_fp4_2], iteration_sizes = [?, ?, ?]>
+#encoding_g_fp4_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f4E2M1FN, f4E2M1FN, f32], user_indexing_maps = [#map_g_fp4, #map_g_fp4_1, #map_g_fp4_2], iteration_sizes = [?, ?, ?]>
+#encoding_g_fp4_res = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f4E2M1FN, f4E2M1FN, f32], user_indexing_maps = [#map_g_fp4, #map_g_fp4_1, #map_g_fp4_2], iteration_sizes = [?, ?, ?]>
+func.func @matmul_fp4_f32_aarch64_generic(
+ %lhs: tensor<?x?xf4E2M1FN, #encoding_g_fp4_lhs>,
+ %rhs: tensor<?x?xf4E2M1FN, #encoding_g_fp4_rhs>,
+ %acc: tensor<?x?xf32, #encoding_g_fp4_res>
+) -> tensor<?x?xf32, #encoding_g_fp4_res> attributes {
+ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "aarch64-xyz-xyz", enable_inner_tiled = true, iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
+} {
+ %0 = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf4E2M1FN, #encoding_g_fp4_lhs>, tensor<?x?xf4E2M1FN, #encoding_g_fp4_rhs>)
+ outs(%acc : tensor<?x?xf32, #encoding_g_fp4_res>) -> tensor<?x?xf32, #encoding_g_fp4_res>
+ return %0 : tensor<?x?xf32, #encoding_g_fp4_res>
+}
+// CHECK-LABEL: func @matmul_fp4_f32_aarch64_generic(
+// CHECK: %[[INNER_FP4:.+]] = iree_codegen.inner_tiled
+// CHECK-SAME: kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG16, intrinsics_m = 2, intrinsics_n = 2, intrinsics_k = 2, lhs_type = f4E2M1FN, rhs_type = f4E2M1FN, acc_type = f32>
+
+// -----
+
+// FP6 LHS/RHS with f32 ACC, on AArch64 (REG16). 6 bits don't divide 8 at
+// K∈{1,2}, so the smallest K making K*6 a multiple of 8 is K = 4
+// (4*6 = 24 bits = 3 bytes). With K=4, LHS/RHS register pressure becomes
+// `intrinsics_m * 4` and `intrinsics_n * 4`; ACC stays at
+// `intrinsics_m * intrinsics_n`. Inside `_REG16`'s 16-register budget,
+// (im=1, in=2, ik=4) packs 1*4 + 1*2 + 2*4 = 14 registers and wins on
+// arithmetic intensity over (2, 1, 4) (the same intensity in mirror image).
+// On `_REG8` neither candidate would fit and we'd fall back to (1, 1) —
+// `_REG16` is the variant that gets a non-degenerate unrolling on FP6.
+// Whether the lowering can actually pack f6E3M2FN values into a byte-
+// aligned 3-byte group today is orthogonal — the framework just
+// guarantees the K-group is byte-aligned.
+
+#map_g_fp6 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_g_fp6_1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_g_fp6_2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_g_fp6_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f6E3M2FN, f6E3M2FN, f32], user_indexing_maps = [#map_g_fp6, #map_g_fp6_1, #map_g_fp6_2], iteration_sizes = [?, ?, ?]>
+#encoding_g_fp6_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f6E3M2FN, f6E3M2FN, f32], user_indexing_maps = [#map_g_fp6, #map_g_fp6_1, #map_g_fp6_2], iteration_sizes = [?, ?, ?]>
+#encoding_g_fp6_res = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f6E3M2FN, f6E3M2FN, f32], user_indexing_maps = [#map_g_fp6, #map_g_fp6_1, #map_g_fp6_2], iteration_sizes = [?, ?, ?]>
+func.func @matmul_fp6_f32_aarch64_generic(
+ %lhs: tensor<?x?xf6E3M2FN, #encoding_g_fp6_lhs>,
+ %rhs: tensor<?x?xf6E3M2FN, #encoding_g_fp6_rhs>,
+ %acc: tensor<?x?xf32, #encoding_g_fp6_res>
+) -> tensor<?x?xf32, #encoding_g_fp6_res> attributes {
+ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "aarch64-xyz-xyz", enable_inner_tiled = true, iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
+} {
+ %0 = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf6E3M2FN, #encoding_g_fp6_lhs>, tensor<?x?xf6E3M2FN, #encoding_g_fp6_rhs>)
+ outs(%acc : tensor<?x?xf32, #encoding_g_fp6_res>) -> tensor<?x?xf32, #encoding_g_fp6_res>
+ return %0 : tensor<?x?xf32, #encoding_g_fp6_res>
+}
+// `intrinsics_m = 1` defaults out of the printed attr.
+// CHECK-LABEL: func @matmul_fp6_f32_aarch64_generic(
+// CHECK: %[[INNER_FP6:.+]] = iree_codegen.inner_tiled
+// CHECK-SAME: kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG16, intrinsics_n = 2, intrinsics_k = 4, lhs_type = f6E3M2FN, rhs_type = f6E3M2FN, acc_type = f32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
index 179cb74..4326e7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
@@ -250,6 +250,76 @@
// -----
+// AVX2 (`+avx2,+fma`) has no element-type-specific MMA intrinsic for
+// bf16/f32, so the cost model falls back to the type-polymorphic generic
+// scalar. x86_64 is 64-bit, so `pickGenericScalarMMAForTarget` picks
+// `_REG16`. The chosen `DataTiledMMAAttr` carries `lhs_type = bf16`,
+// `rhs_type = bf16`, `acc_type = f32` directly (the enum value alone
+// doesn't determine element types).
+
+#map_g = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_g1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_g2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_g_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map_g, #map_g1, #map_g2], iteration_sizes = [?, ?, ?]>
+#encoding_g_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map_g, #map_g1, #map_g2], iteration_sizes = [?, ?, ?]>
+#encoding_g_res = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map_g, #map_g1, #map_g2], iteration_sizes = [?, ?, ?]>
+func.func @matmul_bf16_f32_avx2_generic(
+ %lhs: tensor<?x?xbf16, #encoding_g_lhs>,
+ %rhs: tensor<?x?xbf16, #encoding_g_rhs>,
+ %acc: tensor<?x?xf32, #encoding_g_res>
+) -> tensor<?x?xf32, #encoding_g_res> attributes {
+ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", cpu_features = "+avx2,+fma", enable_inner_tiled = true, iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
+} {
+ %0 = linalg.matmul ins(%lhs, %rhs : tensor<?x?xbf16, #encoding_g_lhs>, tensor<?x?xbf16, #encoding_g_rhs>)
+ outs(%acc : tensor<?x?xf32, #encoding_g_res>) -> tensor<?x?xf32, #encoding_g_res>
+ return %0 : tensor<?x?xf32, #encoding_g_res>
+}
+// `chooseUnrolling` for the generic intrinsic uses the budget encoded in
+// the chosen `_REG*` enum case (16 here). One element per register, so the
+// register pressure is `intrinsics_m * intrinsics_n` (ACC) +
+// `intrinsics_m` (LHS) + `intrinsics_n` (RHS). Matmul dims are dynamic
+// here, so all three dims are free; arithmetic-intensity tie-breaking
+// favors approximately-square tiles. (im=2, in=4) and (im=4, in=2) tie at
+// intensity 8/6 — the search picks (im=2, in=4) (lower im first), packing
+// 8 + 2 + 4 = 14 registers within 16.
+// CHECK-LABEL: func @matmul_bf16_f32_avx2_generic(
+// CHECK: %[[INNER_G:.+]] = iree_codegen.inner_tiled
+// CHECK-SAME: kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG16, intrinsics_m = 2, intrinsics_n = 4, lhs_type = bf16, rhs_type = bf16, acc_type = f32>
+
+// -----
+
+// Sub-byte LHS/RHS forces `chooseUnrolling` to also pick `intrinsics_k > 1`
+// for the generic intrinsic, so each contiguous K-group covers a whole
+// number of bytes (otherwise sub-byte elements aren't byte-addressable in
+// the packed layout). Smallest power-of-two K with K*4 % 8 == 0 is K = 2.
+// LHS pressure becomes `intrinsics_m * intrinsics_k`, RHS pressure
+// `intrinsics_n * intrinsics_k`, ACC `intrinsics_m * intrinsics_n`. Inside
+// `_REG16`'s budget, (m=2, n=2, k=2) packs 2·2 + 2·2 + 2·2 = 12 registers
+// and wins on arithmetic intensity (4/4 = 1.0).
+
+#map_g4 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_g4_1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_g4_2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_g4_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [i4, i4, i32], user_indexing_maps = [#map_g4, #map_g4_1, #map_g4_2], iteration_sizes = [?, ?, ?]>
+#encoding_g4_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i4, i4, i32], user_indexing_maps = [#map_g4, #map_g4_1, #map_g4_2], iteration_sizes = [?, ?, ?]>
+#encoding_g4_res = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i4, i4, i32], user_indexing_maps = [#map_g4, #map_g4_1, #map_g4_2], iteration_sizes = [?, ?, ?]>
+func.func @matmul_i4_i32_avx2_generic(
+ %lhs: tensor<?x?xi4, #encoding_g4_lhs>,
+ %rhs: tensor<?x?xi4, #encoding_g4_rhs>,
+ %acc: tensor<?x?xi32, #encoding_g4_res>
+) -> tensor<?x?xi32, #encoding_g4_res> attributes {
+ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", cpu_features = "+avx2,+fma", enable_inner_tiled = true, iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
+} {
+ %0 = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi4, #encoding_g4_lhs>, tensor<?x?xi4, #encoding_g4_rhs>)
+ outs(%acc : tensor<?x?xi32, #encoding_g4_res>) -> tensor<?x?xi32, #encoding_g4_res>
+ return %0 : tensor<?x?xi32, #encoding_g4_res>
+}
+// CHECK-LABEL: func @matmul_i4_i32_avx2_generic(
+// CHECK: %[[INNER_G4:.+]] = iree_codegen.inner_tiled
+// CHECK-SAME: kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG16, intrinsics_m = 2, intrinsics_n = 2, intrinsics_k = 2, lhs_type = i4, rhs_type = i4, acc_type = i32>
+
+// -----
+
// It tests with bindings and checks that the reshape ops are folded into bindings.
#executable_target_xyz = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz", iree.encoding.resolver = #iree_cpu.cpu_encoding_resolver<>}>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
index ad5b6b4..ea4a2e1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
@@ -352,10 +352,36 @@
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
return Tuple{1, 16, 2};
default:
+ if (isGenericScalar(intrinsic)) {
+ return Tuple{1, 1, 1};
+ }
return {};
}
}
+// Bit-layout constants for the `MMAIntrinsic` enum value (see IREECPUEnums.td
+// for the 0xABCD scheme). The high byte (`kMMAIntrinsicISAMask`) encodes the
+// architecture (nibble A) and the ISA-extension family (nibble B) together
+// — the abbreviation is spelled `ISA` (uppercase) here to avoid reading as
+// the English "is a". The generic-scalar family uses A=F, B=0; the low byte
+// then holds the register-budget heuristic value.
+constexpr uint32_t kMMAIntrinsicISAMask = 0xFF00;
+constexpr uint32_t kMMAIntrinsicISAGeneric = 0xF000;
+constexpr uint32_t kMMAIntrinsicGenericBudgetMask = 0x00FF;
+constexpr uint32_t kMMAIntrinsicISAX86Avx2 = 0x1200;
+constexpr uint32_t kMMAIntrinsicISAX86Avx512 = 0x1300;
+constexpr uint32_t kMMAIntrinsicISAArmSve = 0x2200;
+
+bool isGenericScalar(MMAIntrinsic intr) {
+ return (static_cast<uint32_t>(intr) & kMMAIntrinsicISAMask) ==
+ kMMAIntrinsicISAGeneric;
+}
+
+int64_t getGenericScalarRegisterBudget(MMAIntrinsic intr) {
+ assert(isGenericScalar(intr));
+ return static_cast<uint32_t>(intr) & kMMAIntrinsicGenericBudgetMask;
+}
+
int64_t getRegisterSpaceBytes(MMAIntrinsic intrinsic) {
// Total architectural vector register file size, in bytes. The inner-tiled
// cost model uses this as the capacity for the union of the ACC, LHS and
@@ -364,13 +390,13 @@
// simplification — the resulting `intrinsics_m`/`intrinsics_n` choices are
// good enough in practice and avoid propagating scalability into the cost
// model.
- uint32_t arch = static_cast<uint32_t>(intrinsic) & 0xFF00;
- switch (arch) {
- case 0x1200: // AVX/AVX2: 16 YMM × 32 B.
+ uint32_t isa = static_cast<uint32_t>(intrinsic) & kMMAIntrinsicISAMask;
+ switch (isa) {
+ case kMMAIntrinsicISAX86Avx2: // 16 YMM × 32 B.
return 16 * 32;
- case 0x1300: // AVX-512: 32 ZMM × 64 B.
+ case kMMAIntrinsicISAX86Avx512: // 32 ZMM × 64 B.
return 32 * 64;
- case 0x2200: // Arm SVE/SVE2: 32 Z × (VL treated as 128 bits).
+ case kMMAIntrinsicISAArmSve: // 32 Z × (VL treated as 128 bits).
return 32 * 16;
default:
// Plausible default, but override it on each arch you care for.
@@ -544,6 +570,32 @@
}
}
+/// Returns the (LHS, RHS, ACC) *storage* element types for `attr` — what
+/// `getDistributedTileTypes` should plumb into vector types and what the
+/// inner_tiled op's operand types must agree with. For most `MMAIntrinsic`
+/// values these are baked into the enum and the `MMAIntrinsic`-only
+/// overload above suffices. The `MMA_GENERIC_SCALAR_1x1x1_REG*` family is
+/// type-polymorphic — its element types live on the attr's `lhs_type` /
+/// `rhs_type` / `acc_type` parameters. We strip integer signedness here:
+/// storage is always signless, the attr keeps the `siN` / `uiN` annotation
+/// for the lowering to pick `arith.extsi` vs `arith.extui`.
+static std::tuple<Type, Type, Type>
+getABCElementTypes(MLIRContext *context, IREE::CPU::DataTiledMMAAttr attr) {
+ MMAIntrinsic intrinsic = attr.getIntrinsic();
+ if (isGenericScalar(intrinsic)) {
+ auto signless = [&](Type t) -> Type {
+ if (auto intTy = dyn_cast_if_present<IntegerType>(t);
+ intTy && !intTy.isSignless()) {
+ return IntegerType::get(context, intTy.getWidth());
+ }
+ return t;
+ };
+ return {signless(attr.getLhsType()), signless(attr.getRhsType()),
+ signless(attr.getAccType())};
+ }
+ return getABCElementTypes(context, intrinsic);
+}
+
//===----------------------------------------------------------------------===//
// DataTiledMMA Attributes
//===----------------------------------------------------------------------===//
@@ -580,7 +632,7 @@
result.clear();
return;
}
- auto [aType, bType, cType] = getABCElementTypes(ctx, intrinsic);
+ auto [aType, bType, cType] = getABCElementTypes(ctx, *this);
// Each operand's swizzle encodes its tile shape as (outer physical dim,
// inner physical dim) in expandShape[0], expandShape[1]. This mirrors GPU's
// DataTiledMMA, where the tile types encode the layout directly and no
@@ -726,6 +778,60 @@
}
}
+/// Lowers a `DataTiledMMAAttr` whose intrinsic is one of the
+/// `MMA_GENERIC_SCALAR_1x1x1_REG*` cases directly to a single
+/// `vector.contract`. Since the intrinsic's tile shape is 1×1×1, the
+/// operand tiles after applying `intrinsics_m` / `intrinsics_n` /
+/// `intrinsics_k` are simple row-major (M, K)/(N, K)/(M, N) matmul
+/// tiles — no swizzle-based distribution is needed.
+static LogicalResult lowerGenericScalarToVectorContract(
+ OpBuilder &builder, Location loc, IREE::CPU::DataTiledMMAAttr attr,
+ ValueRange inputs, ValueRange outputs, SmallVectorImpl<Value> &results) {
+ assert(isGenericScalar(attr.getIntrinsic()) &&
+ "lowerGenericScalarToVectorContract only handles "
+ "MMA_GENERIC_SCALAR_1x1x1_REG* intrinsics");
+ Value lhs = inputs[0];
+ Value rhs = inputs[1];
+ Value acc = outputs[0];
+ Type accElem = cast<VectorType>(acc.getType()).getElementType();
+ // For the generic intrinsic, ACC is always at least as wide as LHS/RHS,
+ // and they're either all float or all integer (the cost model only picks
+ // this intrinsic when element types are mutually consistent that way).
+ // Signedness lives on the attr's `lhs_type` / `rhs_type` (the operand
+ // vector types are signless storage) and picks ExtSI vs. ExtUI; ACC is
+ // treated as signed.
+ auto isUnsigned = [](Type t) {
+ auto intTy = dyn_cast_if_present<IntegerType>(t);
+ return intTy && intTy.isUnsigned();
+ };
+ auto widenToAcc = [&](Value v, bool unsignedSrc) -> Value {
+ auto vt = cast<VectorType>(v.getType());
+ if (vt.getElementType() == accElem) {
+ return v;
+ }
+ auto wideTy = VectorType::get(vt.getShape(), accElem);
+ if (isa<FloatType>(accElem)) {
+ return arith::ExtFOp::create(builder, loc, wideTy, v);
+ }
+ return unsignedSrc ? Value(arith::ExtUIOp::create(builder, loc, wideTy, v))
+ : Value(arith::ExtSIOp::create(builder, loc, wideTy, v));
+ };
+ lhs = widenToAcc(lhs, isUnsigned(attr.getLhsType()));
+ rhs = widenToAcc(rhs, isUnsigned(attr.getRhsType()));
+ AffineExpr m = builder.getAffineDimExpr(0);
+ AffineExpr n = builder.getAffineDimExpr(1);
+ AffineExpr k = builder.getAffineDimExpr(2);
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ // LHS is (M, K), RHS is (N, K), ACC is (M, N) — same as the
+ // `iree_codegen.inner_tiled` op's own indexing maps for CPU.
+ vector::IteratorType par = vector::IteratorType::parallel;
+ vector::IteratorType red = vector::IteratorType::reduction;
+ results.push_back(vector::ContractionOp::create(
+ builder, loc, lhs, rhs, acc, MapList{{m, k}, {n, k}, {m, n}},
+ ArrayRef<vector::IteratorType>{par, par, red}));
+ return success();
+}
+
LogicalResult DataTiledMMAAttr::buildUnderlyingOperations(
OpBuilder &builder, Location loc, ValueRange inputs, ValueRange outputs,
SmallVectorImpl<Value> &results) const {
@@ -746,6 +852,14 @@
}
MMAIntrinsic intrinsic = getIntrinsic();
+ // The type-polymorphic generic intrinsic is row-major (1×1×1 base) and
+ // bypasses the swizzle/distribution machinery: it lowers directly to a
+ // single `vector.contract` over the unrolled operand tiles, the way
+ // `linalg.mmt4d` would.
+ if (isGenericScalar(intrinsic)) {
+ return lowerGenericScalarToVectorContract(builder, loc, *this, inputs,
+ outputs, results);
+ }
auto emitIntrinsic = [&](OpBuilder &b, Location loc, Value lhs, Value rhs,
Value acc) -> Value {
return createCpuMmaIntrinsicCall(b, loc, intrinsic, lhs, rhs, acc);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td
index 80548de..3845d9f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td
@@ -89,6 +89,16 @@
the intrinsic itself: any intrinsic may be used transposed, including
square ones where the effect is only the accumulator layout change.
+ For most `intrinsic` values the (LHS, RHS, ACC) element types are baked
+ into the enum and `lhs_type` / `rhs_type` / `acc_type` are unused. The
+ one exception is `MMA_GENERIC_SCALAR_1x1x1`: it is a type-polymorphic
+ fallback used when no element-type-specific intrinsic matches the
+ target, and it carries its element types in those three optional
+ parameters instead. This deliberately breaks the otherwise-strong
+ invariant that an `MMAIntrinsic` enum value pins down a specific element
+ type triple, in exchange for not having to add one enum value per
+ supported (LHS, RHS, ACC) combination.
+
Some GPU-specific methods in IREECodegen_InnerTileDescAttrInterface are left
here but are unused.
}];
@@ -110,7 +120,19 @@
"bool", "false",
"If true, the intrinsic is used in an M↔N-swapped orientation "
"(LHS/RHS roles exchanged, accumulator "
- "column-major).">:$transposed_intrinsic);
+ "column-major).">:$transposed_intrinsic,
+ DefaultValuedParameter<
+ "::mlir::Type", "::mlir::Type()",
+ "LHS element type, used only by type-polymorphic intrinsics "
+ "such as MMA_GENERIC_SCALAR_1x1x1.">:$lhs_type,
+ DefaultValuedParameter<
+ "::mlir::Type", "::mlir::Type()",
+ "RHS element type, used only by type-polymorphic intrinsics "
+ "such as MMA_GENERIC_SCALAR_1x1x1.">:$rhs_type,
+ DefaultValuedParameter<
+ "::mlir::Type", "::mlir::Type()",
+ "ACC element type, used only by type-polymorphic intrinsics "
+ "such as MMA_GENERIC_SCALAR_1x1x1.">:$acc_type);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.td
index b88c59a..184060d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.td
@@ -132,6 +132,25 @@
def MMA_ARM_SVE_FMLA_1x4VLx1_F32_F32
: I32EnumAttrCase<"MMA_ARM_SVE_FMLA_1x4VLx1_F32_F32", 0x2210>;
+// Architecture-agnostic, type-polymorphic generic-scalar fallback. Tile
+// shape is 1×1×1, lowering is a single `vector.contract` over the
+// unrolled (`intrinsics_m`, `intrinsics_n`, `intrinsics_k`) tile. Unlike
+// every other `MMAIntrinsic` value this does *not* pin down (LHS, RHS,
+// ACC) element types; those live on the `DataTiledMMAAttr`'s
+// `lhs_type` / `rhs_type` / `acc_type` parameters. The cost model picks
+// it only when no element-type-specific intrinsic on the target matches
+// the encoding's element types.
+//
+// Two enum cases, one per cost-model register-budget: 8 (32-bit ISAs) or
+// 16 (64-bit ISAs). One element of any width occupies one register (a
+// GPR or lane 0 of a SIMD register, depending on what LLVM picks) — the
+// architectural register file isn't modeled. The budget is encoded in
+// the low byte: bits 8..15 = 0xF0 (generic), bits 0..7 = budget.
+def MMA_GENERIC_SCALAR_1x1x1_REG8
+ : I32EnumAttrCase<"MMA_GENERIC_SCALAR_1x1x1_REG8", 0xF008>;
+def MMA_GENERIC_SCALAR_1x1x1_REG16
+ : I32EnumAttrCase<"MMA_GENERIC_SCALAR_1x1x1_REG16", 0xF010>;
+
def IREECPU_MMAIntrinsic
: IREECPU_I32EnumAttr<
"MMAIntrinsic", "Descriptor for different MMA intrinsics",
@@ -150,7 +169,10 @@
MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16,
// Arm SVE
- MMA_ARM_SVE_FMLA_1x4VLx1_F32_F32]>;
+ MMA_ARM_SVE_FMLA_1x4VLx1_F32_F32,
+
+ // Architecture-agnostic generic-scalar fallback (type-polymorphic).
+ MMA_GENERIC_SCALAR_1x1x1_REG8, MMA_GENERIC_SCALAR_1x1x1_REG16]>;
//===----------------------------------------------------------------------===//
// CPU Lowering Pipelines
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h
index 787b726..8892895 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h
@@ -82,6 +82,13 @@
// Values: AVX/AVX2 = 16 × 32 B, AVX-512 = 32 × 64 B, SVE/SVE2 = 32 × 16 B.
int64_t getRegisterSpaceBytes(MMAIntrinsic intrinsic);
+// True if `intr` is one of the `MMA_GENERIC_SCALAR_1x1x1_REG*` cases.
+bool isGenericScalar(MMAIntrinsic intr);
+
+// For an `MMA_GENERIC_SCALAR_1x1x1_REG*` intrinsic, returns the register
+// budget encoded in the enum case (8 or 16). Asserts otherwise.
+int64_t getGenericScalarRegisterBudget(MMAIntrinsic intr);
+
} // namespace mlir::iree_compiler::IREE::CPU
// clang-format on
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
index 1ccb6a7..6ab4842 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
@@ -118,3 +118,96 @@
// CHECK: arith.extsi {{.*}} : vector<32xi8> to vector<32xi16>
// CHECK: %[[DOT:.+]] = llvm.call_intrinsic "llvm.x86.avx512.pmaddw.d.512"({{.*}}) : (vector<32xi16>, vector<32xi16>) -> vector<16xi32>
// CHECK: arith.addi {{.*}}, %[[DOT]] : vector<16xi32>
+
+// -----
+
+// The MMA_GENERIC_SCALAR_1x1x1_REG* family is type-polymorphic and 1×1×1;
+// after applying `intrinsics_m` / `intrinsics_n` / `intrinsics_k`, the
+// operand tiles are row-major (M, K) / (N, K) / (M, N) — exactly the shape
+// `linalg.mmt4d` vectorizes to. So these lower directly to a single
+// `vector.contract` over the unrolled tile, bypassing the swizzle-
+// distribute machinery the other (architecture-specific) intrinsics use.
+// Mixed-precision element types (e.g. bf16 inputs, f32 accumulator) are
+// handled by an explicit `arith.extf` widen of LHS/RHS to ACC's element
+// type before the contract; a homogeneous integer case lowers analogously
+// through `arith.extsi`. The `_REG*` budget suffix is a property of the
+// cost model — at lowering time all variants behave identically.
+
+#contraction_accesses_g = [
+ affine_map<() -> ()>,
+ affine_map<() -> ()>,
+ affine_map<() -> ()>
+]
+func.func @lower_generic_bf16_f32(
+ %lhs: vector<8x1xbf16>, %rhs: vector<8x1xbf16>, %acc: vector<8x8xf32>)
+ -> vector<8x8xf32> {
+ %0 = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
+ indexing_maps = #contraction_accesses_g,
+ iterator_types = [],
+ kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG8, intrinsics_m = 8, intrinsics_n = 8, lhs_type = bf16, rhs_type = bf16, acc_type = f32>,
+ semantics = #iree_cpu.mma_semantics<>
+ } : vector<8x1xbf16>, vector<8x1xbf16> into vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+
+func.func @lower_generic_i32_i8(
+ %lhs: vector<4x1xi8>, %rhs: vector<4x1xi8>, %acc: vector<4x4xi32>)
+ -> vector<4x4xi32> {
+ %0 = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
+ indexing_maps = #contraction_accesses_g,
+ iterator_types = [],
+ kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG8, intrinsics_m = 4, intrinsics_n = 4, lhs_type = i8, rhs_type = i8, acc_type = i32>,
+ semantics = #iree_cpu.mma_semantics<>
+ } : vector<4x1xi8>, vector<4x1xi8> into vector<4x4xi32>
+ return %0 : vector<4x4xi32>
+}
+
+// Same shape as above but with unsigned LHS/RHS — widening to acc must
+// use `arith.extui`, not `arith.extsi`. Storage stays signless (the inner
+// tile types come from `getABCElementTypes`, which strips signedness);
+// the unsigned annotation lives on the attr's `lhs_type` / `rhs_type`.
+// Acc stays signed.
+func.func @lower_generic_i32_ui8(
+ %lhs: vector<4x1xi8>, %rhs: vector<4x1xi8>, %acc: vector<4x4xi32>)
+ -> vector<4x4xi32> {
+ %0 = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
+ indexing_maps = #contraction_accesses_g,
+ iterator_types = [],
+ kind = #iree_cpu.data_tiled_mma_layout<intrinsic = MMA_GENERIC_SCALAR_1x1x1_REG8, intrinsics_m = 4, intrinsics_n = 4, lhs_type = ui8, rhs_type = ui8, acc_type = i32>,
+ semantics = #iree_cpu.mma_semantics<>
+ } : vector<4x1xi8>, vector<4x1xi8> into vector<4x4xi32>
+ return %0 : vector<4x4xi32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.lower_inner_tiled
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @lower_generic_bf16_f32
+// CHECK: %[[LHS_F32:.+]] = arith.extf %{{.+}} : vector<8x1xbf16> to vector<8x1xf32>
+// CHECK: %[[RHS_F32:.+]] = arith.extf %{{.+}} : vector<8x1xbf16> to vector<8x1xf32>
+// CHECK: vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LHS_F32]], %[[RHS_F32]], %{{.+}} : vector<8x1xf32>, vector<8x1xf32> into vector<8x8xf32>
+
+// CHECK-LABEL: func @lower_generic_i32_i8
+// CHECK: %[[LHS_I32:.+]] = arith.extsi %{{.+}} : vector<4x1xi8> to vector<4x1xi32>
+// CHECK: %[[RHS_I32:.+]] = arith.extsi %{{.+}} : vector<4x1xi8> to vector<4x1xi32>
+// CHECK: vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LHS_I32]], %[[RHS_I32]], %{{.+}} : vector<4x1xi32>, vector<4x1xi32> into vector<4x4xi32>
+
+// CHECK-LABEL: func @lower_generic_i32_ui8
+// CHECK: %[[LHS_UI32:.+]] = arith.extui %{{.+}} : vector<4x1xi8> to vector<4x1xi32>
+// CHECK: %[[RHS_UI32:.+]] = arith.extui %{{.+}} : vector<4x1xi8> to vector<4x1xi32>
+// CHECK: vector.contract
+// CHECK-SAME: %[[LHS_UI32]], %[[RHS_UI32]], %{{.+}} : vector<4x1xi32>, vector<4x1xi32> into vector<4x4xi32>
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
index 8561f3c..06d0f9c 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp
@@ -345,41 +345,65 @@
}
}
-/// Returns x86 `MMAIntrinsic` cases whose required ISA extensions are all
-/// present in `config` (`cpu_features` / target features). Only the "natural"
-/// (M<=N) intrinsic orientations are listed; the M↔N-swapped orientation is
-/// expressed by the `transposed_intrinsic` flag on DataTiledMMAAttr, enumerated
-/// separately by the cost model.
+/// Returns the `MMA_GENERIC_SCALAR_1x1x1_REG*` variant for `config`. Almost
+/// every 64-bit CPU has at least 16 architectural registers we could put
+/// scalar tiles into (GPRs or lane 0 of a SIMD register, depending on what
+/// LLVM picks), so we pick `_REG16` on 64-bit ISAs and `_REG8` on 32-bit
+/// ones. This is a performance heuristic only — too high a budget would
+/// spill, too low would underutilize — and the generic-scalar fallback is
+/// slow either way.
+static IREE::CPU::MMAIntrinsic
+pickGenericScalarMMAForTarget(DictionaryAttr config) {
+ using IREE::CPU::MMAIntrinsic;
+ std::optional<llvm::Triple> triple = getTargetTriple(config);
+ bool is64Bit = triple ? triple->isArch64Bit() : true;
+ return is64Bit ? MMAIntrinsic::MMA_GENERIC_SCALAR_1x1x1_REG16
+ : MMAIntrinsic::MMA_GENERIC_SCALAR_1x1x1_REG8;
+}
+
+/// Returns the `MMAIntrinsic` cases potentially usable for `config`: the x86
+/// architecture-specific intrinsics whose required ISA extensions are all
+/// present, plus one of the architecture-agnostic type-polymorphic
+/// `MMA_GENERIC_SCALAR_1x1x1_REG*` fallback variants (the one whose register
+/// budget matches `config`'s target). Only the "natural" (M<=N) orientation
+/// is listed; the M↔N-swapped orientation is expressed by the
+/// `transposed_intrinsic` flag on DataTiledMMAAttr and is enumerated
+/// separately by the cost model. The 1×1×1 generic intrinsic naturally
+/// loses to any real intrinsic that fits, so it only wins as a fallback
+/// when no real MMA covers the requested element types.
static SmallVector<IREE::CPU::MMAIntrinsic>
getMmaIntrinsicsForTargetConfig(DictionaryAttr config) {
using IREE::CPU::MMAIntrinsic;
- SmallVector<MMAIntrinsic> out;
if (!config) {
- return out;
+ return {};
}
- if (!isX86(config)) {
- return out;
- }
- static const MMAIntrinsic kAllX86[] = {
- MMAIntrinsic::MMA_X86_AVX2_FMA_1x8x1_F32_F32,
- MMAIntrinsic::MMA_X86_AVX512_1x8x1_F64_F64,
- MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F32,
- MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32,
- MMAIntrinsic::MMA_X86_AVX512FP16_1x32x1_F16_F16,
- MMAIntrinsic::MMA_X86_AVX512BF16_1x16x2_F32_BF16,
- MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I16,
- MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I16,
- MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16,
- MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16,
- };
- for (MMAIntrinsic intr : kAllX86) {
- SmallVector<StringRef> required = getMmaIntrinsicRequiredFeatures(intr);
- if (required.empty()) {
- continue;
- }
- if (llvm::all_of(required,
- [&](StringRef f) { return hasFeature(config, f); })) {
- out.push_back(intr);
+ // Always include the generic-scalar fallback first — it's the only
+ // intrinsic guaranteed to apply on any target, so anchoring the list
+ // with it means an early-return below can never accidentally produce
+ // an empty result for an otherwise-valid target.
+ SmallVector<MMAIntrinsic> out{pickGenericScalarMMAForTarget(config)};
+ if (isX86(config)) {
+ static const MMAIntrinsic kAllX86[] = {
+ MMAIntrinsic::MMA_X86_AVX2_FMA_1x8x1_F32_F32,
+ MMAIntrinsic::MMA_X86_AVX512_1x8x1_F64_F64,
+ MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F32,
+ MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32,
+ MMAIntrinsic::MMA_X86_AVX512FP16_1x32x1_F16_F16,
+ MMAIntrinsic::MMA_X86_AVX512BF16_1x16x2_F32_BF16,
+ MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I16,
+ MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I16,
+ MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16,
+ MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16,
+ };
+ for (MMAIntrinsic intr : kAllX86) {
+ SmallVector<StringRef> required = getMmaIntrinsicRequiredFeatures(intr);
+ if (required.empty()) {
+ continue;
+ }
+ if (llvm::all_of(required,
+ [&](StringRef f) { return hasFeature(config, f); })) {
+ out.push_back(intr);
+ }
}
}
return out;
@@ -397,13 +421,27 @@
};
/// Returns the IntrinsicInfo for `intr` in the given orientation, if its ABC
-/// element types match `elementTypes`. Returns nullopt otherwise.
+/// element types match `elementTypes`. Returns nullopt otherwise. The
+/// `MMA_GENERIC_SCALAR_1x1x1_REG*` family is type-polymorphic — those
+/// values match any element-type triple by construction; their `lhsBits` /
+/// `rhsBits` / `accBits` stay at the struct's default 0 since the
+/// generic-scalar branch of `chooseUnrolling` doesn't read them.
static std::optional<IntrinsicInfo>
getIntrinsicInfo(MLIRContext *ctx, ArrayRef<Type> elementTypes,
IREE::CPU::MMAIntrinsic intr, bool transposed) {
- auto base = IREE::CPU::DataTiledMMAAttr::get(ctx, intr, /*intrinsics_m=*/1,
- /*intrinsics_n=*/1,
- /*intrinsics_k=*/1, transposed);
+ if (IREE::CPU::isGenericScalar(intr)) {
+ if (elementTypes.size() != 3 || !elementTypes[0] || !elementTypes[1] ||
+ !elementTypes[2]) {
+ return std::nullopt;
+ }
+ IntrinsicInfo info;
+ info.intrinsicM = info.intrinsicN = info.intrinsicK = 1;
+ return info;
+ }
+ auto base = IREE::CPU::DataTiledMMAAttr::get(
+ ctx, intr, /*intrinsics_m=*/1, /*intrinsics_n=*/1,
+ /*intrinsics_k=*/1, transposed, /*lhs_type=*/Type(),
+ /*rhs_type=*/Type(), /*acc_type=*/Type());
SmallVector<VectorType> baseTiles;
base.getUndistributedTileTypes(baseTiles);
if (baseTiles.size() != 3) {
@@ -489,13 +527,28 @@
// Phase 2 of `chooseCpuInnerTiledMmaForEncoding`: for an already-chosen
// (intrinsic, transposed) pair, pick the largest power-of-two unroll
// factors (intrinsicsM, intrinsicsN) such that the three tiles
-// (ACC + LHS + RHS) still fit in the target's vector register file,
+// (ACC + LHS + RHS) still fit in the target's register budget,
// breaking ties with arithmetic intensity (effM*effN)/(effM+effN) so
// approximately-square tiles win.
//
-// Returns nullopt if no feasible (im, in) exists. The returned pair is
-// (intrinsicsM, intrinsicsN).
-static std::optional<std::pair<int64_t, int64_t>>
+// "Register budget" depends on what the lowering will use:
+// * Architecture-specific SIMD intrinsics use the vector register file,
+// measured in bits, with element widths from `IntrinsicInfo` so a
+// wider element type costs proportionally more of the budget.
+// * The type-polymorphic `MMA_GENERIC_SCALAR_1x1x1_REG*` lowers to
+// scalar arithmetic, where one element occupies one register (a GPR
+// or lane 0 of a SIMD register, depending on what LLVM picks)
+// regardless of bit width. So the budget is in registers, with all
+// element "widths" treated as 1.
+//
+// Returns nullopt if no feasible unrolling exists. The returned tuple is
+// (intrinsicsM, intrinsicsN, intrinsicsK). When the (im, in) search loop
+// finds no candidate that fits the budget — possible for the generic
+// intrinsic with a sub-byte LHS/RHS forcing intrinsics_k = 8 against an
+// 8-register budget — we fall back to (1, 1) so we still return *some*
+// valid unrolling. Some register spill is preferable to failing data-
+// tiling outright.
+static std::optional<std::tuple<int64_t, int64_t, int64_t>>
chooseUnrolling(MLIRContext *ctx, ArrayRef<Type> elementTypes,
IREE::CPU::MMAIntrinsic intr, bool transposed,
const IREE::Encoding::BxMxNxKxKb &matmulSizes) {
@@ -504,20 +557,49 @@
if (!info) {
return std::nullopt;
}
- int64_t regBitBudget = IREE::CPU::getRegisterSpaceBytes(intr) * 8;
- int64_t capMPo2 = po2UnrollCap(matmulSizes.M, info->intrinsicM, regBitBudget);
- int64_t capNPo2 = po2UnrollCap(matmulSizes.N, info->intrinsicN, regBitBudget);
- int64_t accTerm = info->intrinsicM * info->intrinsicN * info->accBits;
- int64_t lhsTerm = info->intrinsicM * info->intrinsicK * info->lhsBits;
- int64_t rhsTerm = info->intrinsicN * info->intrinsicK * info->rhsBits;
- std::optional<std::pair<int64_t, int64_t>> best;
+ int64_t intrinsicsK = 1;
+ int64_t budget;
+ int64_t accUnit, lhsUnit, rhsUnit;
+ if (IREE::CPU::isGenericScalar(intr)) {
+ // Real hardware MMA intrinsics bake whatever K-grouping they need into
+ // the intrinsic itself. The generic-scalar fallback doesn't, so for
+ // sub-byte LHS/RHS types we have to group enough K-elements per
+ // contiguous block that a packed group is byte-addressable: pick the
+ // smallest power of two K such that K*lhsBits and K*rhsBits are both
+ // multiples of 8. K∈{1,2,4,8} is enough to cover every type from i8/f8
+ // (K=1) down to i1 (K=8).
+ int64_t lhsBits = elementTypes[0].getIntOrFloatBitWidth();
+ int64_t rhsBits = elementTypes[1].getIntOrFloatBitWidth();
+ while (intrinsicsK <= 8 && (((intrinsicsK * lhsBits) % 8) != 0 ||
+ ((intrinsicsK * rhsBits) % 8) != 0)) {
+ intrinsicsK *= 2;
+ }
+ if (intrinsicsK > 8) {
+ return std::nullopt;
+ }
+ // The budget — chosen at intrinsic-pick time as a function of the
+ // target's pointer width — is encoded in the enum value itself.
+ budget = IREE::CPU::getGenericScalarRegisterBudget(intr);
+ accUnit = lhsUnit = rhsUnit = 1;
+ } else {
+ budget = IREE::CPU::getRegisterSpaceBytes(intr) * 8;
+ accUnit = info->accBits;
+ lhsUnit = info->lhsBits;
+ rhsUnit = info->rhsBits;
+ }
+ int64_t capMPo2 = po2UnrollCap(matmulSizes.M, info->intrinsicM, budget);
+ int64_t capNPo2 = po2UnrollCap(matmulSizes.N, info->intrinsicN, budget);
+ int64_t accTerm = info->intrinsicM * info->intrinsicN * accUnit;
+ int64_t lhsTerm = info->intrinsicM * info->intrinsicK * intrinsicsK * lhsUnit;
+ int64_t rhsTerm = info->intrinsicN * info->intrinsicK * intrinsicsK * rhsUnit;
+ std::optional<std::pair<int64_t, int64_t>> bestMN;
double bestIntensity = -1.0;
// Enumerate power-of-two intrinsicsM; for each, pick the largest feasible
- // power-of-two intrinsicsN under the bit budget and the static N cap.
- // The budget bounds im on its own (im*lhsTerm alone must be < budget),
- // which terminates the loop without any numRegs-style cap.
+ // power-of-two intrinsicsN under the budget and the static N cap. The
+ // budget bounds im on its own (im*lhsTerm alone must be < budget), which
+ // terminates the loop without any numRegs-style cap.
for (int64_t im = 1; im <= capMPo2; im *= 2) {
- int64_t remaining = regBitBudget - im * lhsTerm;
+ int64_t remaining = budget - im * lhsTerm;
if (remaining <= 0) {
break;
}
@@ -532,10 +614,14 @@
double intensity = effM * effN / (effM + effN);
if (intensity > bestIntensity) {
bestIntensity = intensity;
- best = {im, in};
+ bestMN = {im, in};
}
}
- return best;
+ // Fall back to (1, 1) if no (im, in) fit — see function comment.
+ if (!bestMN) {
+ bestMN = {1, 1};
+ }
+ return std::make_tuple(bestMN->first, bestMN->second, intrinsicsK);
}
// Picks a CPU `DataTiledMMAAttr` for `iree_codegen.inner_tiled` given an
@@ -559,14 +645,24 @@
return {};
}
auto [intr, transposed] = *intrChoice;
- std::optional<std::pair<int64_t, int64_t>> unroll =
+ std::optional<std::tuple<int64_t, int64_t, int64_t>> unroll =
chooseUnrolling(ctx, elementTypes, intr, transposed, *matmulSizes);
if (!unroll) {
return {};
}
- auto [intrinsicsM, intrinsicsN] = *unroll;
+ auto [intrinsicsM, intrinsicsN, intrinsicsK] = *unroll;
+ // The type-polymorphic generic intrinsic doesn't bake element types into
+ // its enum value; we have to pin them down on the attr itself so the
+ // lowering and `getABCElementTypes(ctx, attr)` can read them back.
+ Type lhsType, rhsType, accType;
+ if (IREE::CPU::isGenericScalar(intr)) {
+ lhsType = elementTypes[0];
+ rhsType = elementTypes[1];
+ accType = elementTypes[2];
+ }
return IREE::CPU::DataTiledMMAAttr::get(ctx, intr, intrinsicsM, intrinsicsN,
- /*intrinsics_k=*/1, transposed);
+ intrinsicsK, transposed, lhsType,
+ rhsType, accType);
}
/// Lowers a contraction under a `CPUEncodingResolverAttr` with