[Codegen] Add transfer_{gather/scatter} to vector.{gather/scatter} lowering (#24104)
According to discussion with @Groverkss , this pass is supposed to run
at the point where all N-D vectors have been legalized to 1-D so this
pass should only handle the 1-D `transfer_{gather/scatter}` case.
---------
Signed-off-by: NoumanAmir657 <noumanamir453@gmail.com>diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
index a31b928..c02be8d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
@@ -38,6 +38,7 @@
"BufferizationInterfaces.cpp",
"DistributionPatterns.cpp",
"LowerTransferGatherScatterOps.cpp",
+ "LowerTransferGatherScatterToVector.cpp",
"Passes.cpp",
"VectorExtFoldUnitExtentDims.cpp",
],
@@ -56,6 +57,7 @@
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationInterfaces",
"@llvm-project//mlir:BufferizationTransforms",
+ "@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
index 1e53aa4..63ea2a3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@
"BufferizationInterfaces.cpp"
"DistributionPatterns.cpp"
"LowerTransferGatherScatterOps.cpp"
+ "LowerTransferGatherScatterToVector.cpp"
"Passes.cpp"
"VectorExtFoldUnitExtentDims.cpp"
DEPS
@@ -40,6 +41,7 @@
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
+ MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
MLIRPass
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp
new file mode 100644
index 0000000..7a7e237
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp
@@ -0,0 +1,142 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+namespace mlir::iree_compiler::IREE::VectorExt {
+
+#define GEN_PASS_DEF_LOWERTRANSFERGATHERSCATTERTOVECTORPASS
+#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"
+
+namespace {
+
+static LogicalResult validateMaps(Operation *op, OperandRange indexVecs,
+ AffineMap sourceMap,
+ PatternRewriter &rewriter) {
+ // TODO: vector.gather/scatter requires a single index vector.
+ // We could flatten everything to 1-D to make it work.
+ if (indexVecs.size() != 1) {
+ return rewriter.notifyMatchFailure(op, "expected exactly one index vec");
+ }
+
+ if (!isa<VectorType>(indexVecs[0].getType())) {
+ return rewriter.notifyMatchFailure(op, "index vec must be a vector type");
+ }
+
+ unsigned numResults = sourceMap.getNumResults();
+ for (unsigned i = 0; i < numResults - 1; ++i) {
+ if (!isa<AffineConstantExpr>(sourceMap.getResult(i))) {
+ return rewriter.notifyMatchFailure(op,
+ "non-gathered dims must be constants");
+ }
+ }
+
+ if (!isa<AffineSymbolExpr>(sourceMap.getResult(numResults - 1))) {
+ return rewriter.notifyMatchFailure(op,
+ "last dim must be a pure symbol expr");
+ }
+ return success();
+}
+
+struct LowerTransferGatherToVectorGather final
+ : OpRewritePattern<TransferGatherOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(TransferGatherOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
+ if (failed(
+ validateMaps(op, op.getIndexVecs(), indexingMaps[0], rewriter))) {
+ return failure();
+ }
+
+ Location loc = op.getLoc();
+ VectorType resultType = op.getVectorType();
+ Value indexVec = op.getIndexVecs()[0];
+
+ Value mask = op.getMask();
+ if (!mask) {
+ auto maskType =
+ VectorType::get(resultType.getShape(), rewriter.getI1Type());
+ mask = arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(maskType, true));
+ }
+
+ Value passthru =
+ vector::BroadcastOp::create(rewriter, loc, resultType, op.getPadding());
+
+ rewriter.replaceOpWithNewOp<vector::GatherOp>(op, resultType, op.getBase(),
+ op.getOffsets(), indexVec,
+ mask, passthru);
+ return success();
+ }
+};
+
+struct LowerTransferScatterToVectorScatter final
+ : OpRewritePattern<TransferScatterOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(TransferScatterOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
+ if (failed(
+ validateMaps(op, op.getIndexVecs(), indexingMaps[0], rewriter))) {
+ return failure();
+ }
+
+ Location loc = op.getLoc();
+ VectorType vectorType = op.getVectorType();
+ Value indexVec = op.getIndexVecs()[0];
+
+ Value mask = op.getMask();
+ if (!mask) {
+ auto maskType =
+ VectorType::get(vectorType.getShape(), rewriter.getI1Type());
+ mask = arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(maskType, true));
+ }
+
+ Type resultType = op.hasTensorSemantics() ? op.getBase().getType() : Type{};
+ auto scatterOp = vector::ScatterOp::create(rewriter, loc, resultType,
+ op.getBase(), op.getOffsets(),
+ indexVec, mask, op.getVector());
+ if (op.hasTensorSemantics()) {
+ rewriter.replaceOp(op, scatterOp.getResult());
+ } else {
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
+struct LowerTransferGatherScatterToVectorPass final
+ : impl::LowerTransferGatherScatterToVectorPassBase<
+ LowerTransferGatherScatterToVectorPass> {
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<LowerTransferGatherToVectorGather,
+ LowerTransferScatterToVectorScatter>(ctx);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::VectorExt
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
index 8c54b88..978a941 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
@@ -18,4 +18,14 @@
];
}
+def LowerTransferGatherScatterToVectorPass :
+ Pass<"iree-vector-ext-lower-transfer-gather-scatter-to-vector", "func::FuncOp"> {
+ let summary = "Lower transfer_gather/transfer_scatter to vector.gather/vector.scatter";
+ let dependentDialects = [
+ "::mlir::arith::ArithDialect",
+ "::mlir::vector::VectorDialect",
+ "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect"
+ ];
+}
+
#endif // IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
index 55e1e50..8be6ff5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
srcs = enforce_glob(
# keep sorted
[
+ "lower_transfer_gather_scatter_to_vector.mlir",
"vector_ext_fold_unit_extent_dims.mlir",
"vectorize_vector_ext_ops.mlir",
],
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
index b99eb0a..6c3b4bc 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "lower_transfer_gather_scatter_to_vector.mlir"
"vector_ext_fold_unit_extent_dims.mlir"
"vectorize_vector_ext_ops.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir
new file mode 100644
index 0000000..ae80017
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir
@@ -0,0 +1,286 @@
+// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-lower-transfer-gather-scatter-to-vector))' --split-input-file --mlir-print-local-scope | FileCheck %s
+
+#map = affine_map<(d0)[s0] -> (0, 0, s0)>
+#map1 = affine_map<(d0)[s0] -> (d0)>
+module {
+ func.func @lower_transfer_gather_to_vector_gather(%arg0: tensor<1x1x31xf32>, %arg1: tensor<1x1x1x1x16xf32>) -> tensor<1x1x1x1x16xf32> {
+ %0 = ub.poison : vector<1x16xf32>
+ %1 = ub.poison : vector<1x1x16xf32>
+ %2 = ub.poison : vector<1x1x1x16xf32>
+ %3 = ub.poison : vector<1x1x1x1x16xf32>
+ %cst = arith.constant dense<2> : vector<16xindex>
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %4 = vector.step : vector<16xindex>
+ %5 = arith.muli %4, %cst : vector<16xindex>
+ %6 = iree_vector_ext.transfer_gather %arg0[%c0, %c0, %c0] [%5 : vector<16xindex>], %cst_0 {indexing_maps = [#map, #map1]} : tensor<1x1x31xf32>, vector<16xf32>
+ %7 = vector.insert_strided_slice %6, %0 {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x16xf32>
+ %8 = vector.insert_strided_slice %7, %1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x1x16xf32>
+ %9 = vector.insert_strided_slice %8, %2 {offsets = [0, 0, 0, 0], strides = [1, 1, 1]} : vector<1x1x16xf32> into vector<1x1x1x16xf32>
+ %10 = vector.insert_strided_slice %9, %3 {offsets = [0, 0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x1x1x16xf32> into vector<1x1x1x1x16xf32>
+ %11 = vector.transfer_write %10, %arg1[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x1x1x1x16xf32>, tensor<1x1x1x1x16xf32>
+ return %11 : tensor<1x1x1x1x16xf32>
+ }
+}
+
+// CHECK-LABEL: func.func @lower_transfer_gather_to_vector_gather
+// CHECK-SAME: %[[SRC:.+]]: tensor<1x1x31xf32>
+// CHECK-DAG: %[[PASS_THRU:.+]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-DAG: %[[MASK:.+]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-DAG: %[[STRIDE:.+]] = arith.constant dense<2> : vector<16xindex>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[STEP:.+]] = vector.step : vector<16xindex>
+// CHECK: %[[INDICES:.+]] = arith.muli %[[STEP]], %[[STRIDE]] : vector<16xindex>
+// CHECK: %[[GATHER:.+]] = vector.gather %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASS_THRU]]
+// CHECK-SAME: : tensor<1x1x31xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+// -----
+
+func.func @lower_gather_nontrivial_leading_dims(
+ %src: tensor<4x16xf32>, %idx: vector<8xindex>) -> vector<8xf32> {
+ %pad = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %out = iree_vector_ext.transfer_gather %src[%c2, %c0]
+ [%idx : vector<8xindex>], %pad {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : tensor<4x16xf32>, vector<8xf32>
+ return %out : vector<8xf32>
+}
+// CHECK-LABEL: @lower_gather_nontrivial_leading_dims
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: vector.gather %{{.+}}[%[[C2]], %[[C0]]]
+
+// -----
+
+func.func @lower_masked_gather(
+ %src: tensor<8x16xf16>, %idx: vector<16xindex>,
+ %mask: vector<16xi1>) -> vector<16xf16> {
+ %pad = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+ [%idx : vector<16xindex>], %pad, %mask {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : tensor<8x16xf16>, vector<16xf16>, vector<16xi1>
+ return %out : vector<16xf16>
+}
+// CHECK-LABEL: @lower_masked_gather
+// CHECK: vector.gather
+// CHECK-NOT: iree_vector_ext.transfer_gather
+
+// -----
+
+func.func @negative_lower_gather_multiple_index_vecs(
+ %src: tensor<8x16xf16>,
+ %i0: vector<8xindex>, %i1: vector<16xindex>) -> vector<8x16xf16> {
+ %pad = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+ [%i0, %i1 : vector<8xindex>, vector<16xindex>], %pad {
+ indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (d0)>,
+ affine_map<(d0, d1)[s0, s1] -> (d1)>]
+ } : tensor<8x16xf16>, vector<8x16xf16>
+ return %out : vector<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_multiple_index_vecs
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_scalar_index(
+ %src: memref<8x16xf16>, %idx: index) -> vector<8xf16> {
+ %pad = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+ [%idx : index], %pad {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> ()>]
+ } : memref<8x16xf16>, vector<8xf16>
+ return %out : vector<8xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_scalar_index
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_symbol_in_leading_dim(
+ %src: tensor<8x16xf16>, %idx: vector<8xindex>) -> vector<8x16xf16> {
+ %pad = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+ [%idx : vector<8xindex>], %pad {
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d0)>]
+ } : tensor<8x16xf16>, vector<8x16xf16>
+ return %out : vector<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_symbol_in_leading_dim
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_nonconstant_leading_dim(
+ %src: tensor<16x16xf16>, %idx: vector<16xindex>) -> vector<16xf16> {
+ %pad = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+ [%idx : vector<16xindex>], %pad {
+ indexing_maps = [affine_map<(d0)[s0] -> (d0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : tensor<16x16xf16>, vector<16xf16>
+ return %out : vector<16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_nonconstant_leading_dim
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @lower_transfer_scatter_to_vector_scatter(
+ %src: vector<16xf32>, %dest: tensor<1x1x32xf32>,
+ %idx: vector<16xindex>) -> tensor<1x1x32xf32> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0, %c0]
+ [%idx : vector<16xindex>] {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, 0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : vector<16xf32>, tensor<1x1x32xf32> -> tensor<1x1x32xf32>
+ return %out : tensor<1x1x32xf32>
+}
+// CHECK-LABEL: @lower_transfer_scatter_to_vector_scatter
+// CHECK-DAG: %[[MASK:.+]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: vector.scatter %{{.+}}[%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-NOT: iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @lower_scatter_nontrivial_leading_dims(
+ %src: vector<8xf32>, %dest: tensor<4x16xf32>,
+ %idx: vector<8xindex>) -> tensor<4x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c2, %c0]
+ [%idx : vector<8xindex>] {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : vector<8xf32>, tensor<4x16xf32> -> tensor<4x16xf32>
+ return %out : tensor<4x16xf32>
+}
+// CHECK-LABEL: @lower_scatter_nontrivial_leading_dims
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: vector.scatter %{{.+}}[%[[C2]], %[[C0]]]
+
+// -----
+
+func.func @lower_masked_scatter(
+ %src: vector<16xf16>, %dest: tensor<8x16xf16>,
+ %idx: vector<16xindex>, %mask: vector<16xi1>) -> tensor<8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%idx : vector<16xindex>], %mask {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : vector<16xf16>, tensor<8x16xf16>, vector<16xi1> -> tensor<8x16xf16>
+ return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @lower_masked_scatter
+// CHECK: vector.scatter
+// CHECK-NOT: iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @lower_scatter_memref(
+ %src: vector<8xf32>, %dest: memref<4x16xf32>,
+ %idx: vector<8xindex>) {
+ %c0 = arith.constant 0 : index
+ iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%idx : vector<8xindex>] {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : vector<8xf32>, memref<4x16xf32>
+ return
+}
+// CHECK-LABEL: @lower_scatter_memref
+// CHECK-DAG: %[[MASK:.+]] = arith.constant dense<true> : vector<8xi1>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: vector.scatter
+// CHECK-NOT: iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @negative_lower_scatter_multiple_index_vecs(
+ %src: vector<8x16xf16>, %dest: tensor<8x16xf16>,
+ %i0: vector<8xindex>, %i1: vector<16xindex>) -> tensor<8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%i0, %i1 : vector<8xindex>, vector<16xindex>] {
+ indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (d0)>,
+ affine_map<(d0, d1)[s0, s1] -> (d1)>]
+ } : vector<8x16xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+ return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_multiple_index_vecs
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_scalar_index(
+ %src: vector<8xf16>, %dest: tensor<8x16xf16>, %idx: index) -> tensor<8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%idx : index] {
+ indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+ affine_map<(d0)[s0] -> ()>]
+ } : vector<8xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+ return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_scalar_index
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_symbol_in_leading_dim(
+ %src: vector<8x16xf16>, %dest: tensor<8x16xf16>,
+ %idx: vector<8xindex>) -> tensor<8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%idx : vector<8xindex>] {
+ indexing_maps = [affine_map<(d0, d1)[s0] -> (s0, d1)>,
+ affine_map<(d0, d1)[s0] -> (d0)>]
+ } : vector<8x16xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+ return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_symbol_in_leading_dim
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_nonconstant_leading_dim(
+ %src: vector<16xf16>, %dest: tensor<16x16xf16>,
+ %idx: vector<16xindex>) -> tensor<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+ [%idx : vector<16xindex>] {
+ indexing_maps = [affine_map<(d0)[s0] -> (d0, s0)>,
+ affine_map<(d0)[s0] -> (d0)>]
+ } : vector<16xf16>, tensor<16x16xf16> -> tensor<16x16xf16>
+ return %out : tensor<16x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_nonconstant_leading_dim
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter