[Codegen] Add transfer_{gather/scatter} to vector.{gather/scatter} lowering (#24104)

According to discussion with @Groverkss , this pass is supposed to run
at the point where all N-D vectors have been legalized to 1-D so this
pass should only handle the 1-D `transfer_{gather/scatter}` case.

---------

Signed-off-by: NoumanAmir657 <noumanamir453@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
index a31b928..c02be8d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel
@@ -38,6 +38,7 @@
         "BufferizationInterfaces.cpp",
         "DistributionPatterns.cpp",
         "LowerTransferGatherScatterOps.cpp",
+        "LowerTransferGatherScatterToVector.cpp",
         "Passes.cpp",
         "VectorExtFoldUnitExtentDims.cpp",
     ],
@@ -56,6 +57,7 @@
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:BufferizationInterfaces",
         "@llvm-project//mlir:BufferizationTransforms",
+        "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FunctionInterfaces",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
index 1e53aa4..63ea2a3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@
     "BufferizationInterfaces.cpp"
     "DistributionPatterns.cpp"
     "LowerTransferGatherScatterOps.cpp"
+    "LowerTransferGatherScatterToVector.cpp"
     "Passes.cpp"
     "VectorExtFoldUnitExtentDims.cpp"
   DEPS
@@ -40,6 +41,7 @@
     MLIRArithDialect
     MLIRBufferizationDialect
     MLIRBufferizationTransforms
+    MLIRFuncDialect
     MLIRFunctionInterfaces
     MLIRIR
     MLIRPass
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp
new file mode 100644
index 0000000..7a7e237
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/LowerTransferGatherScatterToVector.cpp
@@ -0,0 +1,142 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+namespace mlir::iree_compiler::IREE::VectorExt {
+
+#define GEN_PASS_DEF_LOWERTRANSFERGATHERSCATTERTOVECTORPASS
+#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"
+
+namespace {
+
+static LogicalResult validateMaps(Operation *op, OperandRange indexVecs,
+                                  AffineMap sourceMap,
+                                  PatternRewriter &rewriter) {
+  // TODO: vector.gather/scatter requires a single index vector.
+  // We could flatten everything to 1-D to make it work.
+  if (indexVecs.size() != 1) {
+    return rewriter.notifyMatchFailure(op, "expected exactly one index vec");
+  }
+
+  if (!isa<VectorType>(indexVecs[0].getType())) {
+    return rewriter.notifyMatchFailure(op, "index vec must be a vector type");
+  }
+
+  unsigned numResults = sourceMap.getNumResults();
+  for (unsigned i = 0; i < numResults - 1; ++i) {
+    if (!isa<AffineConstantExpr>(sourceMap.getResult(i))) {
+      return rewriter.notifyMatchFailure(op,
+                                         "non-gathered dims must be constants");
+    }
+  }
+
+  if (!isa<AffineSymbolExpr>(sourceMap.getResult(numResults - 1))) {
+    return rewriter.notifyMatchFailure(op,
+                                       "last dim must be a pure symbol expr");
+  }
+  return success();
+}
+
+struct LowerTransferGatherToVectorGather final
+    : OpRewritePattern<TransferGatherOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(TransferGatherOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
+    if (failed(
+            validateMaps(op, op.getIndexVecs(), indexingMaps[0], rewriter))) {
+      return failure();
+    }
+
+    Location loc = op.getLoc();
+    VectorType resultType = op.getVectorType();
+    Value indexVec = op.getIndexVecs()[0];
+
+    Value mask = op.getMask();
+    if (!mask) {
+      auto maskType =
+          VectorType::get(resultType.getShape(), rewriter.getI1Type());
+      mask = arith::ConstantOp::create(rewriter, loc,
+                                       DenseElementsAttr::get(maskType, true));
+    }
+
+    Value passthru =
+        vector::BroadcastOp::create(rewriter, loc, resultType, op.getPadding());
+
+    rewriter.replaceOpWithNewOp<vector::GatherOp>(op, resultType, op.getBase(),
+                                                  op.getOffsets(), indexVec,
+                                                  mask, passthru);
+    return success();
+  }
+};
+
+struct LowerTransferScatterToVectorScatter final
+    : OpRewritePattern<TransferScatterOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(TransferScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
+    if (failed(
+            validateMaps(op, op.getIndexVecs(), indexingMaps[0], rewriter))) {
+      return failure();
+    }
+
+    Location loc = op.getLoc();
+    VectorType vectorType = op.getVectorType();
+    Value indexVec = op.getIndexVecs()[0];
+
+    Value mask = op.getMask();
+    if (!mask) {
+      auto maskType =
+          VectorType::get(vectorType.getShape(), rewriter.getI1Type());
+      mask = arith::ConstantOp::create(rewriter, loc,
+                                       DenseElementsAttr::get(maskType, true));
+    }
+
+    Type resultType = op.hasTensorSemantics() ? op.getBase().getType() : Type{};
+    auto scatterOp = vector::ScatterOp::create(rewriter, loc, resultType,
+                                               op.getBase(), op.getOffsets(),
+                                               indexVec, mask, op.getVector());
+    if (op.hasTensorSemantics()) {
+      rewriter.replaceOp(op, scatterOp.getResult());
+    } else {
+      rewriter.eraseOp(op);
+    }
+    return success();
+  }
+};
+
+struct LowerTransferGatherScatterToVectorPass final
+    : impl::LowerTransferGatherScatterToVectorPassBase<
+          LowerTransferGatherScatterToVectorPass> {
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<LowerTransferGatherToVectorGather,
+                 LowerTransferScatterToVectorScatter>(ctx);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::VectorExt
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
index 8c54b88..978a941 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td
@@ -18,4 +18,14 @@
   ];
 }
 
+def LowerTransferGatherScatterToVectorPass :
+    Pass<"iree-vector-ext-lower-transfer-gather-scatter-to-vector", "func::FuncOp"> {
+  let summary = "Lower transfer_gather/transfer_scatter to vector.gather/vector.scatter";
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect",
+    "::mlir::vector::VectorDialect",
+    "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect"
+  ];
+}
+
 #endif // IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
index 55e1e50..8be6ff5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
     srcs = enforce_glob(
         # keep sorted
         [
+            "lower_transfer_gather_scatter_to_vector.mlir",
             "vector_ext_fold_unit_extent_dims.mlir",
             "vectorize_vector_ext_ops.mlir",
         ],
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
index b99eb0a..6c3b4bc 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "lower_transfer_gather_scatter_to_vector.mlir"
     "vector_ext_fold_unit_extent_dims.mlir"
     "vectorize_vector_ext_ops.mlir"
   TOOLS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir
new file mode 100644
index 0000000..ae80017
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/lower_transfer_gather_scatter_to_vector.mlir
@@ -0,0 +1,286 @@
+// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-lower-transfer-gather-scatter-to-vector))' --split-input-file --mlir-print-local-scope | FileCheck %s
+
+#map = affine_map<(d0)[s0] -> (0, 0, s0)>
+#map1 = affine_map<(d0)[s0] -> (d0)>
+module {
+  func.func @lower_transfer_gather_to_vector_gather(%arg0: tensor<1x1x31xf32>, %arg1: tensor<1x1x1x1x16xf32>) -> tensor<1x1x1x1x16xf32> {
+    %0 = ub.poison : vector<1x16xf32>
+    %1 = ub.poison : vector<1x1x16xf32>
+    %2 = ub.poison : vector<1x1x1x16xf32>
+    %3 = ub.poison : vector<1x1x1x1x16xf32>
+    %cst = arith.constant dense<2> : vector<16xindex>
+    %cst_0 = arith.constant 0.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %4 = vector.step : vector<16xindex>
+    %5 = arith.muli %4, %cst : vector<16xindex>
+    %6 = iree_vector_ext.transfer_gather %arg0[%c0, %c0, %c0] [%5 : vector<16xindex>], %cst_0 {indexing_maps = [#map, #map1]} : tensor<1x1x31xf32>, vector<16xf32>
+    %7 = vector.insert_strided_slice %6, %0 {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x16xf32>
+    %8 = vector.insert_strided_slice %7, %1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x1x16xf32>
+    %9 = vector.insert_strided_slice %8, %2 {offsets = [0, 0, 0, 0], strides = [1, 1, 1]} : vector<1x1x16xf32> into vector<1x1x1x16xf32>
+    %10 = vector.insert_strided_slice %9, %3 {offsets = [0, 0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x1x1x16xf32> into vector<1x1x1x1x16xf32>
+    %11 = vector.transfer_write %10, %arg1[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x1x1x1x16xf32>, tensor<1x1x1x1x16xf32>
+    return %11 : tensor<1x1x1x1x16xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @lower_transfer_gather_to_vector_gather
+// CHECK-SAME:  %[[SRC:.+]]: tensor<1x1x31xf32>
+// CHECK-DAG:     %[[PASS_THRU:.+]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-DAG:     %[[MASK:.+]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-DAG:     %[[STRIDE:.+]] = arith.constant dense<2> : vector<16xindex>
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[STEP:.+]] = vector.step : vector<16xindex>
+// CHECK:         %[[INDICES:.+]] = arith.muli %[[STEP]], %[[STRIDE]] : vector<16xindex>
+// CHECK:         %[[GATHER:.+]] = vector.gather %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASS_THRU]]
+// CHECK-SAME:      : tensor<1x1x31xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+// -----
+
+func.func @lower_gather_nontrivial_leading_dims(
+    %src: tensor<4x16xf32>, %idx: vector<8xindex>) -> vector<8xf32> {
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %out = iree_vector_ext.transfer_gather %src[%c2, %c0]
+    [%idx : vector<8xindex>], %pad {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : tensor<4x16xf32>, vector<8xf32>
+  return %out : vector<8xf32>
+}
+// CHECK-LABEL: @lower_gather_nontrivial_leading_dims
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK:     vector.gather %{{.+}}[%[[C2]], %[[C0]]]
+
+// -----
+
+func.func @lower_masked_gather(
+    %src: tensor<8x16xf16>, %idx: vector<16xindex>,
+    %mask: vector<16xi1>) -> vector<16xf16> {
+  %pad = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+    [%idx : vector<16xindex>], %pad, %mask {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : tensor<8x16xf16>, vector<16xf16>, vector<16xi1>
+  return %out : vector<16xf16>
+}
+// CHECK-LABEL: @lower_masked_gather
+// CHECK:       vector.gather
+// CHECK-NOT:   iree_vector_ext.transfer_gather
+
+// -----
+
+func.func @negative_lower_gather_multiple_index_vecs(
+    %src: tensor<8x16xf16>,
+    %i0: vector<8xindex>, %i1: vector<16xindex>) -> vector<8x16xf16> {
+  %pad = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+    [%i0, %i1 : vector<8xindex>, vector<16xindex>], %pad {
+      indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d0)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d1)>]
+    } : tensor<8x16xf16>, vector<8x16xf16>
+  return %out : vector<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_multiple_index_vecs
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_scalar_index(
+    %src: memref<8x16xf16>, %idx: index) -> vector<8xf16> {
+  %pad = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+    [%idx : index], %pad {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> ()>]
+    } : memref<8x16xf16>, vector<8xf16>
+  return %out : vector<8xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_scalar_index
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_symbol_in_leading_dim(
+    %src: tensor<8x16xf16>, %idx: vector<8xindex>) -> vector<8x16xf16> {
+  %pad = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+    [%idx : vector<8xindex>], %pad {
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0)>]
+    } : tensor<8x16xf16>, vector<8x16xf16>
+  return %out : vector<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_symbol_in_leading_dim
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @negative_lower_gather_nonconstant_leading_dim(
+    %src: tensor<16x16xf16>, %idx: vector<16xindex>) -> vector<16xf16> {
+  %pad = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_gather %src[%c0, %c0]
+    [%idx : vector<16xindex>], %pad {
+      indexing_maps = [affine_map<(d0)[s0] -> (d0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : tensor<16x16xf16>, vector<16xf16>
+  return %out : vector<16xf16>
+}
+// CHECK-LABEL: @negative_lower_gather_nonconstant_leading_dim
+// CHECK: iree_vector_ext.transfer_gather
+// CHECK-NOT: vector.gather
+
+// -----
+
+func.func @lower_transfer_scatter_to_vector_scatter(
+    %src: vector<16xf32>, %dest: tensor<1x1x32xf32>,
+    %idx: vector<16xindex>) -> tensor<1x1x32xf32> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0, %c0]
+    [%idx : vector<16xindex>] {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, 0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : vector<16xf32>, tensor<1x1x32xf32> -> tensor<1x1x32xf32>
+  return %out : tensor<1x1x32xf32>
+}
+// CHECK-LABEL: @lower_transfer_scatter_to_vector_scatter
+// CHECK-DAG:     %[[MASK:.+]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         vector.scatter %{{.+}}[%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-NOT:     iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @lower_scatter_nontrivial_leading_dims(
+    %src: vector<8xf32>, %dest: tensor<4x16xf32>,
+    %idx: vector<8xindex>) -> tensor<4x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c2, %c0]
+    [%idx : vector<8xindex>] {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : vector<8xf32>, tensor<4x16xf32> -> tensor<4x16xf32>
+  return %out : tensor<4x16xf32>
+}
+// CHECK-LABEL: @lower_scatter_nontrivial_leading_dims
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK:     vector.scatter %{{.+}}[%[[C2]], %[[C0]]]
+
+// -----
+
+func.func @lower_masked_scatter(
+    %src: vector<16xf16>, %dest: tensor<8x16xf16>,
+    %idx: vector<16xindex>, %mask: vector<16xi1>) -> tensor<8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%idx : vector<16xindex>], %mask {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : vector<16xf16>, tensor<8x16xf16>, vector<16xi1> -> tensor<8x16xf16>
+  return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @lower_masked_scatter
+// CHECK:       vector.scatter
+// CHECK-NOT:   iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @lower_scatter_memref(
+    %src: vector<8xf32>, %dest: memref<4x16xf32>,
+    %idx: vector<8xindex>) {
+  %c0 = arith.constant 0 : index
+  iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%idx : vector<8xindex>] {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : vector<8xf32>, memref<4x16xf32>
+  return
+}
+// CHECK-LABEL: @lower_scatter_memref
+// CHECK-DAG:   %[[MASK:.+]] = arith.constant dense<true> : vector<8xi1>
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       vector.scatter
+// CHECK-NOT:   iree_vector_ext.transfer_scatter
+
+// -----
+
+func.func @negative_lower_scatter_multiple_index_vecs(
+    %src: vector<8x16xf16>, %dest: tensor<8x16xf16>,
+    %i0: vector<8xindex>, %i1: vector<16xindex>) -> tensor<8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%i0, %i1 : vector<8xindex>, vector<16xindex>] {
+      indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d0)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d1)>]
+    } : vector<8x16xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+  return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_multiple_index_vecs
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_scalar_index(
+    %src: vector<8xf16>, %dest: tensor<8x16xf16>, %idx: index) -> tensor<8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%idx : index] {
+      indexing_maps = [affine_map<(d0)[s0] -> (0, s0)>,
+                       affine_map<(d0)[s0] -> ()>]
+    } : vector<8xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+  return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_scalar_index
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_symbol_in_leading_dim(
+    %src: vector<8x16xf16>, %dest: tensor<8x16xf16>,
+    %idx: vector<8xindex>) -> tensor<8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%idx : vector<8xindex>] {
+      indexing_maps = [affine_map<(d0, d1)[s0] -> (s0, d1)>,
+                       affine_map<(d0, d1)[s0] -> (d0)>]
+    } : vector<8x16xf16>, tensor<8x16xf16> -> tensor<8x16xf16>
+  return %out : tensor<8x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_symbol_in_leading_dim
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter
+
+// -----
+
+func.func @negative_lower_scatter_nonconstant_leading_dim(
+    %src: vector<16xf16>, %dest: tensor<16x16xf16>,
+    %idx: vector<16xindex>) -> tensor<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %out = iree_vector_ext.transfer_scatter %src into %dest[%c0, %c0]
+    [%idx : vector<16xindex>] {
+      indexing_maps = [affine_map<(d0)[s0] -> (d0, s0)>,
+                       affine_map<(d0)[s0] -> (d0)>]
+    } : vector<16xf16>, tensor<16x16xf16> -> tensor<16x16xf16>
+  return %out : tensor<16x16xf16>
+}
+// CHECK-LABEL: @negative_lower_scatter_nonconstant_leading_dim
+// CHECK: iree_vector_ext.transfer_scatter
+// CHECK-NOT: vector.scatter