Refactor vector.multi_reduction into flattening, unrolling, and lowering passes. (#24183)
* Adds an initial vector flattening pass.
* At the moment, the only operation that is flattened is
vector.multi_reduction, but others will be added later.
* Adds vector unrolling for vector.multi_reduction.
* Adds a vector.multi_reduction lowering pass
---------
Co-authored-by: Eric <55723758+efric@users.noreply.github.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index b304f6b..4ce3075 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -110,7 +110,9 @@
"LLVMGPUTensorCoreVectorization.cpp",
"LLVMGPUTileAndDistribute.cpp",
"LLVMGPUVectorDistribute.cpp",
+ "LLVMGPUVectorFlattening.cpp",
"LLVMGPUVectorLowering.cpp",
+ "LLVMGPUVectorMultiReductionLowering.cpp",
"LLVMGPUVectorToGPU.cpp",
"Passes.cpp",
"ROCDLAnnotateKernelForTranslation.cpp",
@@ -237,6 +239,7 @@
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@llvm-project//mlir:VectorTransforms",
+ "@llvm-project//mlir:VectorUtils",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index bc5b5fa..65efc4b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -90,7 +90,9 @@
"LLVMGPUTensorCoreVectorization.cpp"
"LLVMGPUTileAndDistribute.cpp"
"LLVMGPUVectorDistribute.cpp"
+ "LLVMGPUVectorFlattening.cpp"
"LLVMGPUVectorLowering.cpp"
+ "LLVMGPUVectorMultiReductionLowering.cpp"
"LLVMGPUVectorToGPU.cpp"
"Passes.cpp"
"ROCDLAnnotateKernelForTranslation.cpp"
@@ -177,6 +179,7 @@
MLIRVectorToLLVM
MLIRVectorToSCF
MLIRVectorTransforms
+ MLIRVectorUtils
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Common::GPU::GPUHeuristics
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
index 0163700..72b8dba 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULegalizeNDVectors.cpp
@@ -710,6 +710,63 @@
}
};
+/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops.
+/// This assumes that the src rank will always be two dimensional.
+///
+/// The reduction dimension must be the inner-most dimension.
+///
+/// BEFORE:
+/// vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to
+/// vector<2xf32>
+///
+/// AFTER:
+/// // 1st reduction
+/// %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32>
+/// %a_0 = vector.extract %acc[0] : f32 from vector<2xf32>
+/// %red_1 = vector.multi_reduction <mul>, %v_0, %a_1 [0] : vector<4xf32> into
+/// f32 %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32>
+///
+/// // 2nd reduction
+/// %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
+/// %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
+/// %red_2 = vector.multi_reduction <mul>, %v_1, %a_1 [0] : vector<4xf32> into
+/// f32 %res_final = vector.insert %red_2, %res_tmp [1] : f32 into
+/// vector<2xf32>
+struct ConvertVectorMultiReduction final
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType srcType = cast<VectorType>(op.getSource().getType());
+ if (srcType.getRank() != 2) {
+ return failure();
+ }
+
+ if (op.isReducedDim(0) || !op.isReducedDim(1)) {
+ return failure();
+ }
+
+ Location loc = op.getLoc();
+ Value acc = adaptor.getAcc()[0];
+ Type resultType = op.getResult().getType();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultType);
+
+ SmallVector<Value> srcs(adaptor.getSource());
+ for (int64_t i = 0, e = srcs.size(); i < e; i++) {
+ Value accElem = vector::ExtractOp::create(rewriter, loc, acc, i);
+ auto reduced = vector::MultiDimReductionOp::create(
+ rewriter, loc, op.getKind(), srcs[i], accElem, ArrayRef<int64_t>{0});
+ result = vector::InsertOp::create(rewriter, loc, reduced, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/// Convert vector.interleave on n-D vectors. The lhs and rhs are already
/// split into flat 1-D vectors by the type converter; create a 1-D interleave
/// for each corresponding pair.
@@ -814,7 +871,8 @@
ConvertVectorInsertStridedSlice, ConvertArithConstant, ConvertUBPoison,
ConvertVectorToElements, ConvertVectorFromElements,
ConvertVectorBroadcast, ConvertVectorBitcast, ConvertVectorInterleave,
- ConvertVectorDeinterleave>(typeConverter, ctx);
+ ConvertVectorDeinterleave, ConvertVectorMultiReduction>(typeConverter,
+ ctx);
// Some nvgpu ops abuse n-D vector types to represent a "struct of
// vectors". These ops are legal despite having n-D vectors — the
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp
new file mode 100644
index 0000000..a932b0a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorFlattening.cpp
@@ -0,0 +1,38 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-vector-flattening"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUVECTORFLATTENINGPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+struct LLVMGPUVectorFlatteningPass final
+ : impl::LLVMGPUVectorFlatteningPassBase<LLVMGPUVectorFlatteningPass> {
+
+ void runOnOperation() override {
+ mlir::FunctionOpInterface funcOp = getOperation();
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ vector::populateVectorMultiReductionFlatteningPatterns(
+ patterns, vector::VectorMultiReductionLowering::InnerReduction);
+ if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
index c652cd7..9316306 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -501,12 +501,6 @@
vector::populateVectorMultiReductionReorderPatterns(
contractLoweringPatterns,
vector::VectorMultiReductionLowering::InnerReduction);
- vector::populateVectorMultiReductionFlatteningPatterns(
- contractLoweringPatterns,
- vector::VectorMultiReductionLowering::InnerReduction);
- vector::populateVectorMultiReductionUnrollingPatterns(
- contractLoweringPatterns,
- vector::VectorMultiReductionLowering::InnerReduction);
// Unroll transfer_gather ops to rank 1 and lower contiguous ones to
// vector.transfer_read.
IREE::VectorExt::populateVectorTransferGatherScatterLoweringPatterns(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp
new file mode 100644
index 0000000..0e07a06
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorMultiReductionLowering.cpp
@@ -0,0 +1,82 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-vector-multi-reduction-lowering"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUVECTORMULTIREDUCTIONLOWERINGPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+/// Converts 1D vector.multi_reduction directly to vector.reduction.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank != 1) {
+ return failure();
+ }
+
+ if (!multiReductionOp.isReducedDim(0)) {
+ return failure();
+ }
+
+ auto loc = multiReductionOp.getLoc();
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+ Operation *reductionOp = vector::ReductionOp::create(
+ rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+ multiReductionOp.getAcc());
+
+ if (mask) {
+ reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+ }
+
+ return reductionOp->getResult(0);
+ }
+};
+
+} // namespace
+
+struct LLVMGPUVectorMultiReductionLoweringPass final
+ : impl::LLVMGPUVectorMultiReductionLoweringPassBase<
+ LLVMGPUVectorMultiReductionLoweringPass> {
+
+ void runOnOperation() override {
+ mlir::FunctionOpInterface funcOp = getOperation();
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<OneDimMultiReductionToReduction>(ctx);
+ if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index b8cf60c..718d8ac 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -995,7 +995,9 @@
// This pass needs to run before SCF -> CF.
// Lower vector operations and legalize all operations to 1D vectors.
funcPassManager.addPass(createLLVMGPUVectorLoweringPass)
+ .addPass(createLLVMGPUVectorFlatteningPass)
.addPass(createLLVMGPULegalizeNDVectorsPass)
+ .addPass(createLLVMGPUVectorMultiReductionLoweringPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index 901dbd5..0277b27 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -77,6 +77,22 @@
];
}
+def LLVMGPUVectorFlatteningPass :
+ InterfacePass<"iree-llvmgpu-vector-flattening", "mlir::FunctionOpInterface"> {
+ let summary = "Flatten n-D vectors.";
+ let dependentDialects = [
+ "vector::VectorDialect",
+ ];
+}
+
+def LLVMGPUVectorMultiReductionLoweringPass :
+ InterfacePass<"iree-llvmgpu-vector-multi-reduction-lowering", "mlir::FunctionOpInterface"> {
+ let summary = "Lower vector.multi_reduction ops.";
+ let dependentDialects = [
+ "vector::VectorDialect",
+ ];
+}
+
def LLVMGPULinkExecutablesPass :
Pass<"iree-llvmgpu-link-executables", "mlir::ModuleOp"> {
let summary = "Links LLVMGPU HAL executables within the top-level program module.";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 9f6f8e4..a05fa14 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -75,7 +75,9 @@
"transform_gpu_pipelining.mlir",
"transform_vector_to_mma.mlir",
"transpose_pipeline_test.mlir",
+ "vector_flattening.mlir",
"vector_lowering.mlir",
+ "vector_multi_reduction_lowering.mlir",
"vector_to_gpu.mlir",
"winograd_pipeline_test.mlir",
],
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 8b60996..a4210cf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -70,7 +70,9 @@
"transform_gpu_pipelining.mlir"
"transform_vector_to_mma.mlir"
"transpose_pipeline_test.mlir"
+ "vector_flattening.mlir"
"vector_lowering.mlir"
+ "vector_multi_reduction_lowering.mlir"
"vector_to_gpu.mlir"
"winograd_pipeline_test.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
index acb4a6e..030e793 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/legalize_nd_vectors.mlir
@@ -339,3 +339,31 @@
// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[B0]] : vector<4xf32>
// CHECK: %[[R1:.+]] = arith.addf %[[A1]], %[[B1]] : vector<4xf32>
// CHECK: util.return %[[R0]], %[[R1]] : vector<4xf32>, vector<4xf32>
+
+// -----
+
+func.func @negative_vector_multi_reduction_rank_one(%arg0: vector<2xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<2xf32> to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @negative_vector_multi_reduction_rank_one
+// CHECK: vector.multi_reduction
+// CHECK: return
+
+// -----
+
+func.func @vector_multi_reduction_2d(%src: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @vector_multi_reduction_2d
+// CHECK-SAME: (%[[S0:.+]]: vector<4xf32>, %[[S1:.+]]: vector<4xf32>, %[[ACC:.+]]: vector<2xf32>)
+// CHECK: %[[A0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK: %[[R0:.+]] = vector.multi_reduction <mul>, %[[S0]], %[[A0]] [0] : vector<4xf32> to f32
+// CHECK: %[[INS0:.+]] = vector.insert %[[R0]], %{{.*}} [0] : f32 into vector<2xf32>
+// CHECK: %[[A1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK: %[[R1:.+]] = vector.multi_reduction <mul>, %[[S1]], %[[A1]] [0] : vector<4xf32> to f32
+// CHECK: %[[INS1:.+]] = vector.insert %[[R1]], %[[INS0]] [1] : f32 into vector<2xf32>
+// CHECK: return %[[INS1]] : vector<2xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir
new file mode 100644
index 0000000..014fa51
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_flattening.mlir
@@ -0,0 +1,12 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-flattening))" \
+// RUN: --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @vector_multi_reduction_flattening
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
index 05a2140..cf6e61d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir
@@ -39,28 +39,6 @@
// -----
-// Test multi_reduction lowering.
-
-func.func @multi_reduction_f32(%a: vector<2x1x8xf32>, %b: vector<2x1x8xf32>) -> vector<2x1xf32> {
- %cst_4 = arith.constant dense<0.000000e+00> : vector<2x1xf32>
- %cst_5 = arith.constant dense<0.000000e+00> : vector<2x1x8xf32>
- %22 = arith.mulf %a, %b : vector<2x1x8xf32>
- %23 = arith.addf %22, %cst_5 : vector<2x1x8xf32>
- %24 = vector.multi_reduction <add>, %23, %cst_4 [2] : vector<2x1x8xf32> to vector<2x1xf32>
- return %24 : vector<2x1xf32>
-}
-
-// CHECK-LABEL: func.func @multi_reduction_f32
-// CHECK-SAME: %[[ARG0:.+]]: vector<2x1x8xf32>, %[[ARG1:.+]]: vector<2x1x8xf32>)
-// CHECK-DAG: %[[FMA:.+]] = math.fma %[[ARG0]], %[[ARG1]], %{{.*}} fastmath<contract> : vector<2x1x8xf32>
-// CHECK-DAG: %[[FMA1:.+]] = vector.extract %[[FMA]][0, 0] : vector<8xf32> from vector<2x1x8xf32>
-// CHECK-DAG: %[[RED1:.+]] = vector.reduction <add>, %[[FMA1]], %{{.*}} : vector<8xf32> into f32
-// CHECK-DAG: %[[FMA2:.+]] = vector.extract %[[FMA]][1, 0] : vector<8xf32> from vector<2x1x8xf32>
-// CHECK-DAG: %[[RED2:.+]] = vector.reduction <add>, %[[FMA2]], %{{.*}} : vector<8xf32> into f32
-// CHECK: vector.from_elements %[[RED1]], %[[RED2]] : vector<2x1xf32>
-
-// -----
-
func.func @multi_reduction_no_uplift(%a: vector<2x1x8xf32>, %b: vector<2x1x8xf32>) -> vector<2x1xf32> {
%cst_4 = arith.constant dense<0.000000e+00> : vector<2x1xf32>
%cst_5 = arith.constant dense<0.000000e+00> : vector<2x1x8xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir
new file mode 100644
index 0000000..93baa06
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_multi_reduction_lowering.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-multi-reduction-lowering))" --split-input-file %s | FileCheck %s --check-prefixes=ALL
+
+// ALL-LABEL: func @one_dim_reduction
+// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32
+func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
+ // ALL: %[[RESULT:.+]] = vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}