[Codegen][CPU] Route inner_tiled broadcast into m_bcst-foldable slot. (#24516)
The CPU `inner_tiled` lowering replicates whichever of LHS/RHS has fewer
lanes up to the other's lane count before calling the LLVM intrinsic.
The previous lowering emitted this as `vector.broadcast` to a
`(replicate, K)` 2-D shape followed by `vector.shape_cast` to flat, with
a comment claiming the x86 backend's instruction selector would recover
the `{1toN}` broadcast-from-memory form on its own.
Empirically that did not work for bf16 matmul codegen on Zen 4: every
`vdpbf16ps` instruction was preceded by a separate `vbroadcastss`,
doubling the per-row uop count of the hot inner loop. Two structural
reasons:
1. The IR shape mattered. LLVM's x86 ISel `m_bcst` patterns key on the
canonical `_mm512_set1_ps`-style splat: a scalar fed into `insertelement
<N x T> poison, T, 0` followed by `shufflevector <N x T>, poison, <N x
i32> zeroinitializer`, with `T` a float. Our `vector.broadcast` to a
`(replicate, K)` 2-D shape + `vector.shape_cast` lowered to a different
shufflevector pattern (or a direct `<K x elem> -> <N*K x elem>`
interleaved shuffle) that did not pattern-match.
2. The intrinsic operand position mattered. The ISA-level `m_bcst` EVEX
operand is on the *third* source of `dpbf16ps`/`vpdpwssd`/ `pmaddwd`,
and on the `b` operand (second multiplicand) of FMA's `a*b+c`. We passed
the broadcasted operand into the LHS slot, putting it where ISel cannot
fold a memory broadcast.
Rewrite the replication to bitcast the source to a 1-lane vector of
width `K * elem_bits` (with a float lane type when that width is 32 or
64 bits, matching the `_mm512_set1_ps` shape), extract the scalar,
`vector.broadcast` it to `replicate` lanes, then bitcast back. Track
whether the broadcast landed on lhs and, for the symmetric LLVM
intrinsics, route the broadcasted operand into the m_bcst-foldable slot.
For K=1 the bitcast pair is a no-op LLVM elides. vpdpbusd is asymmetric
(UI8 must stay in the second slot); its existing sign-aware routing
happens to put the broadcast in the m_bcst slot precisely in the two
orientations where the ISA allows the fold, so no change needed there.
Measured on a 4096×4096 dynamic-shape bf16×bf16 -> f32 matmul on Zen 4
(avx512_bf16, no AMX), with `--iree-opt-data-tiling
--iree-llvmcpu-enable-inner-tiled`:
- All 29 `vdpbf16ps` in the inner loop now use the `{1to16}`
memory-broadcast form (vs 0 before); all 29 separate `vbroadcastss` are
gone.
- End-to-end matmul: 80.8 ms -> 62.7 ms (1.29x faster, 16.0 it/s -> from
12.4 it/s), closing ~60% of the gap to the precompiled mmt4d ukernel
(50.5 ms).
Progress towards #24515.
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
index 081dce6..bddf8f7 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
@@ -692,47 +692,6 @@
static Value createCpuMmaIntrinsicCall(OpBuilder &builder, Location loc,
MMAIntrinsic intrinsic, Value lhs,
Value rhs, Value acc) {
- // Replicate whichever of LHS/RHS has fewer lanes up to the other's lane
- // count, so both reach the intrinsic's flat lane width. In the natural
- // 1×N×K orientation that's the LHS (M=1) broadcast across the N lanes of
- // the RHS; in the M↔N-swapped N×1×K orientation it's the RHS (N=1) that
- // gets broadcast. Going through a (lanes, K) 2-D form keeps the K-pair
- // contiguous; a final shape_cast collapses to the flat 1-D vector the
- // LLVM intrinsic expects. All x86 LLVM intrinsics we target here (AVX-512
- // FMA, VNNI vpdpwssd, BF16 dpbf16ps, integer pmaddwd) take same-width
- // vector operands — there is no single-lane-input variant. The ISA-level
- // `{1toN}` broadcast-from-memory form is recovered later by the x86
- // backend's instruction selector pattern-matching this `vector.broadcast`-
- // of-load into the EVEX broadcast operand, so the explicit broadcast here
- // is what *enables* that, not a perf liability.
- // TODO(24311): Arm's by-element FMA (`fmla.4s vd, vn, vm[idx]`) is exposed
- // via separate intrinsics (e.g. `llvm.aarch64.neon.fma.lane.v4f32`) that
- // take `(vector, vector, lane_idx)`; when we add Arm support, those cases
- // should bypass this broadcast and emit the lane-index intrinsic directly.
- auto lhsType = cast<VectorType>(lhs.getType());
- auto rhsType = cast<VectorType>(rhs.getType());
- if (lhsType.getNumElements() != rhsType.getNumElements()) {
- // `bcastSrc` is the operand fed into `vector.broadcast`; `bcastDst` is
- // the operand whose lane count we match. They alias the underlying
- // lhs/rhs through pointers so the broadcast result flows back into the
- // variable used by the switch below.
- Value *bcastSrc = &lhs;
- Value *bcastDst = &rhs;
- VectorType bcastSrcType = lhsType;
- VectorType bcastDstType = rhsType;
- if (bcastSrcType.getNumElements() > bcastDstType.getNumElements()) {
- std::swap(bcastSrc, bcastDst);
- std::swap(bcastSrcType, bcastDstType);
- }
- int64_t replicate =
- bcastDstType.getNumElements() / bcastSrcType.getNumElements();
- auto bcastType = VectorType::get({replicate, bcastSrcType.getNumElements()},
- bcastSrcType.getElementType());
- Value bcast =
- vector::BroadcastOp::create(builder, loc, bcastType, *bcastSrc);
- *bcastSrc = vector::ShapeCastOp::create(builder, loc, bcastDstType, bcast);
- }
-
// Sign-/float-extend a vector to a wider element type. Used by the
// *_CASTF32 (f16 → f32) and *_CASTI16 (i8 → i16) variants where the
// intrinsic only exists at the wider type.
@@ -745,6 +704,111 @@
return arith::ExtSIOp::create(builder, loc, wideTy, v);
};
+ // For *_CAST* intrinsics, widen lhs/rhs to the intrinsic's element type
+ // *before* the broadcast below. The alternative — widening after the
+ // broadcast, when the smaller operand has already been replicated to the
+ // full lane count — would re-do the i8 → i16 (or f16 → f32) extension on
+ // every M row of the unrolled tile and force LLVM to scalarize the widen
+ // into a per-row `vpbroadcastw` + `vpandq` + `vpsrlvw` sequence around
+ // each FMA. Widening the narrow source vector once before the broadcast
+ // keeps the per-row inner loop down to just the FMA (with `m_bcst`).
+ Type widenTo;
+ switch (intrinsic) {
+ case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32:
+ case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F16_CASTF32:
+ widenTo = builder.getF32Type();
+ break;
+ case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
+ case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I8_CASTI16:
+ case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16:
+ case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I8_CASTI16:
+ widenTo = builder.getI16Type();
+ break;
+ default:
+ break;
+ }
+ if (widenTo) {
+ lhs = widen(lhs, widenTo);
+ rhs = widen(rhs, widenTo);
+ }
+
+ // Replicate whichever of LHS/RHS has fewer lanes up to the other's lane
+ // count, so both reach the intrinsic's flat lane width. In the natural
+ // 1×N×K orientation that's the LHS (M=1) broadcast across the N lanes of
+ // the RHS; in the M↔N-swapped N×1×K orientation it's the RHS (N=1) that
+ // gets broadcast. All x86 LLVM intrinsics we target here (AVX-512 FMA,
+ // VNNI vpdpwssd, BF16 dpbf16ps, integer pmaddwd) take same-width vector
+ // operands — there is no single-lane-input variant.
+ //
+ // We emit the replication as the splat-of-a-single-packed-lane idiom:
+ // bitcast the K-element source to a 1-lane vector whose lane covers the
+ // full `K * elem_bits` broadcast unit, splat to `replicate` lanes via a
+ // zero-mask `vector.shuffle`, and bitcast back to the flat element type.
+ // This is the canonical IR shape that LLVM's x86 instruction selector
+ // pattern-matches to recover the ISA-level `{1toN}` broadcast-from-memory
+ // EVEX operand (`m32bcst`/`m64bcst`). The natural alternative — directly
+ // shuffling `<K x elem>` to `<replicate*K x elem>` with the interleaved
+ // mask `[0,...,K-1, 0,...,K-1, ...]`, or going through `vector.broadcast`
+ // to a (replicate, K) 2-D shape and then `vector.shape_cast`-ing to flat
+ // — is semantically equivalent but lowers to a `<K x elem>` shufflevector
+ // that ISel does *not* recognize as a broadcast and so emits as a
+ // separate `vbroadcastss`/`vbroadcastsd` before each FMA, doubling the
+ // per-row uop count of the hot inner loop. For K=1 the bitcast pair is a
+ // width-preserving no-op LLVM elides.
+ // TODO(24311): Arm's by-element FMA (`fmla.4s vd, vn, vm[idx]`) is exposed
+ // via separate intrinsics (e.g. `llvm.aarch64.neon.fma.lane.v4f32`) that
+ // take `(vector, vector, lane_idx)`; when we add Arm support, those cases
+ // should bypass this replication and emit the lane-index intrinsic directly.
+ auto lhsType = cast<VectorType>(lhs.getType());
+ auto rhsType = cast<VectorType>(rhs.getType());
+ // Tracks whether the broadcast landed on lhs; used by the symmetric
+ // intrinsic cases below to route the broadcast operand into the
+ // m_bcst-foldable slot (third operand for dpbf16ps/vpdpwssd/pmaddwd, second
+ // operand for FMA's `a*b+c` form).
+ bool lhsIsBroadcast = false;
+ if (lhsType.getNumElements() != rhsType.getNumElements()) {
+ // `bcastSrc` is the operand being replicated; `bcastDst` is the operand
+ // whose lane count we match. They alias the underlying lhs/rhs through
+ // pointers so the result flows back into the variable used by the switch
+ // below.
+ Value *bcastSrc = &lhs;
+ Value *bcastDst = &rhs;
+ VectorType bcastSrcType = lhsType;
+ VectorType bcastDstType = rhsType;
+ if (bcastSrcType.getNumElements() > bcastDstType.getNumElements()) {
+ std::swap(bcastSrc, bcastDst);
+ std::swap(bcastSrcType, bcastDstType);
+ }
+ lhsIsBroadcast = (bcastSrc == &lhs);
+ int64_t srcN = bcastSrcType.getNumElements();
+ int64_t replicate = bcastDstType.getNumElements() / srcN;
+ // The broadcast unit is `srcN` packed source elements. When it is a single
+ // element, splat through that element's own type; when it packs several,
+ // no scalar type names the unit, so splat through an integer of the unit's
+ // bit width. x86 ISel's m_bcst fold keys on the splat *shape* (below) and
+ // the broadcast-load width (`m32bcst`/`m64bcst` match any 32-/64-bit load,
+ // integer or float) -- not on a float element type.
+ Type srcElemTy = bcastSrcType.getElementType();
+ Type laneTy = srcN == 1 ? srcElemTy
+ : Type(builder.getIntegerType(
+ srcElemTy.getIntOrFloatBitWidth() * srcN));
+ auto singleLaneTy = VectorType::get({1}, laneTy);
+ auto replicatedTy = VectorType::get({replicate}, laneTy);
+ Value asSingleLane =
+ vector::BitCastOp::create(builder, loc, singleLaneTy, *bcastSrc);
+ // Extract scalar + `vector.broadcast` so this lowers to LLVM's canonical
+ // `insertelement <N x T> poison, T, 0` + `shufflevector <N x T> ...,
+ // <N x i32> zeroinitializer` splat shape (the same shape the ukernel's
+ // `_mm512_set1_ps` produces, and the one x86 ISel folds into m_bcst).
+ // A direct `shufflevector <1 x T> -> <N x T>` is semantically equivalent
+ // but does not pattern-match.
+ Value scalar = vector::ExtractOp::create(builder, loc, asSingleLane,
+ ArrayRef<int64_t>{0});
+ Value splatted =
+ vector::BroadcastOp::create(builder, loc, replicatedTy, scalar);
+ *bcastSrc = vector::BitCastOp::create(builder, loc, bcastDstType, splatted);
+ }
+
// Emit llvm.call_intrinsic.
auto call = [&builder, loc](StringRef name, Type resultType,
ValueRange args) -> Value {
@@ -753,56 +817,55 @@
.getResult(0);
};
- Type f32 = builder.getF32Type();
- Type i16 = builder.getI16Type();
Type accType = acc.getType();
- // The x86 MMAs supported here are mostly LHS/RHS-symmetric (FMA, dpbf16ps,
+ // Most x86 MMAs supported here are LHS/RHS-symmetric (FMA, dpbf16ps,
// vpdpwssd, pmaddw): natural and swapped orientations share the same LLVM
- // intrinsic and arg order, since after the broadcast above both operands
- // have the same lane count. The exception is vpdpbusd, which is asymmetric
- // (first byte source is unsigned, second is signed); for its swapped
- // sibling we route `rhs` (the unsigned operand) into the first slot.
+ // intrinsic and arg order. For these we route the broadcasted operand into
+ // the m_bcst-foldable slot: the third operand for dpbf16ps/vpdpwssd/
+ // pmaddwd, and the `b` operand (= second mul) for FMA's `a*b+c`. The
+ // exception is vpdpbusd, which is asymmetric (first byte source unsigned,
+ // second signed): we keep its existing sign-aware routing — when the
+ // broadcast happens to land in the signed slot, that maps to vpdpbusd's
+ // m_bcst-foldable slot too; when it lands on the unsigned side, no fold is
+ // available (the ISA's m_bcst is on the signed operand).
+ //
+ // For *_CAST* variants (f16 → f32, i8 → i16), the widening already ran at
+ // the top of this function, so lhs/rhs reach the intrinsic call at the
+ // wider type without any per-row widen in the unrolled tile.
+ Value bcst = lhsIsBroadcast ? lhs : rhs;
+ Value full = lhsIsBroadcast ? rhs : lhs;
switch (intrinsic) {
case MMAIntrinsic::MMA_X86_AVX2_FMA_1x8x1_F32_F32:
case MMAIntrinsic::MMA_X86_AVX2_FMA_8x1x1_F32_F32:
- return call("llvm.fma.v8f32", accType, ValueRange{lhs, rhs, acc});
+ return call("llvm.fma.v8f32", accType, ValueRange{full, bcst, acc});
case MMAIntrinsic::MMA_X86_AVX512_1x8x1_F64_F64:
case MMAIntrinsic::MMA_X86_AVX512_8x1x1_F64_F64:
- return call("llvm.fma.v8f64", accType, ValueRange{lhs, rhs, acc});
+ return call("llvm.fma.v8f64", accType, ValueRange{full, bcst, acc});
case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F32:
case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F32:
- return call("llvm.fma.v16f32", accType, ValueRange{lhs, rhs, acc});
case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32:
case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F16_CASTF32:
- return call("llvm.fma.v16f32", accType,
- ValueRange{widen(lhs, f32), widen(rhs, f32), acc});
+ return call("llvm.fma.v16f32", accType, ValueRange{full, bcst, acc});
case MMAIntrinsic::MMA_X86_AVX512FP16_1x32x1_F16_F16:
case MMAIntrinsic::MMA_X86_AVX512FP16_32x1x1_F16_F16:
- return call("llvm.fma.v32f16", accType, ValueRange{lhs, rhs, acc});
+ return call("llvm.fma.v32f16", accType, ValueRange{full, bcst, acc});
case MMAIntrinsic::MMA_X86_AVX512BF16_1x16x2_F32_BF16:
case MMAIntrinsic::MMA_X86_AVX512BF16_16x1x2_F32_BF16:
return call("llvm.x86.avx512bf16.dpbf16ps.512", accType,
- ValueRange{acc, lhs, rhs});
+ ValueRange{acc, full, bcst});
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I16:
case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I16:
- return call("llvm.x86.avx512.vpdpwssd.512", accType,
- ValueRange{acc, lhs, rhs});
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I8_CASTI16:
return call("llvm.x86.avx512.vpdpwssd.512", accType,
- ValueRange{acc, widen(lhs, i16), widen(rhs, i16)});
+ ValueRange{acc, full, bcst});
case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I16:
- case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I16: {
- return arith::AddIOp::create(
- builder, loc, acc,
- call("llvm.x86.avx512.pmaddw.d.512", accType, ValueRange{lhs, rhs}));
- }
+ case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I16:
case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16:
case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I8_CASTI16: {
return arith::AddIOp::create(
builder, loc, acc,
- call("llvm.x86.avx512.pmaddw.d.512", accType,
- ValueRange{widen(lhs, i16), widen(rhs, i16)}));
+ call("llvm.x86.avx512.pmaddw.d.512", accType, ValueRange{full, bcst}));
}
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x4_I32_UI8_I8:
// LHS unsigned, RHS signed — vpdpbusd takes (acc, unsigned, signed).
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
index 43cf44d..990017d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
@@ -32,9 +32,9 @@
}
// CHECK-LABEL: func @lower_avx512_1x16x1_f32
-// CHECK: vector.broadcast {{.*}} : vector<{{1x1|1}}xf32> to vector<16x1xf32>
-// CHECK: vector.shape_cast {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK: llvm.call_intrinsic "llvm.fma.v16f32"({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
+// CHECK: %[[SCALAR:.+]] = vector.extract {{.*}} : f32 from vector<{{1x1|1}}xf32>
+// CHECK: %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : f32 to vector<16xf32>
+// CHECK: llvm.call_intrinsic "llvm.fma.v16f32"(%{{.+}}, %[[BCST]], %{{.+}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
// -----
@@ -70,10 +70,10 @@
}
// CHECK-LABEL: func @lower_avx512_1x16x1_f16_castf32
-// CHECK: vector.broadcast {{.*}} : vector<{{1x1|1}}xf16> to vector<16x1xf16>
-// CHECK: vector.shape_cast {{.*}} : vector<16x1xf16> to vector<16xf16>
+// CHECK: arith.extf {{.*}} : vector<1xf16> to vector<1xf32>
// CHECK: arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
-// CHECK: arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
+// CHECK: %[[SCALAR:.+]] = vector.extract {{.*}} : f32 from vector<1xf32>
+// CHECK: %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : f32 to vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.fma.v16f32"({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
// -----
@@ -112,10 +112,11 @@
}
// CHECK-LABEL: func @lower_avx512_1x16x2_i8_casti16
-// CHECK: vector.broadcast {{.*}} : vector<{{1x2|2}}xi8> to vector<16x2xi8>
-// CHECK: vector.shape_cast {{.*}} : vector<16x2xi8> to vector<32xi8>
+// CHECK: arith.extsi {{.*}} : vector<2xi8> to vector<2xi16>
// CHECK: arith.extsi {{.*}} : vector<32xi8> to vector<32xi16>
-// CHECK: arith.extsi {{.*}} : vector<32xi8> to vector<32xi16>
+// CHECK: %[[SCALAR:.+]] = vector.extract {{.*}} : i32 from vector<1xi32>
+// CHECK: %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : i32 to vector<16xi32>
+// CHECK: vector.bitcast %[[BCST]] : vector<16xi32> to vector<32xi16>
// CHECK: %[[DOT:.+]] = llvm.call_intrinsic "llvm.x86.avx512.pmaddw.d.512"({{.*}}) : (vector<32xi16>, vector<32xi16>) -> vector<16xi32>
// CHECK: arith.addi {{.*}}, %[[DOT]] : vector<16xi32>
@@ -303,7 +304,7 @@
// CHECK-LABEL: func @lower_avx512_16x1x1_f32
// CHECK: util.hoistable_conversion "shape_cast_to_intrinsic"
// CHECK: vector.shape_cast {{.*}} : vector<2x4x16xf32> to vector<2x4x16x1xf32>
-// CHECK: vector.broadcast {{.*}} : vector<1xf32> to vector<16x1xf32>
+// CHECK: vector.broadcast {{.*}} : f32 to vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.fma.v16f32"
// CHECK: util.hoistable_conversion "shape_cast_from_intrinsic"
// CHECK: vector.shape_cast {{.*}} : vector<2x4x16x1xf32> to vector<2x4x16xf32>
@@ -360,8 +361,8 @@
// CHECK-SAME: %[[N_LHS:[a-zA-Z0-9]+]]: vector<1x4xi8>
// CHECK-SAME: %[[N_RHS:[a-zA-Z0-9]+]]: vector<16x4xi8>
// CHECK: %[[N_RHS_FLAT:.+]] = vector.shape_cast %[[N_RHS]] : vector<16x4xi8> to vector<64xi8>
-// CHECK: %[[N_BCAST:.+]] = vector.broadcast %[[N_LHS]] : vector<1x4xi8> to vector<16x4xi8>
-// CHECK: %[[N_LHS_FLAT:.+]] = vector.shape_cast %[[N_BCAST]] : vector<16x4xi8> to vector<64xi8>
+// CHECK: %[[N_BCAST_I32:.+]] = vector.broadcast %{{.+}} : i32 to vector<16xi32>
+// CHECK: %[[N_LHS_FLAT:.+]] = vector.bitcast %[[N_BCAST_I32]] : vector<16xi32> to vector<64xi8>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpbusd.512"(%{{.+}}, %[[N_LHS_FLAT]], %[[N_RHS_FLAT]])
// Swapped: RHS (ui8, narrow) is broadcast; the lowering swaps arg order at
@@ -370,6 +371,6 @@
// CHECK-SAME: %[[S_LHS:[a-zA-Z0-9]+]]: vector<16x4xi8>
// CHECK-SAME: %[[S_RHS:[a-zA-Z0-9]+]]: vector<1x4xi8>
// CHECK: %[[S_LHS_FLAT:.+]] = vector.shape_cast %[[S_LHS]] : vector<16x4xi8> to vector<64xi8>
-// CHECK: %[[S_BCAST:.+]] = vector.broadcast %[[S_RHS]] : vector<1x4xi8> to vector<16x4xi8>
-// CHECK: %[[S_RHS_FLAT:.+]] = vector.shape_cast %[[S_BCAST]] : vector<16x4xi8> to vector<64xi8>
+// CHECK: %[[S_BCAST_I32:.+]] = vector.broadcast %{{.+}} : i32 to vector<16xi32>
+// CHECK: %[[S_RHS_FLAT:.+]] = vector.bitcast %[[S_BCAST_I32]] : vector<16xi32> to vector<64xi8>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpbusd.512"(%{{.+}}, %[[S_RHS_FLAT]], %[[S_LHS_FLAT]])