Refactor vector.multi_reduction into flattening, unrolling, and lowering passes. (#24183)

* Adds an initial vector flattening pass.
* At the moment, the only operation that is flattened is
vector.multi_reduction, but others will be added later.
* Adds vector unrolling for vector.multi_reduction.
* Adds a vector.multi_reduction lowering pass

---------

Co-authored-by: Eric <55723758+efric@users.noreply.github.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index b304f6b..4ce3075 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -110,7 +110,9 @@
         "LLVMGPUTensorCoreVectorization.cpp",
         "LLVMGPUTileAndDistribute.cpp",
         "LLVMGPUVectorDistribute.cpp",
+        "LLVMGPUVectorFlattening.cpp",
         "LLVMGPUVectorLowering.cpp",
+        "LLVMGPUVectorMultiReductionLowering.cpp",
         "LLVMGPUVectorToGPU.cpp",
         "Passes.cpp",
         "ROCDLAnnotateKernelForTranslation.cpp",
@@ -237,6 +239,7 @@
         "@llvm-project//mlir:VectorToLLVM",
         "@llvm-project//mlir:VectorToSCF",
         "@llvm-project//mlir:VectorTransforms",
+        "@llvm-project//mlir:VectorUtils",
     ],
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index bc5b5fa..65efc4b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -90,7 +90,9 @@
     "LLVMGPUTensorCoreVectorization.cpp"
     "LLVMGPUTileAndDistribute.cpp"
     "LLVMGPUVectorDistribute.cpp"
+    "LLVMGPUVectorFlattening.cpp"
     "LLVMGPUVectorLowering.cpp"
+    "LLVMGPUVectorMultiReductionLowering.cpp"
     "LLVMGPUVectorToGPU.cpp"
     "Passes.cpp"
     "ROCDLAnnotateKernelForTranslation.cpp"
@@ -177,6 +179,7 @@
     MLIRVectorToLLVM
     MLIRVectorToSCF
     MLIRVectorTransforms
+    MLIRVectorUtils
     iree::compiler::Codegen::Common
     iree::compiler::Codegen::Common::GPU::CommonGPUPasses
     iree::compiler::Codegen::Common::GPU::GPUHeuristics
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
index 0163700..72b8dba 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
@@ -710,6 +710,63 @@
   }
 };
 
+/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops.
+/// This assumes that the src rank will always be two dimensional.
+///
+/// The reduction dimension must be the inner-most dimension.
+///
+/// BEFORE:
+///  vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to
+///  vector<2xf32>
+///
+/// AFTER:
+///   // 1st reduction
+///   %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32>
+///   %a_0 = vector.extract %acc[0] : f32 from vector<2xf32>
+///   %red_1 = vector.multi_reduction <mul>, %v_0, %a_1 [0] : vector<4xf32> into
+///   f32 %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32>
+///
+///   // 2nd reduction
+///   %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
+///   %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
+///   %red_2 = vector.multi_reduction <mul>, %v_1, %a_1 [0] : vector<4xf32> into
+///   f32 %res_final = vector.insert %red_2, %res_tmp [1] : f32 into
+///   vector<2xf32>
+struct ConvertVectorMultiReduction final
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType srcType = cast<VectorType>(op.getSource().getType());
+    if (srcType.getRank() != 2) {
+      return failure();
+    }
+
+    if (op.isReducedDim(0) || !op.isReducedDim(1)) {
+      return failure();
+    }
+
+    Location loc = op.getLoc();
+    Value acc = adaptor.getAcc()[0];
+    Type resultType = op.getResult().getType();
+    Value result = ub::PoisonOp::create(rewriter, loc, resultType);
+
+    SmallVector<Value> srcs(adaptor.getSource());
+    for (int64_t i = 0, e = srcs.size(); i < e; i++) {
+      Value accElem = vector::ExtractOp::create(rewriter, loc, acc, i);
+      auto reduced = vector::MultiDimReductionOp::create(
+          rewriter, loc, op.getKind(), srcs[i], accElem, ArrayRef<int64_t>{0});
+      result = vector::InsertOp::create(rewriter, loc, reduced, result, i);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /// Convert vector.interleave on n-D vectors. The lhs and rhs are already
 /// split into flat 1-D vectors by the type converter; create a 1-D interleave
 /// for each corresponding pair.
@@ -814,7 +871,8 @@
         ConvertVectorInsertStridedSlice, ConvertArithConstant, ConvertUBPoison,
         ConvertVectorToElements, ConvertVectorFromElements,
         ConvertVectorBroadcast, ConvertVectorBitcast, ConvertVectorInterleave,
-        ConvertVectorDeinterleave>(typeConverter, ctx);
+        ConvertVectorDeinterleave, ConvertVectorMultiReduction>(typeConverter,
+                                                                ctx);
 
     // Some nvgpu ops abuse n-D vector types to represent a "struct of
     // vectors". These ops are legal despite having n-D vectors — the
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp
new file mode 100644
index 0000000..a932b0a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp
@@ -0,0 +1,38 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-vector-flattening"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUVECTORFLATTENINGPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+struct LLVMGPUVectorFlatteningPass final
+    : impl::LLVMGPUVectorFlatteningPassBase<LLVMGPUVectorFlatteningPass> {
+
+  void runOnOperation() override {
+    mlir::FunctionOpInterface funcOp = getOperation();
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    vector::populateVectorMultiReductionFlatteningPatterns(
+        patterns, vector::VectorMultiReductionLowering::InnerReduction);
+    if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
index c652cd7..9316306 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -501,12 +501,6 @@
       vector::populateVectorMultiReductionReorderPatterns(
           contractLoweringPatterns,
           vector::VectorMultiReductionLowering::InnerReduction);
-      vector::populateVectorMultiReductionFlatteningPatterns(
-          contractLoweringPatterns,
-          vector::VectorMultiReductionLowering::InnerReduction);
-      vector::populateVectorMultiReductionUnrollingPatterns(
-          contractLoweringPatterns,
-          vector::VectorMultiReductionLowering::InnerReduction);
       // Unroll transfer_gather ops to rank 1 and lower contiguous ones to
       // vector.transfer_read.
       IREE::VectorExt::populateVectorTransferGatherScatterLoweringPatterns(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp
new file mode 100644
index 0000000..0e07a06
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp
@@ -0,0 +1,82 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-vector-multi-reduction-lowering"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUVECTORMULTIREDUCTIONLOWERINGPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+/// Converts 1D vector.multi_reduction directly to vector.reduction.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 1) {
+      return failure();
+    }
+
+    if (!multiReductionOp.isReducedDim(0)) {
+      return failure();
+    }
+
+    auto loc = multiReductionOp.getLoc();
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+    Operation *reductionOp = vector::ReductionOp::create(
+        rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+        multiReductionOp.getAcc());
+
+    if (mask) {
+      reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+    }
+
+    return reductionOp->getResult(0);
+  }
+};
+
+} // namespace
+
+struct LLVMGPUVectorMultiReductionLoweringPass final
+    : impl::LLVMGPUVectorMultiReductionLoweringPassBase<
+          LLVMGPUVectorMultiReductionLoweringPass> {
+
+  void runOnOperation() override {
+    mlir::FunctionOpInterface funcOp = getOperation();
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<OneDimMultiReductionToReduction>(ctx);
+    if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index b8cf60c..718d8ac 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -995,7 +995,9 @@
   // This pass needs to run before SCF -> CF.
   // Lower vector operations and legalize all operations to 1D vectors.
   funcPassManager.addPass(createLLVMGPUVectorLoweringPass)
+      .addPass(createLLVMGPUVectorFlatteningPass)
       .addPass(createLLVMGPULegalizeNDVectorsPass)
+      .addPass(createLLVMGPUVectorMultiReductionLoweringPass)
       .addPass(createCanonicalizerPass)
       .addPass(createCSEPass);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index 901dbd5..0277b27 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -77,6 +77,22 @@
   ];
 }
 
+def LLVMGPUVectorFlatteningPass :
+    InterfacePass<"iree-llvmgpu-vector-flattening", "mlir::FunctionOpInterface"> {
+  let summary = "Flatten n-D vectors.";
+  let dependentDialects = [
+    "vector::VectorDialect",
+  ];
+}
+
+def LLVMGPUVectorMultiReductionLoweringPass :
+    InterfacePass<"iree-llvmgpu-vector-multi-reduction-lowering", "mlir::FunctionOpInterface"> {
+  let summary = "Lower vector.multi_reduction ops.";
+  let dependentDialects = [
+    "vector::VectorDialect",
+  ];
+}
+
 def LLVMGPULinkExecutablesPass :
     Pass<"iree-llvmgpu-link-executables", "mlir::ModuleOp"> {
   let summary = "Links LLVMGPU HAL executables within the top-level program module.";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 9f6f8e4..a05fa14 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -75,7 +75,9 @@
             "transform_gpu_pipelining.mlir",
             "transform_vector_to_mma.mlir",
             "transpose_pipeline_test.mlir",
+            "vector_flattening.mlir",
             "vector_lowering.mlir",
+            "vector_multi_reduction_lowering.mlir",
             "vector_to_gpu.mlir",
             "winograd_pipeline_test.mlir",
         ],
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 8b60996..a4210cf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -70,7 +70,9 @@
     "transform_gpu_pipelining.mlir"
     "transform_vector_to_mma.mlir"
     "transpose_pipeline_test.mlir"
+    "vector_flattening.mlir"
     "vector_lowering.mlir"
+    "vector_multi_reduction_lowering.mlir"
     "vector_to_gpu.mlir"
     "winograd_pipeline_test.mlir"
   TOOLS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
index acb4a6e..030e793 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
@@ -339,3 +339,31 @@
 //       CHECK:   %[[R0:.+]] = arith.addf %[[A0]], %[[B0]] : vector<4xf32>
 //       CHECK:   %[[R1:.+]] = arith.addf %[[A1]], %[[B1]] : vector<4xf32>
 //       CHECK:   util.return %[[R0]], %[[R1]] : vector<4xf32>, vector<4xf32>
+
+// -----
+
+func.func @negative_vector_multi_reduction_rank_one(%arg0: vector<2xf32>, %acc: f32) -> f32 {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<2xf32> to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: func.func @negative_vector_multi_reduction_rank_one
+//       CHECK:   vector.multi_reduction
+//       CHECK:   return
+
+// -----
+
+func.func @vector_multi_reduction_2d(%src: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @vector_multi_reduction_2d
+//  CHECK-SAME:   (%[[S0:.+]]: vector<4xf32>, %[[S1:.+]]: vector<4xf32>, %[[ACC:.+]]: vector<2xf32>)
+//       CHECK:   %[[A0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+//       CHECK:   %[[R0:.+]] = vector.multi_reduction <mul>, %[[S0]], %[[A0]] [0] : vector<4xf32> to f32
+//       CHECK:   %[[INS0:.+]] = vector.insert %[[R0]], %{{.*}} [0] : f32 into vector<2xf32>
+//       CHECK:   %[[A1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+//       CHECK:   %[[R1:.+]] = vector.multi_reduction <mul>, %[[S1]], %[[A1]] [0] : vector<4xf32> to f32
+//       CHECK:   %[[INS1:.+]] = vector.insert %[[R1]], %[[INS0]] [1] : f32 into vector<2xf32>
+//       CHECK:   return %[[INS1]] : vector<2xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir
new file mode 100644
index 0000000..014fa51
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir
@@ -0,0 +1,12 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-flattening))" \
+// RUN:   --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_multi_reduction_flattening
+// CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    // CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    // CHECK: return %[[RESULT]]
+    return %0 : f32
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
index 05a2140..cf6e61d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
@@ -39,28 +39,6 @@
 
 // -----
 
-// Test multi_reduction lowering.
-
-func.func @multi_reduction_f32(%a: vector<2x1x8xf32>, %b: vector<2x1x8xf32>) -> vector<2x1xf32> {
-  %cst_4 = arith.constant dense<0.000000e+00> : vector<2x1xf32>
-  %cst_5 = arith.constant dense<0.000000e+00> : vector<2x1x8xf32>
-  %22 = arith.mulf %a, %b : vector<2x1x8xf32>
-  %23 = arith.addf %22, %cst_5 : vector<2x1x8xf32>
-  %24 = vector.multi_reduction <add>, %23, %cst_4 [2] : vector<2x1x8xf32> to vector<2x1xf32>
-  return %24 : vector<2x1xf32>
-}
-
-// CHECK-LABEL: func.func @multi_reduction_f32
-// CHECK-SAME: %[[ARG0:.+]]: vector<2x1x8xf32>, %[[ARG1:.+]]: vector<2x1x8xf32>)
-// CHECK-DAG: %[[FMA:.+]] = math.fma %[[ARG0]], %[[ARG1]], %{{.*}} fastmath<contract> : vector<2x1x8xf32>
-// CHECK-DAG: %[[FMA1:.+]] = vector.extract %[[FMA]][0, 0] : vector<8xf32> from vector<2x1x8xf32>
-// CHECK-DAG: %[[RED1:.+]] = vector.reduction <add>, %[[FMA1]], %{{.*}} : vector<8xf32> into f32
-// CHECK-DAG: %[[FMA2:.+]] = vector.extract %[[FMA]][1, 0] : vector<8xf32> from vector<2x1x8xf32>
-// CHECK-DAG: %[[RED2:.+]] = vector.reduction <add>, %[[FMA2]], %{{.*}} : vector<8xf32> into f32
-// CHECK: vector.from_elements %[[RED1]], %[[RED2]] : vector<2x1xf32>
-
-// -----
-
 func.func @multi_reduction_no_uplift(%a: vector<2x1x8xf32>, %b: vector<2x1x8xf32>) -> vector<2x1xf32> {
   %cst_4 = arith.constant dense<0.000000e+00> : vector<2x1xf32>
   %cst_5 = arith.constant dense<0.000000e+00> : vector<2x1x8xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir
new file mode 100644
index 0000000..93baa06
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-multi-reduction-lowering))" --split-input-file %s | FileCheck %s --check-prefixes=ALL
+
+// ALL-LABEL: func @one_dim_reduction
+// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32
+func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
+  // ALL: %[[RESULT:.+]] = vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32
+  %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+  // ALL: return %[[RESULT]]
+  return %0 : f32
+}