[Codegen][LLVMGPU] Set global read layouts at linalg level (#18860)

- Operand promotion is now done the same way as TileAndFuse pipeline, by
reading promote_operands config from lowering_config.
- Moves global read layout setting to LLVMGPUConfigureTensorLayouts,
from LLVMGPUConfigureVectorLayouts pass, anchoring based on lowering
config.

These changes by side effect allow setting layouts on gathers in
VectorDistribute pipeline.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 19af0c4..05b4405 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -95,7 +95,6 @@
         "LLVMGPUCastAddressSpaceFunction.cpp",
         "LLVMGPUCastTypeToFitMMA.cpp",
         "LLVMGPUConfigureTensorLayouts.cpp",
-        "LLVMGPUConfigureVectorLayouts.cpp",
         "LLVMGPUConvolutionToIGEMM.cpp",
         "LLVMGPULinkExecutables.cpp",
         "LLVMGPULowerExecutableTarget.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index aa2c5a5..0566481 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -80,7 +80,6 @@
     "LLVMGPUCastAddressSpaceFunction.cpp"
     "LLVMGPUCastTypeToFitMMA.cpp"
     "LLVMGPUConfigureTensorLayouts.cpp"
-    "LLVMGPUConfigureVectorLayouts.cpp"
     "LLVMGPUConvolutionToIGEMM.cpp"
     "LLVMGPULinkExecutables.cpp"
     "LLVMGPULowerExecutableTarget.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 63d27ae..ede2d0b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -386,6 +386,7 @@
                      b.getI64ArrayAttr(workgroupTileSizes));
   attrs.emplace_back(StringAttr::get(context, "reduction"),
                      b.getI64ArrayAttr(reductionTileSizes));
+  IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
 
   auto configDict = DictionaryAttr::get(context, attrs);
   auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
@@ -633,6 +634,7 @@
                      b.getI64ArrayAttr(workgroupTileSizes));
   attrs.emplace_back(StringAttr::get(context, "reduction"),
                      b.getI64ArrayAttr(reductionTileSizes));
+  IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
 
   auto configDict = DictionaryAttr::get(context, attrs);
   auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
@@ -830,6 +832,15 @@
                      b.getI64ArrayAttr(workgroupTileSizes));
   attrs.emplace_back(StringAttr::get(context, "reduction"),
                      b.getI64ArrayAttr(reductionTileSizes));
+  IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs,
+                                                        {0, 1, 2});
+
+  SmallVector<NamedAttribute, 2> qkConfig;
+  SmallVector<NamedAttribute, 2> pvConfig;
+
+  IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig,
+                                                        {0, 1});
+  IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1});
 
   SmallVector<NamedAttribute, 2> qkAttrs;
   SmallVector<NamedAttribute, 2> pvAttrs;
@@ -837,6 +848,17 @@
   qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
   pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));
 
+  auto qkConfigDict = b.getDictionaryAttr(qkConfig);
+  auto pvConfigDict = b.getDictionaryAttr(pvConfig);
+
+  auto qkLoweringConfig =
+      IREE::GPU::LoweringConfigAttr::get(context, qkConfigDict);
+  auto pvLoweringConfig =
+      IREE::GPU::LoweringConfigAttr::get(context, pvConfigDict);
+
+  qkAttrs.emplace_back(b.getNamedAttr("lowering_config", qkLoweringConfig));
+  pvAttrs.emplace_back(b.getNamedAttr("lowering_config", pvLoweringConfig));
+
   auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
   auto pvAttrDict = b.getDictionaryAttr(pvAttrs);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 22b570b..4945e66 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -21,13 +21,37 @@
 #define GEN_PASS_DEF_LLVMGPUCONFIGURETENSORLAYOUTSPASS
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
 
+using IREE::VectorExt::NestedLayoutAttr;
+using IREE::VectorExt::ToLayoutOp;
+using IREE::VectorExt::VectorLayoutInterface;
+
 namespace {
 
+static SmallVector<bool> getPromotedOperands(Operation *op) {
+  SmallVector<bool> promotedOperands(op->getNumOperands(), false);
+
+  auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
+  if (!config) {
+    return promotedOperands;
+  }
+
+  std::optional<SmallVector<int64_t>> promoteConfig =
+      config.getPromotedOperandList();
+  if (!promoteConfig) {
+    return promotedOperands;
+  }
+
+  for (int64_t operand : promoteConfig.value()) {
+    promotedOperands[operand] = true;
+  }
+
+  return promotedOperands;
+}
+
 static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+                                          SmallVector<bool> promotedOperands,
                                           RewriterBase &rewriter,
-                                          linalg::LinalgOp contract,
-                                          bool promoteLhs = true,
-                                          bool promoteRhs = true) {
+                                          linalg::LinalgOp contract) {
   // TODO: Add SIMT fallback.
   if (!schedule) {
     return contract->emitError("missing mma schedule for contraction");
@@ -56,33 +80,36 @@
 
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(contract);
-  auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, lhs, aLayout, schedule.getIntrinsic());
-  auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, rhs, bLayout, schedule.getIntrinsic());
-  auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, acc, cLayout, schedule.getIntrinsic());
+  auto layoutedLhs =
+      rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.getIntrinsic());
+  auto layoutedRhs =
+      rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.getIntrinsic());
+  auto layoutedAcc =
+      rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.getIntrinsic());
 
   // Promote matmul lhs and rhs.
-  // TODO: We should read this from the lowering_config on the operation.
   // TODO: This is a hack until layout analysis is improved. The layout analysis
   // should decide where to put these shared memory conversions.
-  if (promoteLhs) {
+  if (promotedOperands[0]) {
     layoutedLhs.setSharedMemoryConversion(true);
   }
 
-  if (promoteRhs) {
+  if (promotedOperands[1]) {
     layoutedRhs.setSharedMemoryConversion(true);
   }
 
+  if (promotedOperands[2]) {
+    layoutedAcc.setSharedMemoryConversion(true);
+  }
+
   contract->setOperand(0, layoutedLhs.getResult());
   contract->setOperand(1, layoutedRhs.getResult());
   contract->setOperand(2, layoutedAcc.getResult());
 
   // Set layout for result.
   rewriter.setInsertionPointAfter(contract);
-  auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, contract->getResult(0), cLayout, schedule.getIntrinsic());
+  auto toLayout = rewriter.create<ToLayoutOp>(loc, contract->getResult(0),
+                                              cLayout, schedule.getIntrinsic());
   rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
                                 toLayout);
 
@@ -90,6 +117,7 @@
 }
 
 static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+                                          SmallVector<bool> promotedOperands,
                                           RewriterBase &rewriter,
                                           linalg::LinalgOp conv) {
   // TODO: Add SIMT fallback.
@@ -139,19 +167,27 @@
 
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(conv);
-  auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, lhs, aLayout, schedule.getIntrinsic());
-  auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, rhs, bLayout, schedule.getIntrinsic());
-  auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, acc, cLayout, schedule.getIntrinsic());
+  auto layoutedLhs =
+      rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.getIntrinsic());
+  auto layoutedRhs =
+      rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.getIntrinsic());
+  auto layoutedAcc =
+      rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.getIntrinsic());
 
   // Promote matmul lhs and rhs.
-  // TODO: We should read this from the lowering_config on the operation.
   // TODO: This is a hack until layout analysis is improved. The layout analysis
   // should decide where to put these shared memory conversions.
-  layoutedLhs.setSharedMemoryConversion(true);
-  layoutedRhs.setSharedMemoryConversion(true);
+  if (promotedOperands[0]) {
+    layoutedLhs.setSharedMemoryConversion(true);
+  }
+
+  if (promotedOperands[1]) {
+    layoutedRhs.setSharedMemoryConversion(true);
+  }
+
+  if (promotedOperands[2]) {
+    layoutedAcc.setSharedMemoryConversion(true);
+  }
 
   conv->setOperand(0, layoutedLhs.getResult());
   conv->setOperand(1, layoutedRhs.getResult());
@@ -159,8 +195,8 @@
 
   // Set layout for result.
   rewriter.setInsertionPointAfter(conv);
-  auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, conv->getResult(0), cLayout, schedule.getIntrinsic());
+  auto toLayout = rewriter.create<ToLayoutOp>(loc, conv->getResult(0), cLayout,
+                                              schedule.getIntrinsic());
   rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
                                 toLayout);
 
@@ -278,6 +314,12 @@
           /*subgroup_n_count=*/1);
   IREE::GPU::MMAScheduleAttr pvSchedule = schedule;
 
+  SmallVector<bool> promotedQKOperands = getPromotedOperands(qkMatmul);
+  SmallVector<bool> promotedPVOperands = getPromotedOperands(pvMatmul);
+
+  // Do not promote lhs of pvMatmul if we are reusing the intrinsic output.
+  promotedPVOperands[0] = !reuseIntrinsicOutput;
+
   // Transpose the intrinsic if requested. See docs for
   // swapOperandsToTransposeIntrinsic for more information on why this is done.
   if (transposeIntrinsic) {
@@ -292,21 +334,114 @@
     swapOperandsToTransposeIntrinsic(rewriter, pvGeneric);
     qkSchedule = transposeSchedule(rewriter, qkSchedule);
     pvSchedule = transposeSchedule(rewriter, pvSchedule);
+
+    // Swap promoted operands.
+    std::swap(promotedQKOperands[0], promotedQKOperands[1]);
+    std::swap(promotedPVOperands[0], promotedPVOperands[1]);
   }
 
-  if (failed(setContractionAnchor(qkSchedule, rewriter, qkMatmul))) {
+  if (failed(setContractionAnchor(qkSchedule, promotedQKOperands, rewriter,
+                                  qkMatmul))) {
     return failure();
   }
 
-  // Do not promote lhs of pvMatmul if we are reusing the intrinsic output.
-  bool promoteLhs = !reuseIntrinsicOutput;
-  bool promoteRhs = true;
-  if (transposeIntrinsic) {
-    std::swap(promoteLhs, promoteRhs);
+  return setContractionAnchor(pvSchedule, promotedPVOperands, rewriter,
+                              pvMatmul);
+}
+
+// Apply the permuted projection map to the layout.
+static IREE::VectorExt::VectorLayoutInterface
+getLayoutForMap(VectorLayoutInterface layout, AffineMap map) {
+  // Project out unusued dims in layout.
+  SmallVector<bool> projectedDims(layout.getRank(), false);
+  for (int dim : getUnusedDimsBitVector(map).set_bits()) {
+    projectedDims[dim] = true;
+  }
+  IREE::VectorExt::VectorLayoutInterface projectedLayout =
+      layout.project(projectedDims);
+
+  // Transpose dims in layout.
+  AffineMap permMap = compressUnusedDims(map);
+  SmallVector<int64_t> identity =
+      llvm::to_vector(llvm::seq<int64_t>(permMap.getNumDims()));
+  SmallVector<int64_t> perm = applyPermutationMap<int64_t>(permMap, identity);
+  return projectedLayout.permute(perm);
+}
+
+static LogicalResult setDerivedThreadConfigLayout(
+    IREE::GPU::DerivedThreadConfigAttr config, linalg::LinalgOp linalgOp,
+    ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
+
+  int64_t opRank = linalgOp.getNumLoops();
+
+  SmallVector<int64_t> elementTile = config.getStaticTilingLevelSizes(
+      static_cast<unsigned>(IREE::GPU::TilingLevel::Thread), linalgOp);
+
+  SmallVector<int64_t> opShape = linalgOp.getStaticLoopRanges();
+  for (auto [index, size, element] : llvm::enumerate(opShape, elementTile)) {
+    if (ShapedType::isDynamic(size)) {
+      linalgOp->emitError() << "Cannot set layouts for dynamic loop ranges";
+      return failure();
+    }
+
+    if (size % element != 0) {
+      linalgOp->emitError()
+          << "Operation with unsupported number of elements. "
+             "Chosen vector tile sizes for operation are not "
+             "divisible by operation loop ranges at dim: "
+          << index << ", size=" << size << ", vector size = " << element;
+      return failure();
+    }
+
+    size /= element;
   }
 
-  return setContractionAnchor(pvSchedule, rewriter, pvMatmul, promoteLhs,
-                              promoteRhs);
+  SmallVector<int64_t> threadTile(opRank, 1);
+  SmallVector<int64_t> threadStrides(opRank, 0);
+
+  int64_t residualThreads = ShapedType::getNumElements(workgroupSize);
+  int64_t currStride = 1;
+
+  for (auto [tile, stride, size] :
+       llvm::reverse(llvm::zip(threadTile, threadStrides, opShape))) {
+    int64_t threadBlock;
+    if (residualThreads % size == 0) {
+      threadBlock = size;
+    } else if (size % residualThreads == 0) {
+      threadBlock = residualThreads;
+    } else {
+      linalgOp->emitError() << "Operation with unsupported number of elements.";
+      return failure();
+    }
+
+    tile = threadBlock;
+    stride = currStride;
+    size /= threadBlock;
+
+    currStride *= threadBlock;
+    residualThreads /= threadBlock;
+  }
+
+  SmallVector<int64_t> subgroupTile(opRank, 1);
+  SmallVector<int64_t> subgroupStrides(opRank, 0);
+  SmallVector<int64_t> outerTile(opRank, 1);
+
+  MLIRContext *context = rewriter.getContext();
+  auto layout = IREE::VectorExt::NestedLayoutAttr::get(
+      context, subgroupTile, opShape, outerTile, threadTile, elementTile,
+      subgroupStrides, threadStrides);
+
+  Location loc = linalgOp.getLoc();
+
+  rewriter.setInsertionPointAfter(linalgOp);
+  for (OpResult result : linalgOp->getResults()) {
+    VectorLayoutInterface resultLayout =
+        getLayoutForMap(layout, linalgOp.getIndexingMapMatchingResult(result));
+    auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
+    rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
+  }
+
+  return success();
 }
 
 static Operation *getOpWithAttr(Operation *root, StringRef attr) {
@@ -330,13 +465,28 @@
 struct LLVMGPUConfigureTensorLayoutsPass final
     : impl::LLVMGPUConfigureTensorLayoutsPassBase<
           LLVMGPUConfigureTensorLayoutsPass> {
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
     registry.insert<vector::VectorDialect>();
   }
 
   void runOnOperation() override {
-    auto func = getOperation();
+    FunctionOpInterface func = getOperation();
+    IRRewriter rewriter(func);
+
+    std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
+        getWorkgroupSize(func);
+    if (!maybeWorkgroupSize) {
+      func->emitOpError()
+          << "unable to query workgroup_size information from entry point";
+      return signalPassFailure();
+    }
+
+    if (failed(setDerivedConfigLayouts(func, maybeWorkgroupSize.value(),
+                                       rewriter))) {
+      return signalPassFailure();
+    }
 
     llvm::StringLiteral scheduleAttrName =
         IREE::GPU::MMAScheduleAttr::getMnemonic();
@@ -377,16 +527,18 @@
       return WalkResult::advance();
     });
 
-    IRRewriter rewriter(func);
-
     for (linalg::LinalgOp contract : contracts) {
-      if (failed(setContractionAnchor(scheduleAttr, rewriter, contract))) {
+      SmallVector<bool> promotedOperands = getPromotedOperands(contract);
+      if (failed(setContractionAnchor(scheduleAttr, promotedOperands, rewriter,
+                                      contract))) {
         return signalPassFailure();
       }
     }
 
     for (linalg::LinalgOp conv : convs) {
-      if (failed(setConvolutionAnchor(scheduleAttr, rewriter, conv))) {
+      SmallVector<bool> promotedOperands = getPromotedOperands(conv);
+      if (failed(setConvolutionAnchor(scheduleAttr, promotedOperands, rewriter,
+                                      conv))) {
         return signalPassFailure();
       }
     }
@@ -398,6 +550,31 @@
       }
     }
   }
+
+  LogicalResult setDerivedConfigLayouts(FunctionOpInterface funcOp,
+                                        ArrayRef<int64_t> workgroupSize,
+                                        RewriterBase &rewriter) {
+    SmallVector<linalg::LinalgOp> candidates;
+    funcOp->walk([&](linalg::LinalgOp op) {
+      auto config = dyn_cast_or_null<IREE::GPU::DerivedThreadConfigAttr>(
+          getLoweringConfig(op));
+      if (config) {
+        candidates.push_back(op);
+      }
+    });
+
+    for (linalg::LinalgOp candidate : candidates) {
+      auto config = dyn_cast_or_null<IREE::GPU::DerivedThreadConfigAttr>(
+          getLoweringConfig(candidate));
+      assert(config);
+      if (failed(setDerivedThreadConfigLayout(config, candidate, workgroupSize,
+                                              rewriter))) {
+        return failure();
+      }
+    }
+
+    return success();
+  }
 };
 } // namespace
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp
deleted file mode 100644
index f98c642..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp
+++ /dev/null
@@ -1,299 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <algorithm>
-
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
-#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/MathExtras.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-
-#define DEBUG_TYPE "iree-llvmgpu-configure-vector-layouts"
-
-namespace mlir::iree_compiler {
-
-#define GEN_PASS_DEF_LLVMGPUCONFIGUREVECTORLAYOUTSPASS
-#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
-
-namespace {
-
-// Sets a layout anchor for reads from global memory.
-// The layout this generates is approximately the following:
-//
-// #layout = #iree_vector_ext.nested_layout<
-//    subgroup_tile = [1, ..., 1]
-//    batch_tile    = [<remaining undistributed elements>]
-//    outer_tile    = [1, ..., 1]
-//    thread_tile   = [<greedy from innermost memref dim>]
-//    element_tile  = [1, ..., 128/element_bitwidth, ..., 1]
-//            innermost_memref_dimension ^^^^^^
-//
-// (All orders are the same)
-//    *_order = [<broadcasted_dims>, <transfer_permutation>]>
-//
-// So for the following transfer_read with 64 threads:
-//  vector.transfer_read ... : memref<16x256xf16>, vector<16x32xf16>
-//
-// We use the following layout:
-// #layout = #iree_vector_ext.nested_layout<
-//    subgroup_tile = [1, 1]
-//    batch_tile    = [1, 1]
-//    outer_tile    = [1, 1]
-//    thread_tile   = [16, 4]
-//    element_tile  = [1, 8]
-LogicalResult setTransferReadAnchor(ArrayRef<int64_t> workgroupSize,
-                                    RewriterBase &rewriter,
-                                    vector::TransferReadOp transfer) {
-  MLIRContext *context = rewriter.getContext();
-
-  // Get the forward slice of the transfer to approximate whether it will take
-  // the layout of a contraction instead. Transfer_read ops used directly by a
-  // contraction (i.e. without a copy to shared memory in between) should take
-  // the layout of the contraction op. This is common for cases where the
-  // initial values of the accumulator in a linalg.matmul is read from memory
-  // instead of just being a zerofill.
-  ForwardSliceOptions forwardOptions;
-  forwardOptions.filter = [&](Operation *op) -> bool {
-    return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
-  };
-  BackwardSliceOptions backwardOptions;
-  backwardOptions.filter = [&](Operation *op) -> bool {
-    return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
-  };
-  SetVector<Operation *> slice =
-      getSlice(transfer, backwardOptions, forwardOptions);
-
-  if (llvm::any_of(slice, llvm::IsaPred<vector::ContractionOp>)) {
-    return success();
-  }
-
-  // Shared memory loads are expected to take the layout of the contraction.
-  auto sourceMemRefType = dyn_cast<MemRefType>(transfer.getSource().getType());
-  if (!sourceMemRefType || hasSharedMemoryAddressSpace(sourceMemRefType)) {
-    return success();
-  }
-
-  // Take on layout of broadcast.
-  if (transfer->hasOneUse() &&
-      dyn_cast<vector::BroadcastOp>(*transfer->getUsers().begin())) {
-    return success();
-  }
-
-  // TODO: Support masking.
-  if (transfer.getMask()) {
-    transfer->emitOpError(
-        "Anchoring on transfer_read with masks is not yet implemented.");
-    return failure();
-  }
-
-  int64_t bitWidth = IREE::Util::getTypeBitWidth(
-      getElementTypeOrSelf(transfer.getVectorType()));
-  if (!llvm::isPowerOf2_64(bitWidth) || bitWidth > 128) {
-    transfer->emitOpError(
-        "Anchoring on transfer_read with element type of bitwidth " +
-        std::to_string(bitWidth) + " is not yet implemented");
-    return failure();
-  }
-  int64_t numElementTile = 128 / bitWidth;
-  int64_t flatNumElements =
-      ShapedType::getNumElements(transfer.getVectorType().getShape());
-  int64_t flatNumThreads = ShapedType::getNumElements(workgroupSize);
-  if (flatNumElements % flatNumThreads != 0) {
-    transfer->emitOpError()
-        << "Anchoring on transfer_read with unsupported number of elements "
-           "(not divisible by workgroup size)"
-        << ", number of elements: " << flatNumElements
-        << ", workgroup size: " << flatNumThreads;
-    return failure();
-  }
-  numElementTile = std::min(numElementTile, flatNumElements / flatNumThreads);
-
-  AffineMap transferMap = transfer.getPermutationMap();
-  if (transferMap.getNumDims() == 0) {
-    transfer->emitOpError("Anchoring on transfer_read with zero-rank "
-                          "permutation map is not supported.");
-    return failure();
-  }
-
-  // Select the innermost dim of the memref as the contiguous dim to load
-  // from.
-  int64_t transferRank = transfer.getVectorType().getRank();
-  std::optional<unsigned> maybeDim = transferMap.getResultPosition(
-      getAffineDimExpr(transferMap.getNumDims() - 1, context));
-  int64_t distXDim = maybeDim ? *maybeDim : transferRank - 1;
-
-  ArrayRef<int64_t> vectorShape = transfer.getVectorType().getShape();
-
-  // Limit the maximum inner vector read width to the innermost contiguous
-  // dimension. We could try to be clever and extend this to adjacent
-  // dimensions in cases where the innermost read vector dimension is small,
-  // but that requires comparing memref strides and is uncommon. For now
-  // prioritize warp contiguity over 128-bit read granularity.
-  numElementTile = std::min(numElementTile, vectorShape[distXDim]);
-
-  llvm::SetVector<unsigned> vectorDimDistributionOrder;
-  // Get the order in which to distribute vector dimensions to threads, going
-  // from innermost to outermost memref dimension. It's important to note
-  // that this heuristic only applies to matrix multiplication cases where
-  // we are promoting the operands of a contraction to shared memory and we
-  // have no producers fused with the matmul. In general there is no universal
-  // way to set an anchoring layout for reads without doing an analysis of how
-  // the read values are used.
-  for (int i = transferMap.getNumDims() - 1; i >= 0; --i) {
-    std::optional<unsigned> maybeDim =
-        transferMap.getResultPosition(getAffineDimExpr(i, context));
-    if (maybeDim) {
-      vectorDimDistributionOrder.insert(*maybeDim);
-    }
-  }
-  // Add all remaining (broadcasted) dimensions
-  for (auto dim : llvm::seq(static_cast<int64_t>(0), transferRank)) {
-    if (!vectorDimDistributionOrder.contains(dim))
-      vectorDimDistributionOrder.insert(dim);
-  }
-
-  int64_t residualThreads = flatNumThreads;
-  int64_t residualElements = numElementTile;
-
-  SmallVector<int64_t> order(vectorDimDistributionOrder.rbegin(),
-                             vectorDimDistributionOrder.rend());
-
-  // Distribute all threads in the workgroup to the "threads" dimension,
-  // meaning subgroup counts is unit here, even though the read is being
-  // distributed to multiple subgroups. This is in an attempt to do a
-  // workgroup contiguous load.
-  SmallVector<int64_t> subgroupCounts(transferRank, 1);
-  SmallVector<int64_t> batchSizes(transferRank, 1);
-  SmallVector<int64_t> outerSizes(transferRank, 1);
-  SmallVector<int64_t> threadCounts(transferRank, 1);
-  SmallVector<int64_t> elementSizes(transferRank, 1);
-
-  SmallVector<int64_t> subgroupStrides(transferRank, 1);
-  SmallVector<int64_t> threadStrides(transferRank, 1);
-
-  int64_t currStrides = 1;
-  for (auto dim : llvm::reverse(order)) {
-    int64_t vectorSize = vectorShape[dim];
-    // Set the element count for the innermost vector dimension.
-    if (residualElements != 1) {
-      elementSizes[dim] = residualElements;
-      vectorSize /= residualElements;
-      residualElements = 1;
-    }
-
-    assert((residualThreads % vectorSize == 0 ||
-            vectorSize % residualThreads == 0) &&
-           "dividing threads to incompatible vector");
-    if (residualThreads <= vectorSize) {
-      vectorSize /= residualThreads;
-      threadCounts[dim] = residualThreads;
-      threadStrides[dim] = currStrides;
-      currStrides *= residualThreads;
-      residualThreads = 1;
-    } else {
-      residualThreads /= vectorSize;
-      threadCounts[dim] = vectorSize;
-      threadStrides[dim] = currStrides;
-      currStrides *= vectorSize;
-      vectorSize = 1;
-    }
-
-    batchSizes[dim] = vectorSize;
-  }
-
-  auto layout = IREE::VectorExt::NestedLayoutAttr::get(
-      context, subgroupCounts, batchSizes, outerSizes, threadCounts,
-      elementSizes, subgroupStrides, threadStrides);
-
-  Location loc = transfer.getLoc();
-  rewriter.setInsertionPointAfter(transfer);
-  auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, transfer.getResult(), layout);
-  rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout);
-
-  return success();
-}
-
-struct LLVMGPUConfigureVectorLayoutsPass final
-    : impl::LLVMGPUConfigureVectorLayoutsPassBase<
-          LLVMGPUConfigureVectorLayoutsPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
-    registry.insert<vector::VectorDialect>();
-  }
-
-  void runOnOperation() override {
-    auto func = getOperation();
-
-    std::array<int64_t, 3> workgroupSize;
-    if (func->hasAttr("workgroup_size")) {
-      auto tmpSizes =
-          llvm::cast<ArrayAttr>(func->getAttr("workgroup_size")).getValue();
-      for (auto [i, size] : llvm::enumerate(tmpSizes)) {
-        workgroupSize[i] = llvm::cast<IntegerAttr>(size).getInt();
-      }
-    } else {
-      std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
-          getWorkgroupSize(func);
-      if (!maybeWorkgroupSize) {
-        func->emitOpError()
-            << "unable to query workgroup_size information from entry point";
-        return signalPassFailure();
-      }
-      for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
-        workgroupSize[index] = value;
-      }
-      for (auto index : llvm::seq<size_t>(maybeWorkgroupSize->size(), 3)) {
-        workgroupSize[index] = 1;
-      }
-    }
-
-    llvm::StringLiteral scheduleAttrName =
-        IREE::GPU::MMAScheduleAttr::getMnemonic();
-    auto scheduleAttr =
-        func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
-    if (!scheduleAttr) {
-      DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
-      scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
-          configDict.get(scheduleAttrName));
-    }
-
-    // Vector layout option setter aimed at contractions. Currently this only
-    // sets anchors for two types of operations; vector.contract and
-    // vector.transfer_read from non-shared memory. The assumption in this case
-    // is that all IR input to this pass has a leaf rooted on a transfer_read or
-    // includes a contraction in the program slice, meaning all operations
-    // should receive layouts. Layout setting for other problems like reductions
-    // is TODO.
-    SmallVector<vector::TransferReadOp> reads;
-
-    func->walk([&](Operation *op) {
-      llvm::TypeSwitch<Operation *>(op).Case(
-          [&](vector::TransferReadOp transfer) { reads.push_back(transfer); });
-    });
-
-    IRRewriter rewriter(func);
-
-    for (vector::TransferReadOp read : reads) {
-      if (failed(setTransferReadAnchor(workgroupSize, rewriter, read))) {
-        return signalPassFailure();
-      }
-    }
-  }
-};
-} // namespace
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 3c7eaf8..59d86f4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -851,15 +851,17 @@
   funcPassManager.addPass(createReorderWorkgroups(
       reorderStrategy, clReorderWorkgroupsLogSwizzleTile,
       canReorderWorkgroups));
+
+  if (usePadToModelSharedMemcpy) {
+    funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
+  }
+
   funcPassManager.addPass(
       IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
 
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
-
-  if (usePadToModelSharedMemcpy) {
-    funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
-  }
+  funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
 
   // Tile to reduction loops.
   {
@@ -918,7 +920,6 @@
   funcPassManager.addPass(createLLVMGPUCastTypeToFitMMAPass());
 
   // Vector SIMD -> Vector SIMT
-  funcPassManager.addPass(createLLVMGPUConfigureVectorLayoutsPass());
   funcPassManager.addPass(createLLVMGPUVectorDistributePass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index 0b8df81..6219d41 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -87,11 +87,6 @@
   let summary = "Pass to set layouts on tensors for later vector distribution";
 }
 
-def LLVMGPUConfigureVectorLayoutsPass :
-    InterfacePass<"iree-llvmgpu-configure-vector-layouts", "mlir::FunctionOpInterface"> {
-  let summary = "Pass to set layouts for vector distribution";
-}
-
 def LLVMGPUConvolutionToIGEMMPass :
     InterfacePass<"iree-llvmgpu-convolution-to-igemm", "mlir::FunctionOpInterface"> {
   let summary = "Pass to convert conv_2d ops to igemm and set a lowering configuration.";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 1088035..ceab1f9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -70,7 +70,6 @@
             "transform_vector_to_mma.mlir",
             "transpose_pipeline_test.mlir",
             "ukernel_pipeline_transform.mlir",
-            "configure_vector_layout.mlir",
             "configure_tensor_layout.mlir",
             "vector_lowering.mlir",
             "vector_to_gpu.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 795ee25..a39df72 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -24,7 +24,6 @@
     "config_matvec.mlir"
     "config_winograd.mlir"
     "configure_tensor_layout.mlir"
-    "configure_vector_layout.mlir"
     "conv_pipeline_test_cuda.mlir"
     "conv_pipeline_test_rocm.mlir"
     "convert_to_nvvm.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
index c439e04..a1b1627 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
@@ -16,7 +16,7 @@
 // OPT-IN:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-IN-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true>
 // OPT-IN-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
@@ -92,7 +92,7 @@
 // OPT-IN:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-IN-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>
 // OPT-IN-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
@@ -164,7 +164,7 @@
 // OPT-OUT:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-OUT-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <None>>
 // OPT-OUT-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index e042780..7e1ab62 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -8,7 +8,7 @@
 // RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
 // RUN:   %s | FileCheck %s --check-prefix=MEMORY
 
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -54,7 +54,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -98,7 +98,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 64, 0], reduction = [0, 0, 0, 0, 128]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -190,7 +190,7 @@
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [10, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<10x128x64x2048xf16>> -> tensor<10x128x64x2048xf16>
         %5 = tensor.empty() : tensor<2x10x64x64xf16>
         %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
-        %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0]}>} {
+        %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0], promote_operands = [0, 1]}>} {
         ^bb0(%in: f16, %in_0: f16, %out: f16):
           %8 = arith.mulf %in, %in_0 : f16
           %9 = arith.addf %8, %out : f16
@@ -217,7 +217,7 @@
 
 // Basic f8, f8 -> f32 matmul.
 
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -263,7 +263,7 @@
 
 // Basic i8, i8 -> i32 matmul.
 
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -309,7 +309,7 @@
 
 // Basic i8, i8 -> i32 matmul_transpose_b.
 
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -353,7 +353,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 128, 0, 0, 0], reduction = [0, 0, 0, 0, 1, 1, 32]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 128, 0, 0, 0], reduction = [0, 0, 0, 0, 1, 1, 32], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -396,7 +396,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 1, 64, 0], reduction = [0, 0, 0, 0, 128]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 1, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
@@ -462,7 +462,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -506,15 +506,15 @@
 // CHECK-DAG:     %[[RHS_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<64x1281x1281xf16, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:     %[[OUT_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : memref<64x968x1281xf16, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:     %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
-// CHECK-DAG:     %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:         %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
+// CHECK-DAG:     %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:         %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:         vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
 // CHECK:         vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
 // CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
 // CHECK-DAG:       %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
-// CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:           %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
+// CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:           %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:           gpu.barrier
 // CHECK-DAG:       %{{.+}} = vector.transfer_read %[[LHS_SHARED]]
@@ -533,7 +533,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -604,7 +604,7 @@
 // NOTE: This test is not exhaustive of all possible ways the above condition is breaking,
 //       but rather is an example of a matmul shape from a model that broke our compilation heuristic.
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 128, 0], reduction = [0, 0, 0, 128]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 128, 0], reduction = [0, 0, 0, 128], promote_operands = [0, 1]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
 
 #pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
@@ -655,7 +655,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -689,8 +689,10 @@
                      affine_map<(d0, d1, d2, d3, d4) -> ()>,
                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
                      lowering_config = #config,
-                     decomposition_config = {qk_attrs = {attention_qk_matmul},
-                                             pv_attrs = {attention_pv_matmul}}}
+                     decomposition_config = {
+                      qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
+                      pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+                     }}
                      ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
                       ^bb0(%score: f32):
                         iree_linalg_ext.yield %score : f32
@@ -726,7 +728,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 64], reduction = [0, 0, 0, 0, 64, 0]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 64], reduction = [0, 0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -761,8 +763,10 @@
                                                          affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
                                                          affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
                                                          lowering_config = #config,
-                                                         decomposition_config = {qk_attrs = {attention_qk_matmul},
-                                                                                 pv_attrs = {attention_pv_matmul}}}
+                                                         decomposition_config = {
+                                                          qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
+                                                          pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+                                                         }}
         ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
               ^bb0(%score: f32):
                 iree_linalg_ext.yield %score : f32
@@ -792,7 +796,7 @@
 
 // -----
 
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 64], reduction = [0, 0, 0, 0, 32, 0]}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 64], reduction = [0, 0, 0, 0, 32, 0], promote_operands = [0, 1, 2]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -827,8 +831,10 @@
                                                          affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
                                                          affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
                                                          lowering_config = #config,
-                                                         decomposition_config = {qk_attrs = {attention_qk_matmul},
-                                                                                 pv_attrs = {attention_pv_matmul}}}
+                                                         decomposition_config = {
+                                                          qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
+                                                          pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+                                                         }}
         ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
               ^bb0(%score: f32):
                 iree_linalg_ext.yield %score : f32
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
index e9b597d..032bd68 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
@@ -41,7 +41,7 @@
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
 //      CHECK:   ^bb
 //      CHECK:     linalg.matmul
-// CHECK-SAME:         lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 32], workgroup = [64, 64, 0]}>
+// CHECK-SAME:         lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], reduction = [0, 0, 32], workgroup = [64, 64, 0]}>
 //      CHECK:   iree_linalg_ext.yield
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 8b761f8..6c96d15 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -42,8 +42,8 @@
 
 // CHECK-LABEL: func.func @matmul_96x64x16_mfma
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
 // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
@@ -93,8 +93,8 @@
 
 // CHECK-LABEL: func.func @matmul_96x64x16_wmma
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
 // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
@@ -144,8 +144,8 @@
 
 // CHECK-LABEL: func.func @matmul_128x64x16_multi_subgroup
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
@@ -168,7 +168,8 @@
 
 #traits = {
   indexing_maps = #maps,
-  iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
+  lowering_config = #iree_gpu.lowering_config<{promote_operands = [0]}>
 }
 
 func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
@@ -196,8 +197,80 @@
 // CHECK-LABEL: func.func @packed_matmul_128x128x128
 
 // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
 // CHECK-SAME: outs(%[[ACC]]
+
+// -----
+
+// TODO: We shouldn't have to specify mma_schedule here.
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64,
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+                                             subgroup_m_count = 1,
+                                             subgroup_n_count = 1>}>
+
+func.func @linalg_copy(%in : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
+                      attributes { translation_info = #translation } {
+  %empty = tensor.empty() : tensor<16x16x16xf16>
+  %copied = linalg.copy
+            { lowering_config = #iree_gpu.derived_thread_config }
+            ins(%in : tensor<16x16x16xf16>)
+            outs(%empty : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
+  func.return %copied : tensor<16x16x16xf16>
+}
+
+// CHECK-DAG: #[[$LAYOUT:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [1, 1, 1], batch_tile = [8, 1, 1], outer_tile = [1, 1, 1], thread_tile = [2, 16, 2], element_tile = [1, 1, 8], subgroup_strides = [0, 0, 0], thread_strides = [32, 2, 1]>
+
+// CHECK-LABEL: func.func @linalg_copy
+// CHECK: %[[OUT:.+]] = linalg.copy
+// CHECK: to_layout %[[OUT]] to layout(#[[$LAYOUT]])
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64,
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+                                             subgroup_m_count = 1,
+                                             subgroup_n_count = 1>}>
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+
+#gather_trait = {
+    indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
+    lowering_config = #iree_gpu.derived_thread_config
+}
+
+func.func @gather_like(%base : tensor<16384x16x32x128xf16>,
+                       %indices : tensor<4x64x4xi64>)
+                       -> tensor<4x64x4x16x32x128xf16>
+                       attributes { translation_info = #translation } {
+
+  %empty = tensor.empty() : tensor<4x64x4x16x32x128xf16>
+  %gather = linalg.generic #gather_trait
+            ins(%indices : tensor<4x64x4xi64>)
+            outs(%empty : tensor<4x64x4x16x32x128xf16>) {
+  ^bb0(%in: i64, %out: f16):
+    %idx = arith.index_cast %in : i64 to index
+    %iv3 = linalg.index 3 : index
+    %iv4 = linalg.index 4 : index
+    %iv5 = linalg.index 5 : index
+    %extracted = tensor.extract %base[%idx, %iv3, %iv4, %iv5] : tensor<16384x16x32x128xf16>
+    linalg.yield %extracted : f16
+  } -> tensor<4x64x4x16x32x128xf16>
+
+  func.return %gather : tensor<4x64x4x16x32x128xf16>
+}
+
+// CHECK-DAG: #[[$LAYOUT:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [1, 1, 1, 1, 1, 1], batch_tile = [4, 64, 4, 16, 8, 1], outer_tile = [1, 1, 1, 1, 1, 1], thread_tile = [1, 1, 1, 1, 4, 16], element_tile = [1, 1, 1, 1, 1, 8], subgroup_strides = [0, 0, 0, 0, 0, 0], thread_strides = [0, 0, 0, 0, 16, 1]>
+
+// CHECK-LABEL: func.func @gather_like
+// CHECK: %[[OUT:.+]] = linalg.generic
+// CHECK: to_layout %[[OUT]] to layout(#[[$LAYOUT]])
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir
deleted file mode 100644
index 3bac228..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir
+++ /dev/null
@@ -1,75 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-configure-vector-layouts, canonicalize, cse))' %s | FileCheck %s
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
-                                              workgroup_size = [64, 1, 1]
-                                              subgroup_size = 64,
-      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1>}>
-
-// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [1, 1], outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 8], subgroup_strides = [0, 0], thread_strides = [4, 1]>
-// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [1, 1], outer_tile = [1, 1], thread_tile = [4, 16], element_tile = [8, 1], subgroup_strides = [0, 0], thread_strides = [1, 4]>
-
-// CHECK-LABEL: func.func @transfer_read_permute
-func.func @transfer_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
-                                         %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
-                                         %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes { translation_info = #translation } {
-
-  %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
-  %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
-  %cst = arith.constant 0.000000e+00 : f16
-  %cst_f32 = arith.constant 0.000000e+00 : f32
-  %c32 = arith.constant 32 : index
-  %c256 = arith.constant 256 : index
-  %c0 = arith.constant 0 : index
-  %6 = vector.transfer_read %lhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
-  %7 = vector.transfer_read %rhs[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
-  // CHECK: %[[READ0:.+]] = vector.transfer_read
-  // CHECK: to_layout %[[READ0]] to layout(#[[$NESTED]])
-  // CHECK: %[[READ1:.+]] = vector.transfer_read
-  // CHECK: to_layout %[[READ1]] to layout(#[[$NESTED1]])
-  vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
-  vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
-  memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space<workgroup>>
-  memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>
-  return
-}
-
-// -----
-
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
-                                              workgroup_size = [32, 4, 1]
-                                              subgroup_size = 32,
-      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
-
-// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [4, 1], outer_tile = [1, 1], thread_tile = [32, 4], element_tile = [1, 32], subgroup_strides = [0, 0], thread_strides = [4, 1]>
-
-// CHECK-LABEL: func.func @dequant_anchors_on_quant_only
-func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
-                                  %scale: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>,
-                                  %zp: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes { translation_info = #translation } {
-  %alloc = memref.alloc() : memref<128x128xf16, #gpu.address_space<workgroup>>
-  %cst = arith.constant 0.000000e+00 : f16
-  %cst_0 = arith.constant 0.000000e+00 : f32
-  %c32 = arith.constant 32 : index
-  %c256 = arith.constant 256 : index
-  %c0_i4 = arith.constant 0 : i4
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_read %quant[%c0, %c0], %c0_i4 {in_bounds = [true, true]} : memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128x128xi4>
-  // CHECK: %[[READ:.+]] = vector.transfer_read
-  // CHECK: to_layout %[[READ]] to layout(#[[$NESTED]])
-  %1 = vector.transfer_read %scale[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
-  %2 = vector.broadcast %1 : vector<128xf16> to vector<128x128xf16>
-  %3 = vector.transpose %2, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
-  %4 = vector.transfer_read %zp[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128xf16>
-  %5 = vector.broadcast %4 : vector<128xf16> to vector<128x128xf16>
-  %6 = vector.transpose %5, [1, 0] : vector<128x128xf16> to vector<128x128xf16>
-  %7 = arith.extui %0 : vector<128x128xi4> to vector<128x128xi32>
-  %8 = arith.uitofp %7 : vector<128x128xi32> to vector<128x128xf16>
-  %9 = arith.subf %8, %6 : vector<128x128xf16>
-  %10 = arith.mulf %9, %3 : vector<128x128xf16>
-  vector.transfer_write %10, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<128x128xf16>, memref<128x128xf16, #gpu.address_space<workgroup>>
-  return
-}