[VectorDistribution] Add distribution pattern for vector::MultiDimReductionOp (#17076)

This patch adds support for distribution vector::MultiDimReductionOp.
The distribution works in two steps:

- Local reduction: add a vector.multi_reduction to reduce across
batches, outers, element
- Subgroup reduction: Use gpu.shuffle to reduce among threads
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 20a3d54..f0859e8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -5,18 +5,17 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include <cstdint>
+#include <numeric>
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
@@ -459,13 +458,186 @@
   }
 };
 
+static int64_t getShuffleOffset(NestedLayoutAttr layout, int64_t dim) {
+  // Get strides for dimensions based on layouts.
+  SmallVector<int64_t> threadBasis(layout.getThreadBasis());
+  SmallVector<int64_t> basisStrides(threadBasis.size());
+  // Take prefix sum to get strides.
+  std::exclusive_scan(threadBasis.rbegin(), threadBasis.rend(),
+                      basisStrides.rbegin(), 1, std::multiplies<>{});
+  // Remove non-active thread ids.
+  SmallVector<int64_t> activeThreadStrides;
+  for (auto [i, stride] : llvm::enumerate(basisStrides)) {
+    if (layout.getThreadActiveIds()[i]) {
+      activeThreadStrides.push_back(stride);
+    }
+  }
+  // TODO: Do we need to do inversion or not?
+  return activeThreadStrides[layout.getThreadOrder()[dim]];
+}
+
+static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) {
+  return layout.getThreadsPerOuter()[dim];
+}
+
+/// The lowering for multi_reduction is done in two steps:
+///   1. Local Reduce: Each thread reduces all elements carried by it along
+///      the reduction dimensions. This is the batch, outer and element dims.
+///   2. Thread Reduce: Each thread reduces result of step 1 across threads
+///      by doing a butterfly shuffle.
+///
+/// Currently, reduction across warps is not supported, but it would just add
+/// another step, Warp Reduce, where threads do an atomic addition on a buffer.
+struct DistributeMultiReduction final
+    : OpDistributionPattern<vector::MultiDimReductionOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  DistributeMultiReduction(MLIRContext *context, int64_t subgroupSize,
+                           int64_t maxBitsPerShuffle, int64_t benefit = 1)
+      : OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
+        maxBitsPerShuffle(maxBitsPerShuffle) {}
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReduceOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    VectorValue srcVector = multiReduceOp.getSource();
+    auto accVector = dyn_cast<VectorValue>(multiReduceOp.getAcc());
+    if (!accVector) {
+      return rewriter.notifyMatchFailure(
+          multiReduceOp, "unimplemented: scalar accumulator distribution");
+    }
+    auto resVector = dyn_cast<VectorValue>(multiReduceOp.getResult());
+    if (!resVector) {
+      return rewriter.notifyMatchFailure(
+          multiReduceOp, "unimplemented: scalar result distribution");
+    }
+
+    auto srcLayout = dyn_cast_or_null<NestedLayoutAttr>(signature[srcVector]);
+    if (!srcLayout) {
+      return rewriter.notifyMatchFailure(multiReduceOp,
+                                         "expected nested layout attr");
+    }
+
+    Type elemTy = srcVector.getType().getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth != maxBitsPerShuffle) {
+      return rewriter.notifyMatchFailure(
+          multiReduceOp, llvm::formatv("unimplemented: packed shuffle",
+                                       elemBitwidth, maxBitsPerShuffle));
+    }
+
+    VectorValue disSrc =
+        getDistributed(rewriter, srcVector, signature[srcVector]);
+    VectorValue disAcc =
+        getDistributed(rewriter, accVector, signature[accVector]);
+
+    Location loc = multiReduceOp.getLoc();
+
+    SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
+    int64_t rank = srcVector.getType().getRank();
+
+    // Do thread local reduce.
+
+    SmallVector<bool> distributedReductionMask;
+    distributedReductionMask.reserve(3 * rank);
+    distributedReductionMask.append(
+        applyPermutation(reducedDims, srcLayout.getBatchOrder()));
+    distributedReductionMask.append(
+        applyPermutation(reducedDims, srcLayout.getOuterOrder()));
+    distributedReductionMask.append(
+        applyPermutation(reducedDims, srcLayout.getElementOrder()));
+
+    auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
+        loc, disSrc, disAcc, distributedReductionMask, multiReduceOp.getKind());
+    auto locallyReduced = dyn_cast<VectorValue>(localReduction.getResult());
+
+    assert(locallyReduced && "result should have been a vector");
+
+    // Flatten the locally reduced value.
+    VectorType shaped = locallyReduced.getType();
+    int64_t numElements = shaped.getNumElements();
+    SmallVector<int64_t> flatShape(1, numElements);
+    VectorType flatVecType = VectorType::get(flatShape, elemTy);
+    VectorValue flat =
+        rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
+
+    FailureOr<VectorValue> threadReduced = doThreadReduction(
+        rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
+    if (failed(threadReduced)) {
+      return failure();
+    }
+
+    VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
+        loc, shaped, threadReduced.value());
+    replaceOpWithDistributedValues(rewriter, multiReduceOp, unflattened);
+
+    return failure();
+  }
+
+  FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
+                                           NestedLayoutAttr layout,
+                                           VectorValue flat,
+                                           vector::CombiningKind kind,
+                                           ArrayRef<bool> reductionMask) const {
+    VectorType flatVecType = flat.getType();
+    int64_t numElements = flatVecType.getNumElements();
+    Location loc = flat.getLoc();
+
+    auto constOp = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(flatVecType));
+    auto res = llvm::cast<VectorValue>(constOp.getResult());
+
+    for (unsigned i = 0; i < numElements; ++i) {
+      Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);
+
+      // Reduce across all reduction dimensions 1-by-1.
+      for (unsigned i = 0; i < reductionMask.size(); ++i) {
+        if (reductionMask[i]) {
+          extracted = doPackedThreadReductionOnDim(rewriter, layout, extracted,
+                                                   kind, i);
+        }
+      }
+
+      res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
+    }
+
+    return res;
+  }
+
+  Value doPackedThreadReductionOnDim(RewriterBase &rewriter,
+                                     NestedLayoutAttr layout, Value val,
+                                     vector::CombiningKind kind,
+                                     int64_t dim) const {
+    Location loc = val.getLoc();
+    int64_t offset = getShuffleOffset(layout, dim);
+    int64_t width = getShuffleWidth(layout, dim);
+
+    for (int i = offset; i < offset * width; i <<= 1) {
+      auto shuffleOp = rewriter.create<gpu::ShuffleOp>(
+          loc, val, i, subgroupSize, gpu::ShuffleMode::XOR);
+      val =
+          makeArithReduction(rewriter, loc, kind, shuffleOp.getShuffleResult(),
+                             val, nullptr, nullptr);
+    }
+
+    return val;
+  }
+
+  int64_t subgroupSize;
+  int64_t maxBitsPerShuffle;
+};
+
 } // namespace
 
-void populateGPUDistributeNestedLayoutAttrPatterns(
-    Value threadId, RewritePatternSet &patterns) {
+void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
+                                                   Value threadId,
+                                                   int64_t subgroupSize,
+                                                   int64_t maxBitsPerShuffle) {
   patterns.add<DistributeTransferRead, DistributeTransferWrite>(
       patterns.getContext(), threadId);
   patterns.add<DistributeBroadcast>(patterns.getContext());
+  patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
+                                         maxBitsPerShuffle);
 }
 
 }; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
index cdda3f9..8730384 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
@@ -37,8 +37,9 @@
 void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns,
                                               int64_t maxBitsPerShuffle = 32);
 
-void populateGPUDistributeNestedLayoutAttrPatterns(Value threadId,
-                                                   RewritePatternSet &patterns);
+void populateGPUDistributeNestedLayoutAttrPatterns(
+    RewritePatternSet &patterns, Value threadId, int64_t subgroupSize,
+    int64_t maxBitsPerShuffle = 32);
 
 // Adds patterns that distributes vector.contract ops with nested layout
 // annotations to amdgpu.mfma ops.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index 06b7bca..d93b3c2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -1002,3 +1002,99 @@
 
 // CHECK:         %[[T7:.+]] = vector.transpose %[[RD7]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
 // CHECK:         vector.transfer_write %[[T7]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID3]], %[[DIM0_ID]]]
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1], 
+  // We are reducing along dim=1, so each thread will reduce 
+  // 2 batches x 4 elements = 8 elements.
+  batches_per_subgroup = [2, 2], 
+  outers_per_batch = [1, 1], 
+  // We are reducing on dim=1, which is distributed over 4 threads. Based
+  // on the subgroup basis and thread order, the shuffle offset is 16.
+  threads_per_outer = [16, 4], 
+  elements_per_thread = [1, 4], 
+
+  subgroup_order = [1, 0], 
+  batch_order = [1, 0], 
+  outer_order = [1, 0], 
+  thread_order = [1, 0], 
+
+  subgroup_basis = [1, 1], 
+  thread_basis = [4, 16]
+>
+
+func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1 
+  {
+    __vector_layout_test_anchor_operand_0 = #nested
+  } [1] : vector<32x32xf32> to vector<32xf32>
+  return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
+// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
+// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [0, 2, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// Global reduction
+// CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
+// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
+// CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
+// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1xf32> -> vector<32xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  // We are reducing along dim=1, so each thread will reduce 
+  // 4 batches x 4 elements = 16 elements.
+  batches_per_subgroup    = [1, 4],
+  outers_per_batch        = [1, 1],
+  // We are reducing on dim=1, which is distributed over 2 threads. Based
+  // on the subgroup basis and thread order, the shuffle offset is 32.
+  threads_per_outer       = [32, 2],
+  elements_per_thread     = [1, 4],
+
+  thread_order            = [1, 0],
+
+  subgroup_basis          = [1, 1],
+  thread_basis            = [2, 32]
+>
+
+func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1 
+  {
+    __vector_layout_test_anchor_operand_0 = #nested
+  } [1] : vector<32x32xf32> to vector<32xf32>
+  return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
+// Global reduction
+// CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 6fe8e16..dac7e3c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -933,7 +933,9 @@
   populateGPUDistributionPatterns(patterns);
   populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
   populateGPUReductionDistributionPatterns(patterns);
-  populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
+  // For testing we use subgroup size = 64.
+  populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
+                                                /*subgroupSize=*/64);
   populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
   if (getExperimental())
     populateGPULayoutResolutionDistributionPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index 619a19d..ed14df5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -56,13 +56,15 @@
   ContractionVectorLayoutOptions(Operation *root,
                                  ArrayRef<int64_t> workgroupSize,
                                  IREE::GPU::MMAScheduleAttr schedule,
-                                 Value laneId, bool printLayout)
+                                 Value laneId, int64_t subgroupSize,
+                                 bool printLayout)
       : VectorLayoutOptions(root, /*fullConversion=*/!printLayout),
         workgroupSize(workgroupSize), schedule(schedule),
         printLayout(printLayout), patterns(root->getContext()) {
     populateGPUDistributionPatterns(patterns);
     populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
-    populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
+    populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
+                                                  subgroupSize);
   }
 
   LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override {
@@ -371,7 +373,9 @@
       std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
           getWorkgroupSize(func);
       if (!maybeWorkgroupSize) {
-        return;
+        func->emitOpError()
+            << "unable to query workgroup_size information from entry point";
+        return signalPassFailure();
       }
       for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
         workgroupSize[index] = value;
@@ -408,8 +412,16 @@
     Value linearThreadIdVal = affine::makeComposedAffineApply(
         builder, func.getLoc(), linearId, threadGrid);
 
+    std::optional<int64_t> subgroupSize = getSubgroupSize(func);
+    if (!subgroupSize) {
+      func->emitOpError()
+          << "unable to query subgroup size information from entry point";
+      return signalPassFailure();
+    }
+
     ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr,
-                                           linearThreadIdVal, testLayout);
+                                           linearThreadIdVal,
+                                           subgroupSize.value(), testLayout);
     if (failed(distributeVectorOps(func, options.getPatterns(), options))) {
       func->emitOpError() << "failed to distribute";
       return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 6ec034c..e73d935 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -1478,7 +1478,9 @@
   populateGPUDistributionPatterns(patterns);
   populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
   populateGPUReductionDistributionPatterns(patterns);
-  populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
+  // For testing we use subgroup size = 64.
+  populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
+                                                /*subgroupSize=*/64);
   populateAMDGPUDistributionPatterns(patterns);
   populateGPULayoutResolutionDistributionPatterns(patterns);
   if (failed(distributeVectorOps(target, patterns, options))) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
index ab30f16..bcd47c4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
@@ -1,12 +1,14 @@
 // RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @mfma_matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
@@ -54,13 +56,15 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @mfma_matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
@@ -120,13 +124,15 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [32, 1, 1]
+                                              subgroup_size = 32, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @wmma_matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [32, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
@@ -178,13 +184,15 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [32, 1, 1]
+                                              subgroup_size = 32, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @wmma_matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                               %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [32, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 66da55c..520f7f9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -1,10 +1,11 @@
 // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute{test-layout}, canonicalize, cse))' %s | FileCheck %s
 
-func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
-      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>}>
+
+func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -27,11 +28,12 @@
 
 // -----
 
-func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
-      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>}>
+
+func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -54,11 +56,12 @@
 
 // -----
 
-func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
-      subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 2, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 2, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>, subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>}>
+
+func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -72,13 +75,15 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                  %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                  %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
@@ -129,13 +134,16 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
+
 func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                          %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                          %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
   %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
@@ -190,14 +198,16 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 1, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>}>
+
 func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>,
                                   %rhs: memref<32x16xf16>,
                                   %bias: memref<16x16xf32>,
                                   %out: memref<16x16xf32>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-                     subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
-    workgroup_size = [64, 1, 1]} {
+  attributes { translation_info = #translation } {
   %cst = arith.constant 0.000000e+00 : f16
   %cst_f32 = arith.constant 0.000000e+00 : f32
   %c32 = arith.constant 32 : index
@@ -228,11 +238,12 @@
 
 // -----
 
-func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
-      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [32, 1, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [32, 1, 1]
+                                              subgroup_size = 32, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>}>
+
+func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -255,11 +266,12 @@
 
 // -----
 
-func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
-      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
-    workgroup_size = [32, 1, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [32, 1, 1]
+                                              subgroup_size = 32, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>}>
+
+func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -282,11 +294,13 @@
 
 // -----
 
-func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-      subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 4, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
-    workgroup_size = [64, 2, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 2, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 4, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>}>
+
+
+func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
       iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -328,11 +342,12 @@
 
 // -----
 
-func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-      subgroup_m_count = 4, subgroup_n_count = 1, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
-    workgroup_size = [64, 2, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [64, 2, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 4, subgroup_n_count = 1, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>}>
+
+func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
       iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -353,11 +368,12 @@
 
 // -----
 
-func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
-      subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 8, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>,
-    workgroup_size = [128, 2, 1]} {
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [128, 2, 1]
+                                              subgroup_size = 64, 
+      {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 8, subgroup_n_tile_count = 4, subgroup_k_tile_count = 1>, workgroup_size = [128, 2, 1]}>
+
+func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes { translation_info = #translation } {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
       iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -383,13 +399,15 @@
 
 // -----
 
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute 
+                                              workgroup_size = [32, 4, 1]
+                                              subgroup_size = 32, 
+      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4, subgroup_m_tile_count = 1, subgroup_n_tile_count = 2, subgroup_k_tile_count = 8>}>
+
 func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                   %scale: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>,
                                   %zp: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-  attributes {
-    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
-                   subgroup_m_count = 1, subgroup_n_count = 4, subgroup_m_tile_count = 1, subgroup_n_tile_count = 2, subgroup_k_tile_count = 8>,
-    workgroup_size = [32, 4, 1]} {
+  attributes { translation_info = #translation } {
   %alloc = memref.alloc() : memref<128x128xf16, #gpu.address_space<workgroup>>
   %cst = arith.constant 0.000000e+00 : f16
   %cst_0 = arith.constant 0.000000e+00 : f32
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 6b0d760..0234765 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -312,8 +312,14 @@
   };
   SmallVector<bool> subgroupMask(getSubgroupActiveIds());
   SmallVector<bool> threadMask(getThreadActiveIds());
-  composeMasks(subgroupMask, applyPermutation(droppedDims, getSubgroupOrder()));
-  composeMasks(threadMask, applyPermutation(droppedDims, getThreadOrder()));
+
+  SmallVector<bool> invertedDroppedThreadMask =
+      applyPermutation(droppedDims, invertPermutationVector(getThreadOrder()));
+  composeMasks(subgroupMask, invertedDroppedThreadMask);
+
+  SmallVector<bool> invertedDroppedSubgroupMask =
+      applyPermutation(droppedDims, invertPermutationVector(getThreadOrder()));
+  composeMasks(threadMask, invertedDroppedSubgroupMask);
 
   return NestedLayoutAttr::get(getContext(), subgroupCount, subgroupOrder,
                                batchCount, batchOrder, outerCount, outerOrder,
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index ba61828..d369aa5 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -7,6 +7,7 @@
 """iree_generated_e2e_matmul_test generator for e2e matmul tests.
 """
 
+from typing import Optional
 import argparse
 import os
 import yaml
@@ -115,6 +116,7 @@
     mma_schedule: typing.Optional[MMASchedule]
     # Compilation info
     workgroup_size: typing.List[int]
+    subgroup_size: Optional[int] = None
 
     # Prints the workgroup size
     def workgroup_size_str(self):
@@ -286,6 +288,9 @@
                 workgroup_size=workgroup_size,
                 software_pipeline_depth=0,
                 mma_schedule=schedule,
+                # TODO: This is only valid for gfx9. Change this for RDNA3
+                # architectures.
+                subgroup_size=64,
             )
         )
     return infos
@@ -533,11 +538,16 @@
         mma_schedule = ""
         if compilation_info.mma_schedule is not None:
             mma_schedule = ", {}".format(compilation_info.mma_schedule)
+        subgroup_size_str = ""
+        if compilation_info.subgroup_size is not None:
+            subgroup_size_str = f"subgroup_size = {compilation_info.subgroup_size}"
+
         compilation_info_string = (
             f"#compilation{generate_function.compilation_index} = "
             "#iree_codegen.compilation_info<\n"
             f"  lowering_config = <tile_sizes = {compilation_info.tile_sizes}>,\n"
-            f"  translation_info = <{compiler_pipeline} {compilation_info.workgroup_size_str()},\n"
+            f"  translation_info = <{compiler_pipeline} {compilation_info.workgroup_size_str()}\n"
+            f"  {subgroup_size_str},\n"
             f"  {{ pipeline_depth = {compilation_info.software_pipeline_depth}, "
             f"  store_stage = 1{mma_schedule} }}>>\n"
         )