[Codegen][CPU] Route inner_tiled broadcast into m_bcst-foldable slot. (#24516)

The CPU `inner_tiled` lowering replicates whichever of LHS/RHS has fewer
lanes up to the other's lane count before calling the LLVM intrinsic.
The previous lowering emitted this as `vector.broadcast` to a
`(replicate, K)` 2-D shape followed by `vector.shape_cast` to flat, with
a comment claiming the x86 backend's instruction selector would recover
the `{1toN}` broadcast-from-memory form on its own.

Empirically that did not work for bf16 matmul codegen on Zen 4: every
`vdpbf16ps` instruction was preceded by a separate `vbroadcastss`,
doubling the per-row uop count of the hot inner loop. Two structural
reasons:

1. The IR shape mattered. LLVM's x86 ISel `m_bcst` patterns key on the
canonical `_mm512_set1_ps`-style splat: a scalar fed into `insertelement
<N x T> poison, T, 0` followed by `shufflevector <N x T>, poison, <N x
i32> zeroinitializer`, with `T` a float. Our `vector.broadcast` to a
`(replicate, K)` 2-D shape + `vector.shape_cast` lowered to a different
shufflevector pattern (or a direct `<K x elem> -> <N*K x elem>`
interleaved shuffle) that did not pattern-match.

2. The intrinsic operand position mattered. The ISA-level `m_bcst` EVEX
operand is on the *third* source of `dpbf16ps`/`vpdpwssd`/ `pmaddwd`,
and on the `b` operand (second multiplicand) of FMA's `a*b+c`. We passed
the broadcasted operand into the LHS slot, putting it where ISel cannot
fold a memory broadcast.

Rewrite the replication to bitcast the source to a 1-lane vector of
width `K * elem_bits` (with a float lane type when that width is 32 or
64 bits, matching the `_mm512_set1_ps` shape), extract the scalar,
`vector.broadcast` it to `replicate` lanes, then bitcast back. Track
whether the broadcast landed on lhs and, for the symmetric LLVM
intrinsics, route the broadcasted operand into the m_bcst-foldable slot.
For K=1 the bitcast pair is a no-op LLVM elides. vpdpbusd is asymmetric
(UI8 must stay in the second slot); its existing sign-aware routing
happens to put the broadcast in the m_bcst slot precisely in the two
orientations where the ISA allows the fold, so no change needed there.

Measured on a 4096×4096 dynamic-shape bf16×bf16 -> f32 matmul on Zen 4
(avx512_bf16, no AMX), with `--iree-opt-data-tiling
--iree-llvmcpu-enable-inner-tiled`:

- All 29 `vdpbf16ps` in the inner loop now use the `{1to16}`
memory-broadcast form (vs 0 before); all 29 separate `vbroadcastss` are
gone.
- End-to-end matmul: 80.8 ms -> 62.7 ms (1.29x faster, 16.0 it/s -> from
12.4 it/s), closing ~60% of the gap to the precompiled mmt4d ukernel
(50.5 ms).

Progress towards #24515.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
index 081dce6..bddf8f7 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
@@ -692,47 +692,6 @@
 static Value createCpuMmaIntrinsicCall(OpBuilder &builder, Location loc,
                                        MMAIntrinsic intrinsic, Value lhs,
                                        Value rhs, Value acc) {
-  // Replicate whichever of LHS/RHS has fewer lanes up to the other's lane
-  // count, so both reach the intrinsic's flat lane width. In the natural
-  // 1×N×K orientation that's the LHS (M=1) broadcast across the N lanes of
-  // the RHS; in the M↔N-swapped N×1×K orientation it's the RHS (N=1) that
-  // gets broadcast. Going through a (lanes, K) 2-D form keeps the K-pair
-  // contiguous; a final shape_cast collapses to the flat 1-D vector the
-  // LLVM intrinsic expects. All x86 LLVM intrinsics we target here (AVX-512
-  // FMA, VNNI vpdpwssd, BF16 dpbf16ps, integer pmaddwd) take same-width
-  // vector operands — there is no single-lane-input variant. The ISA-level
-  // `{1toN}` broadcast-from-memory form is recovered later by the x86
-  // backend's instruction selector pattern-matching this `vector.broadcast`-
-  // of-load into the EVEX broadcast operand, so the explicit broadcast here
-  // is what *enables* that, not a perf liability.
-  // TODO(24311): Arm's by-element FMA (`fmla.4s vd, vn, vm[idx]`) is exposed
-  // via separate intrinsics (e.g. `llvm.aarch64.neon.fma.lane.v4f32`) that
-  // take `(vector, vector, lane_idx)`; when we add Arm support, those cases
-  // should bypass this broadcast and emit the lane-index intrinsic directly.
-  auto lhsType = cast<VectorType>(lhs.getType());
-  auto rhsType = cast<VectorType>(rhs.getType());
-  if (lhsType.getNumElements() != rhsType.getNumElements()) {
-    // `bcastSrc` is the operand fed into `vector.broadcast`; `bcastDst` is
-    // the operand whose lane count we match. They alias the underlying
-    // lhs/rhs through pointers so the broadcast result flows back into the
-    // variable used by the switch below.
-    Value *bcastSrc = &lhs;
-    Value *bcastDst = &rhs;
-    VectorType bcastSrcType = lhsType;
-    VectorType bcastDstType = rhsType;
-    if (bcastSrcType.getNumElements() > bcastDstType.getNumElements()) {
-      std::swap(bcastSrc, bcastDst);
-      std::swap(bcastSrcType, bcastDstType);
-    }
-    int64_t replicate =
-        bcastDstType.getNumElements() / bcastSrcType.getNumElements();
-    auto bcastType = VectorType::get({replicate, bcastSrcType.getNumElements()},
-                                     bcastSrcType.getElementType());
-    Value bcast =
-        vector::BroadcastOp::create(builder, loc, bcastType, *bcastSrc);
-    *bcastSrc = vector::ShapeCastOp::create(builder, loc, bcastDstType, bcast);
-  }
-
   // Sign-/float-extend a vector to a wider element type. Used by the
   // *_CASTF32 (f16 → f32) and *_CASTI16 (i8 → i16) variants where the
   // intrinsic only exists at the wider type.
@@ -745,6 +704,111 @@
     return arith::ExtSIOp::create(builder, loc, wideTy, v);
   };
 
+  // For *_CAST* intrinsics, widen lhs/rhs to the intrinsic's element type
+  // *before* the broadcast below. The alternative — widening after the
+  // broadcast, when the smaller operand has already been replicated to the
+  // full lane count — would re-do the i8 → i16 (or f16 → f32) extension on
+  // every M row of the unrolled tile and force LLVM to scalarize the widen
+  // into a per-row `vpbroadcastw` + `vpandq` + `vpsrlvw` sequence around
+  // each FMA. Widening the narrow source vector once before the broadcast
+  // keeps the per-row inner loop down to just the FMA (with `m_bcst`).
+  Type widenTo;
+  switch (intrinsic) {
+  case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32:
+  case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F16_CASTF32:
+    widenTo = builder.getF32Type();
+    break;
+  case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
+  case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I8_CASTI16:
+  case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16:
+  case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I8_CASTI16:
+    widenTo = builder.getI16Type();
+    break;
+  default:
+    break;
+  }
+  if (widenTo) {
+    lhs = widen(lhs, widenTo);
+    rhs = widen(rhs, widenTo);
+  }
+
+  // Replicate whichever of LHS/RHS has fewer lanes up to the other's lane
+  // count, so both reach the intrinsic's flat lane width. In the natural
+  // 1×N×K orientation that's the LHS (M=1) broadcast across the N lanes of
+  // the RHS; in the M↔N-swapped N×1×K orientation it's the RHS (N=1) that
+  // gets broadcast. All x86 LLVM intrinsics we target here (AVX-512 FMA,
+  // VNNI vpdpwssd, BF16 dpbf16ps, integer pmaddwd) take same-width vector
+  // operands — there is no single-lane-input variant.
+  //
+  // We emit the replication as the splat-of-a-single-packed-lane idiom:
+  // bitcast the K-element source to a 1-lane vector whose lane covers the
+  // full `K * elem_bits` broadcast unit, splat to `replicate` lanes via a
+  // zero-mask `vector.shuffle`, and bitcast back to the flat element type.
+  // This is the canonical IR shape that LLVM's x86 instruction selector
+  // pattern-matches to recover the ISA-level `{1toN}` broadcast-from-memory
+  // EVEX operand (`m32bcst`/`m64bcst`). The natural alternative — directly
+  // shuffling `<K x elem>` to `<replicate*K x elem>` with the interleaved
+  // mask `[0,...,K-1, 0,...,K-1, ...]`, or going through `vector.broadcast`
+  // to a (replicate, K) 2-D shape and then `vector.shape_cast`-ing to flat
+  // — is semantically equivalent but lowers to a `<K x elem>` shufflevector
+  // that ISel does *not* recognize as a broadcast and so emits as a
+  // separate `vbroadcastss`/`vbroadcastsd` before each FMA, doubling the
+  // per-row uop count of the hot inner loop. For K=1 the bitcast pair is a
+  // width-preserving no-op LLVM elides.
+  // TODO(24311): Arm's by-element FMA (`fmla.4s vd, vn, vm[idx]`) is exposed
+  // via separate intrinsics (e.g. `llvm.aarch64.neon.fma.lane.v4f32`) that
+  // take `(vector, vector, lane_idx)`; when we add Arm support, those cases
+  // should bypass this replication and emit the lane-index intrinsic directly.
+  auto lhsType = cast<VectorType>(lhs.getType());
+  auto rhsType = cast<VectorType>(rhs.getType());
+  // Tracks whether the broadcast landed on lhs; used by the symmetric
+  // intrinsic cases below to route the broadcast operand into the
+  // m_bcst-foldable slot (third operand for dpbf16ps/vpdpwssd/pmaddwd, second
+  // operand for FMA's `a*b+c` form).
+  bool lhsIsBroadcast = false;
+  if (lhsType.getNumElements() != rhsType.getNumElements()) {
+    // `bcastSrc` is the operand being replicated; `bcastDst` is the operand
+    // whose lane count we match. They alias the underlying lhs/rhs through
+    // pointers so the result flows back into the variable used by the switch
+    // below.
+    Value *bcastSrc = &lhs;
+    Value *bcastDst = &rhs;
+    VectorType bcastSrcType = lhsType;
+    VectorType bcastDstType = rhsType;
+    if (bcastSrcType.getNumElements() > bcastDstType.getNumElements()) {
+      std::swap(bcastSrc, bcastDst);
+      std::swap(bcastSrcType, bcastDstType);
+    }
+    lhsIsBroadcast = (bcastSrc == &lhs);
+    int64_t srcN = bcastSrcType.getNumElements();
+    int64_t replicate = bcastDstType.getNumElements() / srcN;
+    // The broadcast unit is `srcN` packed source elements. When it is a single
+    // element, splat through that element's own type; when it packs several,
+    // no scalar type names the unit, so splat through an integer of the unit's
+    // bit width. x86 ISel's m_bcst fold keys on the splat *shape* (below) and
+    // the broadcast-load width (`m32bcst`/`m64bcst` match any 32-/64-bit load,
+    // integer or float) -- not on a float element type.
+    Type srcElemTy = bcastSrcType.getElementType();
+    Type laneTy = srcN == 1 ? srcElemTy
+                            : Type(builder.getIntegerType(
+                                  srcElemTy.getIntOrFloatBitWidth() * srcN));
+    auto singleLaneTy = VectorType::get({1}, laneTy);
+    auto replicatedTy = VectorType::get({replicate}, laneTy);
+    Value asSingleLane =
+        vector::BitCastOp::create(builder, loc, singleLaneTy, *bcastSrc);
+    // Extract scalar + `vector.broadcast` so this lowers to LLVM's canonical
+    // `insertelement <N x T> poison, T, 0` + `shufflevector <N x T> ...,
+    // <N x i32> zeroinitializer` splat shape (the same shape the ukernel's
+    // `_mm512_set1_ps` produces, and the one x86 ISel folds into m_bcst).
+    // A direct `shufflevector <1 x T> -> <N x T>` is semantically equivalent
+    // but does not pattern-match.
+    Value scalar = vector::ExtractOp::create(builder, loc, asSingleLane,
+                                             ArrayRef<int64_t>{0});
+    Value splatted =
+        vector::BroadcastOp::create(builder, loc, replicatedTy, scalar);
+    *bcastSrc = vector::BitCastOp::create(builder, loc, bcastDstType, splatted);
+  }
+
   // Emit llvm.call_intrinsic.
   auto call = [&builder, loc](StringRef name, Type resultType,
                               ValueRange args) -> Value {
@@ -753,56 +817,55 @@
         .getResult(0);
   };
 
-  Type f32 = builder.getF32Type();
-  Type i16 = builder.getI16Type();
   Type accType = acc.getType();
-  // The x86 MMAs supported here are mostly LHS/RHS-symmetric (FMA, dpbf16ps,
+  // Most x86 MMAs supported here are LHS/RHS-symmetric (FMA, dpbf16ps,
   // vpdpwssd, pmaddw): natural and swapped orientations share the same LLVM
-  // intrinsic and arg order, since after the broadcast above both operands
-  // have the same lane count. The exception is vpdpbusd, which is asymmetric
-  // (first byte source is unsigned, second is signed); for its swapped
-  // sibling we route `rhs` (the unsigned operand) into the first slot.
+  // intrinsic and arg order. For these we route the broadcasted operand into
+  // the m_bcst-foldable slot: the third operand for dpbf16ps/vpdpwssd/
+  // pmaddwd, and the `b` operand (= second mul) for FMA's `a*b+c`. The
+  // exception is vpdpbusd, which is asymmetric (first byte source unsigned,
+  // second signed): we keep its existing sign-aware routing — when the
+  // broadcast happens to land in the signed slot, that maps to vpdpbusd's
+  // m_bcst-foldable slot too; when it lands on the unsigned side, no fold is
+  // available (the ISA's m_bcst is on the signed operand).
+  //
+  // For *_CAST* variants (f16 → f32, i8 → i16), the widening already ran at
+  // the top of this function, so lhs/rhs reach the intrinsic call at the
+  // wider type without any per-row widen in the unrolled tile.
+  Value bcst = lhsIsBroadcast ? lhs : rhs;
+  Value full = lhsIsBroadcast ? rhs : lhs;
   switch (intrinsic) {
   case MMAIntrinsic::MMA_X86_AVX2_FMA_1x8x1_F32_F32:
   case MMAIntrinsic::MMA_X86_AVX2_FMA_8x1x1_F32_F32:
-    return call("llvm.fma.v8f32", accType, ValueRange{lhs, rhs, acc});
+    return call("llvm.fma.v8f32", accType, ValueRange{full, bcst, acc});
   case MMAIntrinsic::MMA_X86_AVX512_1x8x1_F64_F64:
   case MMAIntrinsic::MMA_X86_AVX512_8x1x1_F64_F64:
-    return call("llvm.fma.v8f64", accType, ValueRange{lhs, rhs, acc});
+    return call("llvm.fma.v8f64", accType, ValueRange{full, bcst, acc});
   case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F32:
   case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F32:
-    return call("llvm.fma.v16f32", accType, ValueRange{lhs, rhs, acc});
   case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32:
   case MMAIntrinsic::MMA_X86_AVX512_16x1x1_F32_F16_CASTF32:
-    return call("llvm.fma.v16f32", accType,
-                ValueRange{widen(lhs, f32), widen(rhs, f32), acc});
+    return call("llvm.fma.v16f32", accType, ValueRange{full, bcst, acc});
   case MMAIntrinsic::MMA_X86_AVX512FP16_1x32x1_F16_F16:
   case MMAIntrinsic::MMA_X86_AVX512FP16_32x1x1_F16_F16:
-    return call("llvm.fma.v32f16", accType, ValueRange{lhs, rhs, acc});
+    return call("llvm.fma.v32f16", accType, ValueRange{full, bcst, acc});
   case MMAIntrinsic::MMA_X86_AVX512BF16_1x16x2_F32_BF16:
   case MMAIntrinsic::MMA_X86_AVX512BF16_16x1x2_F32_BF16:
     return call("llvm.x86.avx512bf16.dpbf16ps.512", accType,
-                ValueRange{acc, lhs, rhs});
+                ValueRange{acc, full, bcst});
   case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I16:
   case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I16:
-    return call("llvm.x86.avx512.vpdpwssd.512", accType,
-                ValueRange{acc, lhs, rhs});
   case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
   case MMAIntrinsic::MMA_X86_AVX512VNNI_16x1x2_I32_I8_CASTI16:
     return call("llvm.x86.avx512.vpdpwssd.512", accType,
-                ValueRange{acc, widen(lhs, i16), widen(rhs, i16)});
+                ValueRange{acc, full, bcst});
   case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I16:
-  case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I16: {
-    return arith::AddIOp::create(
-        builder, loc, acc,
-        call("llvm.x86.avx512.pmaddw.d.512", accType, ValueRange{lhs, rhs}));
-  }
+  case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I16:
   case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16:
   case MMAIntrinsic::MMA_X86_AVX512_16x1x2_I32_I8_CASTI16: {
     return arith::AddIOp::create(
         builder, loc, acc,
-        call("llvm.x86.avx512.pmaddw.d.512", accType,
-             ValueRange{widen(lhs, i16), widen(rhs, i16)}));
+        call("llvm.x86.avx512.pmaddw.d.512", accType, ValueRange{full, bcst}));
   }
   case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x4_I32_UI8_I8:
     // LHS unsigned, RHS signed — vpdpbusd takes (acc, unsigned, signed).
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
index 43cf44d..990017d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/lower_inner_tiled.mlir
@@ -32,9 +32,9 @@
 }
 
 // CHECK-LABEL: func @lower_avx512_1x16x1_f32
-//       CHECK:   vector.broadcast {{.*}} : vector<{{1x1|1}}xf32> to vector<16x1xf32>
-//       CHECK:   vector.shape_cast {{.*}} : vector<16x1xf32> to vector<16xf32>
-//       CHECK:   llvm.call_intrinsic "llvm.fma.v16f32"({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
+//       CHECK:   %[[SCALAR:.+]] = vector.extract {{.*}} : f32 from vector<{{1x1|1}}xf32>
+//       CHECK:   %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : f32 to vector<16xf32>
+//       CHECK:   llvm.call_intrinsic "llvm.fma.v16f32"(%{{.+}}, %[[BCST]], %{{.+}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
 
 // -----
 
@@ -70,10 +70,10 @@
 }
 
 // CHECK-LABEL: func @lower_avx512_1x16x1_f16_castf32
-//       CHECK:   vector.broadcast {{.*}} : vector<{{1x1|1}}xf16> to vector<16x1xf16>
-//       CHECK:   vector.shape_cast {{.*}} : vector<16x1xf16> to vector<16xf16>
+//       CHECK:   arith.extf {{.*}} : vector<1xf16> to vector<1xf32>
 //       CHECK:   arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
-//       CHECK:   arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
+//       CHECK:   %[[SCALAR:.+]] = vector.extract {{.*}} : f32 from vector<1xf32>
+//       CHECK:   %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : f32 to vector<16xf32>
 //       CHECK:   llvm.call_intrinsic "llvm.fma.v16f32"({{.*}}) : (vector<16xf32>, vector<16xf32>, vector<16xf32>) -> vector<16xf32>
 
 // -----
@@ -112,10 +112,11 @@
 }
 
 // CHECK-LABEL: func @lower_avx512_1x16x2_i8_casti16
-//       CHECK:   vector.broadcast {{.*}} : vector<{{1x2|2}}xi8> to vector<16x2xi8>
-//       CHECK:   vector.shape_cast {{.*}} : vector<16x2xi8> to vector<32xi8>
+//       CHECK:   arith.extsi {{.*}} : vector<2xi8> to vector<2xi16>
 //       CHECK:   arith.extsi {{.*}} : vector<32xi8> to vector<32xi16>
-//       CHECK:   arith.extsi {{.*}} : vector<32xi8> to vector<32xi16>
+//       CHECK:   %[[SCALAR:.+]] = vector.extract {{.*}} : i32 from vector<1xi32>
+//       CHECK:   %[[BCST:.+]] = vector.broadcast %[[SCALAR]] : i32 to vector<16xi32>
+//       CHECK:   vector.bitcast %[[BCST]] : vector<16xi32> to vector<32xi16>
 //       CHECK:   %[[DOT:.+]] = llvm.call_intrinsic "llvm.x86.avx512.pmaddw.d.512"({{.*}}) : (vector<32xi16>, vector<32xi16>) -> vector<16xi32>
 //       CHECK:   arith.addi {{.*}}, %[[DOT]] : vector<16xi32>
 
@@ -303,7 +304,7 @@
 // CHECK-LABEL: func @lower_avx512_16x1x1_f32
 //       CHECK:   util.hoistable_conversion "shape_cast_to_intrinsic"
 //       CHECK:     vector.shape_cast {{.*}} : vector<2x4x16xf32> to vector<2x4x16x1xf32>
-//       CHECK:   vector.broadcast {{.*}} : vector<1xf32> to vector<16x1xf32>
+//       CHECK:   vector.broadcast {{.*}} : f32 to vector<16xf32>
 //       CHECK:   llvm.call_intrinsic "llvm.fma.v16f32"
 //       CHECK:   util.hoistable_conversion "shape_cast_from_intrinsic"
 //       CHECK:     vector.shape_cast {{.*}} : vector<2x4x16x1xf32> to vector<2x4x16xf32>
@@ -360,8 +361,8 @@
 //  CHECK-SAME:   %[[N_LHS:[a-zA-Z0-9]+]]: vector<1x4xi8>
 //  CHECK-SAME:   %[[N_RHS:[a-zA-Z0-9]+]]: vector<16x4xi8>
 //       CHECK:   %[[N_RHS_FLAT:.+]] = vector.shape_cast %[[N_RHS]] : vector<16x4xi8> to vector<64xi8>
-//       CHECK:   %[[N_BCAST:.+]] = vector.broadcast %[[N_LHS]] : vector<1x4xi8> to vector<16x4xi8>
-//       CHECK:   %[[N_LHS_FLAT:.+]] = vector.shape_cast %[[N_BCAST]] : vector<16x4xi8> to vector<64xi8>
+//       CHECK:   %[[N_BCAST_I32:.+]] = vector.broadcast %{{.+}} : i32 to vector<16xi32>
+//       CHECK:   %[[N_LHS_FLAT:.+]] = vector.bitcast %[[N_BCAST_I32]] : vector<16xi32> to vector<64xi8>
 //       CHECK:   llvm.call_intrinsic "llvm.x86.avx512.vpdpbusd.512"(%{{.+}}, %[[N_LHS_FLAT]], %[[N_RHS_FLAT]])
 
 // Swapped: RHS (ui8, narrow) is broadcast; the lowering swaps arg order at
@@ -370,6 +371,6 @@
 //  CHECK-SAME:   %[[S_LHS:[a-zA-Z0-9]+]]: vector<16x4xi8>
 //  CHECK-SAME:   %[[S_RHS:[a-zA-Z0-9]+]]: vector<1x4xi8>
 //       CHECK:   %[[S_LHS_FLAT:.+]] = vector.shape_cast %[[S_LHS]] : vector<16x4xi8> to vector<64xi8>
-//       CHECK:   %[[S_BCAST:.+]] = vector.broadcast %[[S_RHS]] : vector<1x4xi8> to vector<16x4xi8>
-//       CHECK:   %[[S_RHS_FLAT:.+]] = vector.shape_cast %[[S_BCAST]] : vector<16x4xi8> to vector<64xi8>
+//       CHECK:   %[[S_BCAST_I32:.+]] = vector.broadcast %{{.+}} : i32 to vector<16xi32>
+//       CHECK:   %[[S_RHS_FLAT:.+]] = vector.bitcast %[[S_BCAST_I32]] : vector<16xi32> to vector<64xi8>
 //       CHECK:   llvm.call_intrinsic "llvm.x86.avx512.vpdpbusd.512"(%{{.+}}, %[[S_RHS_FLAT]], %[[S_LHS_FLAT]])