[VectorDistribution] Remove iteration on non-distributed levels (#17096)

This patch removes iteration order in the layout for non-distributed
levels. The layout now only represents how the register is tiled, and
does not define an iteration order over it. The patterns themselves
define an iteration order over the tiled layout.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index efe43d3..4088008 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -4,6 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include <numeric>
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
@@ -97,12 +98,12 @@
     LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");
 
     // Offsets into the LHS/RHS batches.
-    SmallVector<int64_t, 2> lhsBatchOffsets(lhsLayout.getRank(), 0);
-    SmallVector<int64_t, 2> rhsBatchOffsets(rhsLayout.getRank(), 0);
+    SmallVector<int64_t> lhsBatchOffsets(rank, 0);
+    SmallVector<int64_t> rhsBatchOffsets(rank, 0);
 
     // Offsets into the result batches.
     ArrayRef<int64_t> resultBatches = resultLayout.getBatchesPerSubgroup();
-    SmallVector<int64_t, 2> resultBatchTileSizes(rank, 1);
+    SmallVector<int64_t> resultBatchTileSizes(rank, 1);
     LLVM_DEBUG({
       llvm::dbgs() << "result batches: [";
       llvm::interleaveComma(resultBatches, llvm::dbgs());
@@ -114,18 +115,22 @@
     Value lhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
     Value rhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
 
+    SmallVector<AffineMap> indexingMaps = contractOp.getIndexingMapsArray();
+    AffineMap lhsMap = compressUnusedDims(indexingMaps[0]);
+    AffineMap rhsMap = compressUnusedDims(indexingMaps[1]);
+    AffineMap resMap = compressUnusedDims(indexingMaps[2]);
+
+    SmallVector<int64_t> resBatchOrder(resMap.getNumResults());
+    std::iota(resBatchOrder.begin(), resBatchOrder.end(), 0);
+    resBatchOrder = applyPermutationMap(resMap, ArrayRef(resBatchOrder));
+
     // Iterate over all result batches and unroll computation to direct MFMA
     // intrinsic ops.
     Location loc = contractOp.getLoc();
     auto resultTiles = StaticTileOffsetRange(
-        resultBatches, resultBatchTileSizes, resultLayout.getBatchOrder());
+        resultBatches, resultBatchTileSizes, resBatchOrder);
     SmallVector<int64_t, 2> resultBatchOffsets;
-    for (SmallVector<int64_t, 2> originalResultBatchOffsets : resultTiles) {
-      // Permute the result batch offsets first to match the distributed shape
-      // dim order for indexing.
-      resultBatchOffsets = originalResultBatchOffsets;
-      applyPermutationToVector(resultBatchOffsets,
-                               resultLayout.getBatchOrder());
+    for (SmallVector<int64_t, 2> resultBatchOffsets : resultTiles) {
       LLVM_DEBUG({
         llvm::dbgs() << "current result batch offsets: [";
         llvm::interleaveComma(resultBatchOffsets, llvm::dbgs());
@@ -151,9 +156,9 @@
         // Fills the batch offsets for LHS and RHS. For the K dimension it's the
         // induction variable; for the M/N dimension we need to extract from the
         // result batch offsets.
-        fillOperandBatchOffsets(opDetail, k, originalResultBatchOffsets,
-                                resultLayout, lhsBatchOffsets, rhsBatchOffsets,
-                                lhsLayout, rhsLayout);
+        fillOperandBatchOffsets(opDetail, k, resultBatchOffsets,
+                                lhsBatchOffsets, rhsBatchOffsets, lhsMap,
+                                rhsMap);
         LLVM_DEBUG({
           llvm::dbgs() << "current lhs batch offsets: [";
           llvm::interleaveComma(lhsBatchOffsets, llvm::dbgs());
@@ -196,12 +201,10 @@
   // both LHS and RHS.
   void fillOperandBatchOffsets(const VectorContractOpInfo &opDetail,
                                int64_t kOffset, ArrayRef<int64_t> resultOffsets,
-                               NestedLayoutAttr resultLayout,
-                               SmallVector<int64_t, 2> &lhsOffsets,
-                               SmallVector<int64_t, 2> &rhsOffsets,
-                               NestedLayoutAttr lhsLayout,
-                               NestedLayoutAttr rhsLayout) const {
-
+                               SmallVector<int64_t> &lhsOffsets,
+                               SmallVector<int64_t> &rhsOffsets,
+                               AffineMap lhsMap, AffineMap rhsMap) const {
+    auto [lhsK, rhsK] = opDetail.getOperandKIndex();
     // resultOffsets contains batch indices into the C/D vector. It is a 2-D
     // index for both M and N. We need to split out for M and N, and add index
     // for K.
@@ -214,13 +217,8 @@
       rhsOffsets[rhsN] = resultOffsets[resultN];
     }
 
-    auto [lhsK, rhsK] = opDetail.getOperandKIndex();
     lhsOffsets[lhsK] = kOffset;
     rhsOffsets[rhsK] = kOffset;
-
-    // Now apply permutation on LHS/RHS according to their batch order.
-    applyPermutationToVector(lhsOffsets, lhsLayout.getBatchOrder());
-    applyPermutationToVector(rhsOffsets, rhsLayout.getBatchOrder());
   }
 
   struct AMDMMAParameters {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 1887c28..94a046a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -616,17 +616,18 @@
   }
 };
 
-struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
+struct DistributeTransposeLayoutAttr final
+    : OpDistributionPattern<vector::TransposeOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 DistributionSignature &signature,
                                 PatternRewriter &rewriter) const override {
     VectorValue value = transposeOp.getVector();
-    VectorLayoutInterface layout =
-        dyn_cast<VectorLayoutInterface>(signature[value]);
+    VectorLayoutInterface layout = dyn_cast<LayoutAttr>(signature[value]);
     if (!layout) {
-      return failure();
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "layout must be LayoutAttr");
     }
 
     /// Transpose only changes the notion of where the data carried by each
@@ -849,7 +850,7 @@
   patterns
       .add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
           patterns.getContext(), laneId);
-  patterns.add<DistributeBroadcastLayoutAttr, DistributeTranspose>(
+  patterns.add<DistributeBroadcastLayoutAttr, DistributeTransposeLayoutAttr>(
       patterns.getContext());
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index f0859e8..f8122da 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -81,12 +81,8 @@
   int64_t rank = vectorLayout.getRank();
   // Permute the batch and outer vector offsets to match the order of
   // the vector dimensions using the inverse of the batch/offset order.
-  SmallVector<int64_t> batchOffsets =
-      applyPermutation(ArrayRef<int64_t>(offsets.begin(), rank),
-                       invertPermutationVector(vectorLayout.getBatchOrder()));
-  SmallVector<int64_t> outerVectorOffsets =
-      applyPermutation(ArrayRef<int64_t>(offsets.begin() + rank, rank),
-                       invertPermutationVector(vectorLayout.getOuterOrder()));
+  ArrayRef<int64_t> batchOffsets(offsets.begin(), rank);
+  ArrayRef<int64_t> outerVectorOffsets(offsets.begin() + rank, rank);
 
   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
   for (const auto &[i, dim] : llvm::enumerate(permutationMap.getResults())) {
@@ -111,29 +107,6 @@
   return slicedIndices;
 }
 
-static SmallVector<int64_t> getLoopOrder(NestedLayoutAttr vectorLayout) {
-  int64_t rank = vectorLayout.getRank();
-  // Let the unroll order first unroll the batch dimensions, then the
-  // outer vector dimensions. We unroll in the order specified by the
-  // layout.
-  SmallVector<int64_t> loopOrder;
-  int64_t base = 0;
-  for (auto b : vectorLayout.getBatchOrder()) {
-    loopOrder.push_back(base + b);
-  }
-  base += rank;
-  // We must unroll along the outer dimensions as well to match the rank
-  // requirements of vector transfer ops (<= memref rank up to broadcasts).
-  for (auto o : vectorLayout.getOuterOrder()) {
-    loopOrder.push_back(base + o);
-  }
-  base += rank;
-  for (int i = 0, e = rank; i < e; ++i) {
-    loopOrder.push_back(base + i);
-  }
-  return loopOrder;
-}
-
 static SmallVector<int64_t>
 getElementVectorTileShape(NestedLayoutAttr vectorLayout) {
   int64_t rank = vectorLayout.getRank();
@@ -215,7 +188,6 @@
 
     SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
     SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
-    SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
     int64_t rank = vectorLayout.getRank();
 
     Type elementType = readOp.getSource().getType().getElementType();
@@ -238,7 +210,7 @@
     ValueRange indices = readOp.getIndices();
     SmallVector<int64_t> strides(rank, 1);
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+         StaticTileOffsetRange(distShape, tileShape)) {
       SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
           rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
           warpIndices, threadIndices);
@@ -247,19 +219,6 @@
           readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
-      // Transpose to the element order.
-      //
-      // A = transfer_read
-      // B = transpose A
-      //
-      // P(A) = I
-      //
-      // P(A) * perm = P(B)
-      // perm = P(B)
-      if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
-        slicedRead = rewriter.create<vector::TransposeOp>(
-            slicedRead.getLoc(), slicedRead, vectorLayout.getElementOrder());
-      }
 
       acc = rewriter.create<vector::InsertStridedSliceOp>(
           readOp.getLoc(), slicedRead, acc, offsets, strides);
@@ -302,7 +261,6 @@
 
     SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
     SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
-    SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
     int64_t rank = vectorLayout.getRank();
 
     SmallVector<Value> warpIndices, threadIndices;
@@ -314,7 +272,7 @@
 
     ValueRange indices = writeOp.getIndices();
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+         StaticTileOffsetRange(distShape, tileShape)) {
       SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
           rewriter, indices, offsets, vectorLayout, writeOp.getPermutationMap(),
           warpIndices, threadIndices);
@@ -326,20 +284,6 @@
       Value slicedVector = rewriter.create<vector::ExtractOp>(
           writeOp.getLoc(), distributedVector,
           offsetArray.take_front(rank * 2));
-      // Transpose to the native dimension order.
-      // B = transpose(A)
-      // transfer_write B
-      //
-      // P(B) = I
-      //
-      // P(A) * perm = P(B)
-      // P(A) * perm = I
-      // perm = P(A) ^ -1
-      if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
-        slicedVector = rewriter.create<vector::TransposeOp>(
-            slicedVector.getLoc(), slicedVector,
-            invertPermutationVector(vectorLayout.getElementOrder()));
-      }
       rewriter.create<vector::TransferWriteOp>(
           writeOp.getLoc(), slicedVector, writeOp.getSource(), slicedIndices,
           writeOp.getPermutationMapAttr(), writeOp.getMask(),
@@ -388,53 +332,39 @@
         loc, vectorType, rewriter.getZeroAttr(vectorType));
 
     int64_t rank = vectorLayout.getRank();
-    int64_t sourceRank = sourceLayout.getRank();
     // We unroll along both the batch and outer dimensions for a similar reason
     // to the transfer ops. `vector.broadcast` can only broadcast along outer
     // dims, so mixing broadcasted and un-broadcasted element/outer dims can't
     // be represented with a single `vector.broadcast`.
     SmallVector<int64_t> resultVectorUnrollShape =
         getElementVectorTileShape(vectorLayout);
-    SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
 
     Value distributedSource = getDistributed(rewriter, srcVector, sourceLayout);
 
     VectorType broadcastTargetType =
-        VectorType::get(applyPermutation(vectorLayout.getElementsPerThread(),
-                                         vectorLayout.getElementOrder()),
-                        elementType);
+        VectorType::get(vectorLayout.getElementsPerThread(), elementType);
+
+    int64_t sourceRank = sourceLayout.getRank();
+
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(distShape, resultVectorUnrollShape, loopOrder)) {
+         StaticTileOffsetRange(distShape, resultVectorUnrollShape)) {
       ArrayRef<int64_t> offsetsRef(offsets);
-      // Invert the permutations on the batch/outer offsets to get the offsets
-      // in the order of the vector dimensions. We are iterating over each
-      // (batch x outer) tile, and the offsets for those tiles are already
-      // permuted by the layout batch/outer orders. Hence why we apply the
-      // inverse permutation here.
-      SmallVector<int64_t> permutedBatchOffsets = applyPermutation(
-          offsetsRef.slice(0, rank),
-          invertPermutationVector(vectorLayout.getBatchOrder()));
-      SmallVector<int64_t> permutedOuterOffsets = applyPermutation(
-          offsetsRef.slice(rank, rank),
-          invertPermutationVector(vectorLayout.getOuterOrder()));
 
       // Slice out the last |sourceRank| dimensions which is the inner
       // broadcasted shape.
       ArrayRef<int64_t> batchSourceOffsets =
-          ArrayRef<int64_t>(permutedBatchOffsets)
-              .slice(rank - sourceRank, sourceRank);
+          offsetsRef.slice(rank - sourceRank, sourceRank);
       ArrayRef<int64_t> outerSourceOffsets =
-          ArrayRef<int64_t>(permutedOuterOffsets)
-              .slice(rank - sourceRank, sourceRank);
+          offsetsRef.slice(2 * rank - sourceRank, sourceRank);
 
       // Construct the list of source offsets based on the batch/outer order of
       // the broadcasted vector. This is because we need to compute the offsets
       // into the distributed source vector with the distributed permutation.
       SmallVector<int64_t> sourceOffsets;
-      sourceOffsets.append(
-          applyPermutation(batchSourceOffsets, sourceLayout.getBatchOrder()));
-      sourceOffsets.append(
-          applyPermutation(outerSourceOffsets, sourceLayout.getOuterOrder()));
+      sourceOffsets.append(batchSourceOffsets.begin(),
+                           batchSourceOffsets.end());
+      sourceOffsets.append(outerSourceOffsets.begin(),
+                           outerSourceOffsets.end());
 
       // Extract a slice of the input to be broadcasted.
       Value slice = rewriter.create<vector::ExtractOp>(loc, distributedSource,
@@ -538,14 +468,13 @@
 
     // Do thread local reduce.
 
+    // The distributed reduction mask is simply the same mask appended
+    // thrice.
     SmallVector<bool> distributedReductionMask;
     distributedReductionMask.reserve(3 * rank);
-    distributedReductionMask.append(
-        applyPermutation(reducedDims, srcLayout.getBatchOrder()));
-    distributedReductionMask.append(
-        applyPermutation(reducedDims, srcLayout.getOuterOrder()));
-    distributedReductionMask.append(
-        applyPermutation(reducedDims, srcLayout.getElementOrder()));
+    for (int i = 0; i < 3; ++i) {
+      distributedReductionMask.append(reducedDims.begin(), reducedDims.end());
+    }
 
     auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
         loc, disSrc, disAcc, distributedReductionMask, multiReduceOp.getKind());
@@ -627,6 +556,60 @@
   int64_t maxBitsPerShuffle;
 };
 
+struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    VectorValue value = transposeOp.getVector();
+    VectorLayoutInterface layout = dyn_cast<NestedLayoutAttr>(signature[value]);
+    if (!layout) {
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "layout must be NestedLayoutAttr");
+    }
+
+    /// Transpose only changes the notion of where the data carried by each
+    /// thread comes from in the transposed SIMD vector. The data carried by
+    /// each thread is still the same, transposed as requested by the operation.
+    /// So, for distributed dimensions (thread and subgroup) transpose is a
+    /// no-op.
+    ///
+    /// Example (indices [0-3] represent ids of the threads carrying the data):
+    ///
+    /// input: vector<2x4xf16>
+    ///
+    /// 0 0 1 1
+    /// 2 2 3 3
+    ///
+    /// after transpose,
+    ///
+    /// transp: vector<4x2xf16>
+    ///
+    /// 0 2
+    /// 0 2
+    /// 1 3
+    /// 1 3
+    ///
+    /// As it can be seen, each thread is still carrying the same data but
+    /// just holds a transposed version of it.
+
+    VectorValue input = getDistributed(rewriter, value, layout);
+    // Permute batch, outer and element based on the given permutation.
+    int64_t rank = value.getType().getRank();
+    SmallVector<int64_t> permutation;
+    for (int i = 0; i < 3; ++i) {
+      for (auto it : transposeOp.getPermutation()) {
+        permutation.push_back(it + (i * rank));
+      }
+    }
+    VectorValue transposed = rewriter.create<vector::TransposeOp>(
+        transposeOp.getLoc(), input, permutation);
+    replaceOpWithDistributedValues(rewriter, transposeOp, transposed);
+    return success();
+  }
+};
+
 } // namespace
 
 void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
@@ -635,7 +618,7 @@
                                                    int64_t maxBitsPerShuffle) {
   patterns.add<DistributeTransferRead, DistributeTransferWrite>(
       patterns.getContext(), threadId);
-  patterns.add<DistributeBroadcast>(patterns.getContext());
+  patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
   patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
                                          maxBitsPerShuffle);
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index 9141916..3c26d7f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-transform-dialect-interpreter --cse %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-transform-dialect-interpreter --canonicalize --cse %s | FileCheck %s
 
 // CDNA3 V_MFMA_F32_32x32x8_F16
 
@@ -28,8 +28,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -42,8 +40,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -71,24 +67,23 @@
 }
 
 // CHECK-LABEL: func @contract_to_mfma_32x32x8_mm
-//  CHECK-SAME: (%[[A:.+]]: vector<32x8xf16>, %[[B:.+]]: vector<8x32xf16>, %[[C:.+]]: vector<32x32xf32>)
-//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4x1x1x4xf32>
-//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<32x32xf32> -> vector<1x1x4x1x1x4xf32>
-//       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<32x8xf16> -> vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<8x32xf16> -> vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<1x1x4x1x1x4xf32>
-//       CHECK:   %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<16xf32>
-//       CHECK:   %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
-//  CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
-//  CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
-//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x1x4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
-//       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
-//       CHECK:   return {{.*}} %[[R_SIMD]]
+// CHECK-SAME: (%[[A:.+]]: vector<32x8xf16>, %[[B:.+]]: vector<8x32xf16>, %[[C:.+]]: vector<32x32xf32>)
+// CHECK:       %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<32x32xf32> -> vector<1x1x4x1x4x1xf32
+// CHECK:       %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<32x8xf16>  -> vector<1x1x1x1x1x4xf16>
+// CHECK:       %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<8x32xf16>  -> vector<1x1x1x1x4x1xf16>
+// CHECK:       %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<1x1x4x1x4x1xf32>
+// CHECK:       %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
+// CHECK:       %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x4x1xf16> from vector<1x1x1x1x4x1xf16>
+// CHECK:       %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
+// CHECK:       %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
+// CHECK:       %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK:       %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
+// CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
+// CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK:       %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK:       %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
+// CHECK:       %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32>
+// CHECK:       return {{.*}} %[[R_SIMD]]
 
 // -----
 
@@ -120,8 +115,6 @@
   threads_per_outer       = [4, 16],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [4, 16]
 >
@@ -152,22 +145,22 @@
 
 // CHECK-LABEL: func @contract_to_mfma_16x16x16_mm
 //  CHECK-SAME: (%[[A:.+]]: vector<16x16xf16>, %[[B:.+]]: vector<16x16xf16>, %[[C:.+]]: vector<16x16xf32>)
-//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32>
-//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x1x1x1x4xf32>
+//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x1x1x4x1xf32>
 //       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<16x16xf16> -> vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<1x1x1x4xf32> from vector<1x1x1x1x1x4xf32>
+//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x4x1xf16>
+//       CHECK:   %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<1x1x4x1xf32> from vector<1x1x1x1x4x1xf32>
 //       CHECK:   %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
-//       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
+//       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x4x1xf16> from vector<1x1x1x1x4x1xf16>
 //       CHECK:   %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
+//       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x4x1xf32> to vector<4xf32>
 //       CHECK:   %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
 //  CHECK-SAME:     {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none
 //  CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4xf32> to vector<1x1x1x4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<1x1x1x4xf32> into vector<1x1x1x1x1x4xf32>
-//       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x1x1x1x4xf32> -> vector<16x16xf32>
+
+//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]]  : vector<4xf32> to vector<1x1x4x1xf32>
+//       CHECK:   %[[B_OUT:.*]]  = vector.broadcast %[[R_CAST]] : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+//       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x1x1x4x1xf32> -> vector<16x16xf32>
 //       CHECK:   return {{.*}} %[[R_SIMD]]
 
 // -----
@@ -200,8 +193,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -214,8 +205,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -244,23 +233,23 @@
 
 // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch
 //       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00>
-//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x1x4xf32>
+//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x4x1xf32>
 //       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
-//       CHECK:   %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
+//       CHECK:   %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32
 //       CHECK:   %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
 //       CHECK:   %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<16xf32>
+//       CHECK:   %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
 //       CHECK:   %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
-//       CHECK:   %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x1x4xf32>
-//       CHECK:   %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
-//       CHECK:   %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
+//       CHECK:   %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
+//       CHECK:   %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
+//       CHECK:   %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
 //       CHECK:   %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
 //       CHECK:   %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<16xf32>
+//       CHECK:   %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
 //       CHECK:   %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
-//       CHECK:   %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
-//       CHECK:   %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
-//       CHECK:   %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x1x4xf32> -> vector<64x32xf32>
+//       CHECK:   %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
+//       CHECK:   %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
+//       CHECK:   %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
 //       CHECK:   return {{.*}}} %[[R]]
 
 // -----
@@ -293,8 +282,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -307,8 +294,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -336,22 +321,19 @@
 }
 
 // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_kbatch(%arg0: vector<32x16xf16>, %arg1: vector<16x32xf16>, %arg2: vector<32x32xf32>) -> vector<32x32xf32> {
-//       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<32x16xf16> -> vector<1x2x1x1x1x4xf16>
-//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<16x32xf16> -> vector<2x1x1x1x1x4xf16>
-//       CHECK:   %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16>
-//       CHECK:   %[[B_SLICE0:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
-//       CHECK:   %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
+//       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<32x16xf16>
+//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<16x32xf16>
+//       CHECK:   %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0]
+//       CHECK:   %[[B_SLICE0:.+]] = vector.extract %[[B_SIMT]][0, 0]
+//       CHECK:   %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]]
+//       CHECK:   %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]]
 //       CHECK:   %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}}
-//       CHECK:   %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16>
-//       CHECK:   %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
-//       CHECK:   %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
+//       CHECK:   %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1]
+//       CHECK:   %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0]
+//       CHECK:   %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]]
+//       CHECK:   %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]]
 //       CHECK:   %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
-//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
-//       CHECK:   %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
-//       CHECK:   return {{.*}} %[[R]] : vector<32x32xf32>
+//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]]
 
 // -----
 
@@ -383,8 +365,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -397,9 +377,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  batch_order             = [1, 0],
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -427,27 +404,27 @@
 }
 
 // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
-//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<3x2x4x1x1x4xf32>
-//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2x3x4x1x4x1xf32>
+//       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
+//       CHECK:   vector.extract %[[C_SIMT]][0, 0]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][0, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
+//       CHECK:   vector.extract %[[C_SIMT]][0, 1]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1]
+//       CHECK:   vector.extract %[[C_SIMT]][0, 2]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [1, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][1, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [0, 2]
+//       CHECK:   vector.extract %[[C_SIMT]][1, 0]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][2, 0] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 0]
+//       CHECK:   vector.extract %[[C_SIMT]][1, 1]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [2, 0] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   vector.extract %[[C_SIMT]][2, 1] : vector<4x1x1x4xf32> from vector<3x2x4x1x1x4xf32>
+//       CHECK:   %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [1, 1]
+//       CHECK:   vector.extract %[[C_SIMT]][1, 2]
 //       CHECK:   amdgpu.mfma
-//       CHECK:   %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [2, 1] : vector<4x1x1x4xf32> into vector<3x2x4x1x1x4xf32>
-//       CHECK:   iree_vector_ext.to_simd %[[INS5]] : vector<3x2x4x1x1x4xf32> -> vector<64x96xf32>
+//       CHECK:   %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [1, 2]
+//       CHECK:   iree_vector_ext.to_simd %[[INS5]]
 
 // -----
 
@@ -493,8 +470,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 32]
 >
@@ -522,15 +497,15 @@
 }
 
 // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mmt
-// CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x1x4xf32>
+// CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x4x1xf32>
 // CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
-// CHECK:   vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
+// CHECK:   vector.extract %[[B_SIMT]][0, 0]
 // CHECK:   amdgpu.mfma
-// CHECK:   %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x2x4x1x1x4xf32>
-// CHECK:   vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
+// CHECK:   %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
+// CHECK:   vector.extract %[[B_SIMT]][1, 0]
 // CHECK:   amdgpu.mfma
-// CHECK:   %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1] : vector<4x1x1x4xf32> into vector<1x2x4x1x1x4xf32>
-// CHECK:   iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x1x4xf32> -> vector<32x64xf32>
+// CHECK:   %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1]
+// CHECK:   iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
 
 // -----
 
@@ -562,8 +537,6 @@
   threads_per_outer       = [1, 16],
   elements_per_thread     = [16, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [1, 32]
 >
@@ -576,8 +549,6 @@
   threads_per_outer       = [2, 16],
   elements_per_thread     = [1, 1],
 
-  element_order = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [2, 16]
 >
@@ -607,19 +578,17 @@
 
 // CHECK-LABEL: func.func @contract_to_wmma_16x16x16_mm
 //  CHECK-SAME: (%[[A:.+]]: vector<16x16xf16>, %[[B:.+]]: vector<16x16xf16>, %[[C:.+]]: vector<16x16xf32>)
-//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x8x1x1x1xf32>
 //       CHECK:   %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x8x1x1x1xf32>
 //       CHECK:   %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
-//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
+//       CHECK:   %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x16x1xf16>
 //       CHECK:   %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<8x1x1x1xf32> from vector<1x1x8x1x1x1xf32>
 //       CHECK:   %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
-//       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
-//       CHECK:   %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
-//       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
+//       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x16x1xf16> from vector<1x1x1x1x16x1xf16>
+//       CHECK:   %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf1
+//       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x16x1xf16> to vector<16xf1
 //       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<8x1x1x1xf32> to vector<8xf32>
 //       CHECK:   %[[WMMA:.+]] = amdgpu.wmma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
 //       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[WMMA]] : vector<8xf32> to vector<8x1x1x1xf32>
-//       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<8x1x1x1xf32> into vector<1x1x8x1x1x1xf32>
-//       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
+//       CHECK:   %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
+//       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
 //       CHECK:   return {{.*}} %[[R_SIMD]]
-
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index d93b3c2..41002f8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -1,54 +1,5 @@
 // RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
 
-#layout_row_major = #iree_vector_ext.nested_layout<
-  subgroups_per_workgroup = [1, 1],
-  batches_per_subgroup    = [2, 2],
-  outers_per_batch        = [1, 1],
-  threads_per_outer       = [8, 1],
-  elements_per_thread     = [1, 8],
-
-  batch_order             = [1, 0],
-
-  subgroup_basis          = [1, 1],
-  thread_basis            = [8, 1]
->
-
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
-// CHECK-LABEL: @distribute_transfer_read_row_major
-func.func @distribute_transfer_read_row_major(%arg0: memref<4x4xf16>) -> vector<16x16xf16> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f16
-  %root = vector.transfer_read %arg0[%c0, %c0], %cst
-          {in_bounds = [false, false],
-           "__vector_layout_test_anchor_result_0" = #layout_row_major}
-                  : memref<4x4xf16>, vector<16x16xf16>
-  func.return %root : vector<16x16xf16>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
-    transform.yield
-  }
-}
-
-// CHECK: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x1x1x1x8xf16>
-// CHECK: %[[IDX:.+]] = gpu.thread_id  x
-// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[IDX]] into (%c1, %c1, %c8, %c1) : index, index, index, index
-// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c0], {{.*}} : memref<4x4xf16>, vector<1x8xf16>
-// CHECK: vector.insert_strided_slice %{{.*}}, %[[ACC]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x2x1x1x1x8xf16>
-// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c8]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
-// CHECK: %[[ID_PLUS_BATCH1:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
-// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c0]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
-// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
-
-// -----
-
 #layout_col_major = #iree_vector_ext.nested_layout<
   subgroups_per_workgroup = [1, 1],
   batches_per_subgroup    = [1, 2],
@@ -56,9 +7,6 @@
   threads_per_outer       = [4, 8],
   elements_per_thread     = [4, 1],
 
-  batch_order             = [1, 0],
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [4, 8]
 >
@@ -88,13 +36,11 @@
 // CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c4, %c8) : index, index, index, index
 // CHECK: %[[LANEY:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
 // CHECK: %[[RD00:.+]] = vector.transfer_read %arg0[%[[LANEY:.+]], %[[IDS]]#3], {{.*}} : memref<32x32xf16>, vector<4x1xf16>
-// CHECK: %[[ELEM_ORDER:.+]] = vector.transpose %[[RD00]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice %[[ELEM_ORDER]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x4xf16> into vector<2x1x1x1x1x4xf16>
+// CHECK: vector.insert_strided_slice %[[RD00]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<4x1xf16> into vector<1x2x1x1x4x1xf16>
 // CHECK: %[[LANEX_PLUS_BATCH:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
 // CHECK: vector.transfer_read %arg0[%[[LANEY]], %[[LANEX_PLUS_BATCH]]], %{{.*}} {in_bounds = [true, true]} : memref<32x32xf16>, vector<4x1xf16>
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1x1x1x4xf16> -> vector<16x16xf16>
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<1x2x1x1x4x1xf16> -> vector<16x16xf16>
 
 // -----
 
@@ -105,8 +51,6 @@
   threads_per_outer       = [8, 1],
   elements_per_thread     = [1, 8],
 
-  batch_order             = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [8, 1]
 >
@@ -154,9 +98,6 @@
   threads_per_outer       = [4, 8],
   elements_per_thread     = [4, 1],
 
-  batch_order             = [1, 0],
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [4, 8]
 >
@@ -186,9 +127,8 @@
 // CHECK-SAME:    %[[I0:.+]]: index, %[[I1:.+]]: index
 
 // CHECK: %[[BROADCAST_READ:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[I1]]], %{{.*}} permutation_map = #[[$MAP]]
-// CHECK: %[[UNIT:.+]] = vector.transpose %[[BROADCAST_READ]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
-// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
-// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[BROADCAST_READ]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[BROADCAST_READ]], %{{.*}} {offsets = [0, 1, 0, 0, 0, 0]
 
 // -----
 
@@ -199,8 +139,6 @@
   threads_per_outer       = [8, 1],
   elements_per_thread     = [1, 8],
 
-  batch_order             = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [8, 1]
 >
@@ -289,8 +227,6 @@
   elements_per_thread     = [1, 1, 1, 2],
 
   subgroup_order          = [1, 0, 2, 3],
-  batch_order             = [1, 2, 3, 0],
-  outer_order             = [0, 3, 1, 2],
   thread_order            = [0, 1, 3, 2],
 
   subgroup_basis          = [7, 3, 1, 1],
@@ -410,8 +346,6 @@
   threads_per_outer       = [8, 1],
   elements_per_thread     = [1, 8],
 
-  batch_order             = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [8, 1]
 >
@@ -439,10 +373,10 @@
 // CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
 // CHECK: %[[SLICE:.+]] = vector.extract %{{.*}}[0, 0, 0, 0] : vector<1x8xf16> from vector<2x2x1x1x1x8xf16>
 // CHECK: vector.transfer_write %[[SLICE]], %{{.*}}[%[[IDS]]#2, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<64x64xf16>
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[IDS]]#2, %c8]
 // CHECK: %[[LANEX_PLUS_VECDIMX:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
-// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c0]
 // CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c8]
@@ -456,9 +390,6 @@
   threads_per_outer       = [4, 8],
   elements_per_thread     = [4, 1],
 
-  batch_order             = [1, 0],
-  element_order           = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [4, 8]
 >
@@ -488,11 +419,9 @@
 // CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[TIDX]] into (%c1, %c1, %c4, %c8) : index, index, index, index
 // CHECK: %[[LANEY:.+]] = affine.apply #map()[%[[IDS]]#2]
 // CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
 // CHECK: vector.transfer_write %{{.*}}[%[[LANEY]], %[[IDS]]#3]
 // CHECK: %[[LANEX:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
-// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
 // CHECK: vector.transfer_write {{.*}}[%[[LANEY]], %[[LANEX]]]
 
 // -----
@@ -504,8 +433,6 @@
   threads_per_outer       = [8, 1],
   elements_per_thread     = [1, 8],
 
-  batch_order             = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [8, 1]
 >
@@ -541,10 +468,10 @@
 // CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
 // CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
-// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
 // CHECK: %[[LIN_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
-// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
 // CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
 // CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
@@ -558,8 +485,6 @@
   threads_per_outer       = [8, 1],
   elements_per_thread     = [1, 8],
 
-  batch_order             = [1, 0],
-
   subgroup_basis          = [1, 1],
   thread_basis            = [8, 1]
 >
@@ -623,8 +548,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [4, 2],
   thread_basis            = [2, 32]
 >
@@ -637,8 +560,6 @@
   threads_per_outer       = [2, 32],
   elements_per_thread     = [4, 1],
 
-  element_order           = [1, 0],
-
   subgroup_basis          = [4, 2],
   thread_basis            = [2, 32]
 >
@@ -715,10 +636,7 @@
   elements_per_thread     = [1, 4],
 
   subgroup_order          = [1, 0],
-  batch_order             = [1, 0],
-  outer_order             = [1, 0],
   thread_order            = [1, 0],
-  element_order           = [0, 1],
 
   subgroup_basis          = [4, 2],
   thread_basis            = [2, 32]
@@ -767,7 +685,6 @@
   outers_per_batch = [1, 1],
   threads_per_outer = [4, 16],
   elements_per_thread = [4, 1],
-  element_order = [1, 0],
   subgroup_basis = [2, 2],
   thread_basis = [4, 16]
 >
@@ -787,21 +704,21 @@
 }
 
 // CHECK: vector.extract {{.*}}[0, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 0, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 0, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
 // CHECK: vector.extract {{.*}}[1, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 1, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 1, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
 // CHECK: vector.extract {{.*}}[2, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
-// CHECK: vector.insert {{.*}} [0, 2, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
+// CHECK: vector.insert {{.*}} [0, 2, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
 // CHECK: vector.extract {{.*}}[3, 0] : vector<1xf16> from vector<4x1x1xf16>
-// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<1x4xf16>
+// CHECK: vector.broadcast {{.*}} : vector<1xf16> to vector<4x1xf16>
 
-// CHECK: vector.insert {{.*}} [1, 0, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 1, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 2, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
-// CHECK: vector.insert {{.*}} [1, 3, 0, 0] : vector<1x4xf16> into vector<2x4x1x1x1x4xf16>
+// CHECK: vector.insert {{.*}} [1, 0, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 1, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 2, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
+// CHECK: vector.insert {{.*}} [1, 3, 0, 0] : vector<4x1xf16> into vector<2x4x1x1x4x1xf16>
 
 // -----
 
@@ -811,7 +728,6 @@
   outers_per_batch = [2, 1, 1],
   threads_per_outer = [4, 16, 8],
   elements_per_thread = [1, 4, 4],
-  batch_order = [2, 1, 0],
   subgroup_basis = [2, 2, 2],
   thread_basis = [4, 16, 8]
 >
@@ -832,14 +748,14 @@
 
 // CHECK: %[[EXTRACT:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
 // CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf16> to vector<1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 1, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
-// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 0, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 0, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 1, 0, 0, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}} [1, 1, 0, 1, 0, 0] : vector<1x4x4xf16> into vector<2x2x1x2x1x1x1x4x4xf16>
 
 // -----
 
@@ -869,7 +785,8 @@
 }
 
 // CHECK-LABEL: func @transpose
-// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<256x64xf16> -> vector<2x4x2x1x2x2xf16>
+// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<256x64xf16> -> vector<4x2x1x2x2x2xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0, 3, 2, 5, 4] : vector<4x2x1x2x2x2xf16> to vector<2x4x2x1x2x2xf16>
 // CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
 // CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<64x256xf16>
 
@@ -907,17 +824,15 @@
 // CHECK-SAME:   threads_per_outer = [16, 4]
 // CHECK-SAME:   elements_per_thread = [2, 2]
 // CHECK-SAME:   subgroup_order = [1, 0]
-// CHECK-SAME:   batch_order = [1, 0],
-// CHECK-SAME:   outer_order = [1, 0]
 // CHECK-SAME:   thread_order = [1, 0]
-// CHECK-SAME:   element_order = [1, 0]
 // CHECK-SAME:   subgroup_basis = [2, 2]
 // CHECK-SAME:   thread_basis = [4, 16]
 
 // CHECK-LABEL: func @transpose
 // CHECK: iree_vector_ext.to_simt %{{.*}} : vector<64x256xf16> -> vector<2x4x2x1x2x2xf16>
-// CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
-// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<256x64xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0, 3, 2, 5, 4] : vector<2x4x2x1x2x2xf16> to vector<4x2x1x2x2x2xf16>
+// CHECK: math.sqrt %{{.*}} : vector<4x2x1x2x2x2xf16>
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<4x2x1x2x2x2xf16> -> vector<256x64xf16>
 // CHECK: return {{.*}}#[[$LAYOUT]]
 
 // -----
@@ -964,69 +879,54 @@
 // CHECK-LABEL: func @transpose_3d
 // CHECK:         %[[IDS:.+]]:6 = affine.delinearize_index %{{.*}} into (%c2, %c1, %c1, %c4, %c8, %c2)
 
-// CHECK:         %[[DIM0_ID:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#0, %[[IDS]]#3]
-// CHECK:         %[[DIM2_ID0:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#5]
-// CHECK:         %[[RD0:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID0]]], {{.*}} : memref<32x32x32xf16>, vector<4x1x2xf16>
-// CHECK:         %[[DIM2_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#5]
-// CHECK:         %[[RD1:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID1]]]
-// CHECK:         %[[DIM2_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#5]
-// CHECK:         %[[RD2:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID2]]]
-// CHECK:         %[[DIM2_ID3:.+]] = affine.apply #[[$MAP4]]()[%[[IDS]]#5]
-// CHECK:         %[[RD3:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID3]]]
-// CHECK:         %[[DIM2_ID4:.+]] = affine.apply #[[$MAP5]]()[%[[IDS]]#4]
-// CHECK:         %[[RD4:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID0]]]
-// CHECK:         %[[RD5:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID1]]]
-// CHECK:         %[[RD6:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID2]]]
-// CHECK:         %[[RD7:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID3]]]
+// CHECK-DAG:         %[[DIM0_ID:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#0, %[[IDS]]#3]
+// CHECK-DAG:         %[[DIM2_ID0:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#5]
+// CHECK-DAG:         %[[RD0:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID0]]], {{.*}} : memref<32x32x32xf16>, vector<4x1x2xf16>
+// CHECK-DAG:         %[[DIM2_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#5]
+// CHECK-DAG:         %[[RD1:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID1]]]
+// CHECK-DAG:         %[[DIM2_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#5]
+// CHECK-DAG:         %[[RD2:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID2]]]
+// CHECK-DAG:         %[[DIM2_ID3:.+]] = affine.apply #[[$MAP4]]()[%[[IDS]]#5]
+// CHECK-DAG:         %[[RD3:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[IDS]]#4, %[[DIM2_ID3]]]
+// CHECK-DAG:         %[[DIM2_ID4:.+]] = affine.apply #[[$MAP5]]()[%[[IDS]]#4]
+// CHECK-DAG:         %[[RD4:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID0]]]
+// CHECK-DAG:         %[[RD5:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID1]]]
+// CHECK-DAG:         %[[RD6:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID2]]]
+// CHECK-DAG:         %[[RD7:.+]] = vector.transfer_read %arg0[%[[DIM0_ID]], %[[DIM2_ID4]], %[[DIM2_ID3]]]
 
-// CHECK:         %[[T0:.+]] = vector.transpose %[[RD0]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T0]], %arg0[%[[IDS]]#4, %[[DIM2_ID0]], %[[DIM0_ID]]] {{.*}} : vector<1x2x4xf16>, memref<32x32x32xf16>
+// CHECK:         vector.transpose %{{.*}}, [1, 2, 0, 4, 5, 3, 7, 8, 6] : vector<1x2x4x1x1x1x4x1x2xf16> to vector<2x4x1x1x1x1x1x2x4xf16>
 
-// CHECK:         %[[T1:.+]] = vector.transpose %[[RD4]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T1]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID0]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T2:.+]] = vector.transpose %[[RD1]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T2]], %arg0[%[[IDS]]#4, %[[DIM2_ID1]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T3:.+]] = vector.transpose %[[RD5]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T3]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID1]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T4:.+]] = vector.transpose %[[RD2]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T4]], %arg0[%[[IDS]]#4, %[[DIM2_ID2]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T5:.+]] = vector.transpose %[[RD6]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T5]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID2]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T6:.+]] = vector.transpose %[[RD3]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T6]], %arg0[%[[IDS]]#4, %[[DIM2_ID3]], %[[DIM0_ID]]]
-
-// CHECK:         %[[T7:.+]] = vector.transpose %[[RD7]], [1, 2, 0] : vector<4x1x2xf16> to vector<1x2x4xf16>
-// CHECK:         vector.transfer_write %[[T7]], %arg0[%[[DIM2_ID4]], %[[DIM2_ID3]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID0]], %[[DIM0_ID]]] {{.*}} : vector<1x2x4xf16>, memref<32x32x32xf16>
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID0]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID1]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID1]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID2]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID2]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[IDS]]#4, %[[DIM2_ID3]], %[[DIM0_ID]]]
+// CHECK-DAG:         vector.transfer_write %{{.*}}, %arg0[%[[DIM2_ID4]], %[[DIM2_ID3]], %[[DIM0_ID]]]
 
 // -----
 
 #nested = #iree_vector_ext.nested_layout<
-  subgroups_per_workgroup = [1, 1], 
-  // We are reducing along dim=1, so each thread will reduce 
+  subgroups_per_workgroup = [1, 1],
+  // We are reducing along dim=1, so each thread will reduce
   // 2 batches x 4 elements = 8 elements.
-  batches_per_subgroup = [2, 2], 
-  outers_per_batch = [1, 1], 
+  batches_per_subgroup = [2, 2],
+  outers_per_batch = [1, 1],
   // We are reducing on dim=1, which is distributed over 4 threads. Based
   // on the subgroup basis and thread order, the shuffle offset is 16.
-  threads_per_outer = [16, 4], 
-  elements_per_thread = [1, 4], 
+  threads_per_outer = [16, 4],
+  elements_per_thread = [1, 4],
 
-  subgroup_order = [1, 0], 
-  batch_order = [1, 0], 
-  outer_order = [1, 0], 
-  thread_order = [1, 0], 
+  subgroup_order = [1, 0],
+  thread_order = [1, 0],
 
-  subgroup_basis = [1, 1], 
+  subgroup_basis = [1, 1],
   thread_basis = [4, 16]
 >
 
 func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
-  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1 
+  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
   {
     __vector_layout_test_anchor_operand_0 = #nested
   } [1] : vector<32x32xf32> to vector<32xf32>
@@ -1048,7 +948,7 @@
 // CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
 // CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
 // Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [0, 2, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[DARG1]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
 // Global reduction
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C16]], %[[C64]] : f32
 // CHECK: gpu.shuffle  xor %{{.*}}, %[[C32]], %[[C64]] : f32
@@ -1060,7 +960,7 @@
 
 #nested = #iree_vector_ext.nested_layout<
   subgroups_per_workgroup = [1, 1],
-  // We are reducing along dim=1, so each thread will reduce 
+  // We are reducing along dim=1, so each thread will reduce
   // 4 batches x 4 elements = 16 elements.
   batches_per_subgroup    = [1, 4],
   outers_per_batch        = [1, 1],
@@ -1076,7 +976,7 @@
 >
 
 func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
-  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1 
+  %0 = vector.multi_reduction <maximumf>, %arg0, %arg1
   {
     __vector_layout_test_anchor_operand_0 = #nested
   } [1] : vector<32x32xf32> to vector<32xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index 0a03966..f36e8bd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -45,10 +45,7 @@
   elements_per_thread     = [1, 8, 2],
 
   subgroup_order          = [0, 1, 2],
-  batch_order             = [0, 1, 2],
-  outer_order             = [0, 1, 2],
   thread_order            = [0, 1, 2],
-  element_order           = [0, 2, 1],
 
   subgroup_basis          = [2, 1, 1],
   thread_basis            = [8, 2, 4]
@@ -58,15 +55,15 @@
 func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>, %b: vector<128x128x128xf16>) -> vector<128x128x128xf16> {
   %c0 = arith.constant 0 : index
   %cst_0 = arith.constant 0.0 : f16
-  // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x2x8xf16>
+  // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x8x2xf16>
   %root = arith.constant {"__vector_layout_test_anchor_result_0" = #nested} dense<0.0> : vector<128x128x128xf16>
-  // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
-  // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x2x8xf16>
+  // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
+  // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x8x2xf16>
   %c = arith.mulf %root, %b : vector<128x128x128xf16>
-  // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
-  // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x2x8xf16>
+  // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
+  // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x8x2xf16>
   %d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<128x128x128xf16>
-  // CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x2x8xf16> -> vector<128x128x128xf16>
+  // CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x8x2xf16> -> vector<128x128x128xf16>
   return %d : vector<128x128x128xf16>
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index f27f67d..417f8ee 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -602,10 +602,9 @@
 NestedLayoutAttr permuteAndCreateNestedLayout(
     MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim,
     SmallVector<int64_t> subgroupCount, SmallVector<int64_t> subgroupOrder,
-    SmallVector<int64_t> batchCount, SmallVector<int64_t> batchOrder,
-    MMAAttr::SingleSubgroupLayout counts, MMAAttr::SingleSubgroupLayout orders,
-    ArrayRef<int64_t> dataDuplicate, ArrayRef<int64_t> subgroupBasis,
-    ArrayRef<bool> subgroupActiveIds) {
+    SmallVector<int64_t> batchCount, MMAAttr::SingleSubgroupLayout counts,
+    MMAAttr::SingleSubgroupLayout orders, ArrayRef<int64_t> dataDuplicate,
+    ArrayRef<int64_t> subgroupBasis, ArrayRef<bool> subgroupActiveIds) {
 
   LLVM_DEBUG({
     llvm::errs() << "Given:";
@@ -617,20 +616,14 @@
     llvm::interleaveComma(subgroupOrder, llvm::errs());
     llvm::errs() << "\n    batchCount: ";
     llvm::interleaveComma(batchCount, llvm::errs());
-    llvm::errs() << "\n    batchOrder: ";
-    llvm::interleaveComma(batchOrder, llvm::errs());
     llvm::errs() << "\n    counts.outer: ";
     llvm::interleaveComma(counts.outer, llvm::errs());
-    llvm::errs() << "\n    orders.outer: ";
-    llvm::interleaveComma(orders.outer, llvm::errs());
     llvm::errs() << "\n    counts.thread: ";
     llvm::interleaveComma(counts.thread, llvm::errs());
     llvm::errs() << "\n    orders.thread: ";
     llvm::interleaveComma(orders.thread, llvm::errs());
     llvm::errs() << "\n    counts.element: ";
     llvm::interleaveComma(counts.element, llvm::errs());
-    llvm::errs() << "\n    orders.element: ";
-    llvm::interleaveComma(orders.element, llvm::errs());
     llvm::errs() << "\n    subgroupBasis: ";
     llvm::interleaveComma(subgroupBasis, llvm::errs());
     llvm::errs() << "\n    subgroupActiveIds: ";
@@ -638,12 +631,8 @@
     llvm::errs() << "\n";
   });
 
-  SmallVector<int64_t> outerOrder =
-      getIdentityPermWithSwap(rank, orders.outer, outerDim, innerDim);
   SmallVector<int64_t> threadOrder =
       getIdentityPermWithSwap(rank, orders.thread, outerDim, innerDim);
-  SmallVector<int64_t> elementOrder =
-      getIdentityPermWithSwap(rank, orders.element, outerDim, innerDim);
 
   SmallVector<int64_t> threadBasis =
       getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim);
@@ -666,20 +655,14 @@
     llvm::interleaveComma(subgroupOrder, llvm::errs());
     llvm::errs() << "\n    batchCount: ";
     llvm::interleaveComma(batchCount, llvm::errs());
-    llvm::errs() << "\n    batchOrder: ";
-    llvm::interleaveComma(batchOrder, llvm::errs());
     llvm::errs() << "\n    outerCount: ";
     llvm::interleaveComma(outerCount, llvm::errs());
-    llvm::errs() << "\n    outerOrder: ";
-    llvm::interleaveComma(outerOrder, llvm::errs());
     llvm::errs() << "\n    threadCount: ";
     llvm::interleaveComma(threadCount, llvm::errs());
     llvm::errs() << "\n    threadOrder: ";
     llvm::interleaveComma(threadOrder, llvm::errs());
     llvm::errs() << "\n    elementCount: ";
     llvm::interleaveComma(elementCount, llvm::errs());
-    llvm::errs() << "\n    elementOrder: ";
-    llvm::interleaveComma(elementOrder, llvm::errs());
     llvm::errs() << "\n    subgroupBasis: ";
     llvm::interleaveComma(subgroupBasis, llvm::errs());
     llvm::errs() << "\n    subgroupActiveIds: ";
@@ -690,10 +673,9 @@
   });
 
   auto layoutAttr = NestedLayoutAttr::get(
-      context, subgroupCount, subgroupOrder, batchCount, batchOrder, outerCount,
-      outerOrder, threadCount, threadOrder, elementCount, elementOrder,
-      subgroupBasis, subgroupActiveIds, threadBasis,
-      SmallVector<bool>(threadBasis.size(), true));
+      context, subgroupCount, subgroupOrder, batchCount, outerCount,
+      threadCount, threadOrder, elementCount, subgroupBasis, subgroupActiveIds,
+      threadBasis, SmallVector<bool>(threadBasis.size(), true));
   return layoutAttr;
 }
 
@@ -847,8 +829,7 @@
       context, cRank, m, n,
       /*subgroupCount=*/cSubgroupSizes,
       /*subgroupOrder=*/cOverallOrder,
-      /*batchCount=*/cBatchSizes,
-      /*batchOrder=*/cOverallOrder, cCounts, cOrders,
+      /*batchCount=*/cBatchSizes, cCounts, cOrders,
       /*dataDuplicate=*/mmaAttr.getCDataDuplicate(), subgroupBasis,
       cActiveSubgroups);
   LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; });
@@ -868,15 +849,12 @@
   SmallVector<int64_t> aBatchSizes(aRank, 1);
   SmallVector<int64_t> aSubgroupSizes(aRank, 1);
   SmallVector<int64_t> aSubgroupOrder(aRank, 0);
-  SmallVector<int64_t> aBatchOrder(aRank, 0);
   for (auto [i, dim] : llvm::enumerate(aMDims)) {
     aBatchSizes[dim] = batchMSizes[i];
     aSubgroupSizes[dim] = subgroupMBasis[i];
     aSubgroupOrder[dim] = i;
-    aBatchOrder[dim] = i >= afk ? i + 1 : i;
   }
   aSubgroupOrder[afk] = aRank - 1;
-  aBatchOrder[afk] = afk;
   aBatchSizes[afk] = getSubgroupKTileCount();
 
   SmallVector<bool> aActiveSubgroups(subgroupBasis.size(), false);
@@ -889,8 +867,7 @@
       context, aRank, afm, afk,
       /*subgroupCount=*/aSubgroupSizes,
       /*subgroupOrder=*/aSubgroupOrder,
-      /*batchCount=*/aBatchSizes,
-      /*batchOrder=*/getIdentityPerm(aRank), aCounts, aOrders,
+      /*batchCount=*/aBatchSizes, aCounts, aOrders,
       /*dataDuplicate=*/mmaAttr.getADataDuplicate(), subgroupBasis,
       aActiveSubgroups);
   LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; });
@@ -907,15 +884,12 @@
   SmallVector<int64_t> bBatchSizes(bRank, 1);
   SmallVector<int64_t> bSubgroupSizes(bRank, 1);
   SmallVector<int64_t> bSubgroupOrder(bRank, 0);
-  SmallVector<int64_t> bBatchOrder(bRank, 0);
   for (auto [i, dim] : llvm::enumerate(bNDims)) {
     bBatchSizes[dim] = batchNSizes[i];
     bSubgroupSizes[dim] = subgroupNBasis[i];
     bSubgroupOrder[dim] = i;
-    bBatchOrder[dim] = i >= bfk ? i + 1 : i;
   }
   bSubgroupOrder[bfk] = bRank - 1;
-  bBatchOrder[bfk] = bfk;
   bBatchSizes[bfk] = getSubgroupKTileCount();
 
   SmallVector<bool> bActiveSubgroups(subgroupBasis.size(), false);
@@ -928,8 +902,7 @@
       context, bRank, bfk, bfn,
       /*subgroupCount=*/bSubgroupSizes,
       /*subgroupOrder=*/bSubgroupOrder,
-      /*batchCount=*/bBatchSizes,
-      /*batchOrder=*/bBatchOrder, bCounts, bOrders,
+      /*batchCount=*/bBatchSizes, bCounts, bOrders,
       /*dataDuplicate=*/mmaAttr.getBDataDuplicate(), subgroupBasis,
       bActiveSubgroups);
   LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; });
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index ed14df5..fa01423 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -326,8 +326,8 @@
     SmallVector<int64_t> threadBasis = threadCounts;
 
     auto layout = IREE::VectorExt::NestedLayoutAttr::get(
-        context, subgroupCounts, order, batchSizes, order, outerSizes, order,
-        threadCounts, order, elementSizes, order, subgroupBasis,
+        context, subgroupCounts, order, batchSizes, outerSizes, threadCounts,
+        order, elementSizes, subgroupBasis,
         SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
         SmallVector<bool>(threadBasis.size(), true));
     if (analysis.setAnchor(transfer.getResult(), layout).failed()) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 99a864b..83ae1be 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -49,11 +49,11 @@
 
 //    CHECK-LABEL: func.func @matmul_256x256x256_f16_f32()
 //     CHECK-SAME:    translation_info = #[[$TRANSLATION]]
-//          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x1x4xf32>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf32>)
 // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
 // along the K dimension. So in total 32 mfma ops.
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:     scf.yield %{{.+}} : vector<2x2x1x1x1x4xf32>
+//          CHECK:     scf.yield %{{.+}} : vector<2x2x1x1x4x1xf32>
 //  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
 
 // -----
@@ -100,11 +100,11 @@
 
 //    CHECK-LABEL: func.func @matmul_256x256x256_f16_f16()
 //     CHECK-SAME:     translation_info = #[[$TRANSLATION]]
-//          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x1x4xf16>)
-//          CHECK:     arith.extf %[[ARG]] : vector<2x2x1x1x1x4xf16> to vector<2x2x1x1x1x4xf32>
+//          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x4x1xf16>)
+//          CHECK:     arith.extf %[[ARG]] : vector<2x2x1x1x4x1xf16> to vector<2x2x1x1x4x1xf32>
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x2x1x1x1x4xf32> to vector<2x2x1x1x1x4xf16>
-//          CHECK:     scf.yield %[[TRUNC]] : vector<2x2x1x1x1x4xf16>
+//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x2x1x1x4x1xf32> to vector<2x2x1x1x4x1xf16>
+//          CHECK:     scf.yield %[[TRUNC]] : vector<2x2x1x1x4x1xf16>
 //  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
 
 // -----
@@ -170,11 +170,11 @@
 //     CHECK-SAME:     translation_info = #[[TRANSLATION]]
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 2048 - 128 = 1920.
-//          CHECK:   scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x1x4xf16>)
-//          CHECK:     arith.extf %[[ARG]] : vector<4x1x1x1x1x4xf16> to vector<4x1x1x1x1x4xf32>
+//          CHECK:   scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
+//          CHECK:     arith.extf %[[ARG]] : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x1x4xf32> to vector<4x1x1x1x1x4xf16>
-//          CHECK:     scf.yield %[[TRUNC]] : vector<4x1x1x1x1x4xf16>
+//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x4x1xf32> to vector<4x1x1x1x4x1xf16>
+//          CHECK:     scf.yield %[[TRUNC]] : vector<4x1x1x1x4x1xf16>
 // CHECK-COUNT-32:   amdgpu.mfma
 //  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type<storage_buffer>>
 
@@ -221,7 +221,7 @@
 //          CHECK:     scf.for {{.*}} = %c0 to %c3
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 768 - 32 = 736.
-//          CHECK:       scf.for {{.*}} = %c0 to %c736 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf32>)
+//          CHECK:       scf.for {{.*}} = %c0 to %c736 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x4x1xf32>)
 // CHECK-COUNT-16:         amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 //          CHECK:         scf.yield
 // CHECK-COUNT-16:       amdgpu.mfma
@@ -300,11 +300,11 @@
 //    CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 1280 - 128 = 1152.
-//          CHECK:   scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x1x4xf16>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
 // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
 // along the K dimension. So in total 32 mfma ops.
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:     scf.yield %{{.+}} : vector<2x2x1x1x1x4xf16>
+//          CHECK:     scf.yield %{{.+}} : vector<2x2x1x1x4x1xf16>
 // CHECK-COUNT-32:   amdgpu.mfma
 //  CHECK-COUNT-4:   vector.transfer_write {{.+}} : vector<4x1xf16>
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
index bcd47c4..244fef4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir
@@ -38,10 +38,10 @@
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_256x256x256
-//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32>
+//       CHECK:   %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x4x1xf32>
 //       CHECK:   %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
 //       CHECK:   %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
-//       CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
+//       CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]])
 //       CHECK:     %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
 //       CHECK:     %[[RLOAD:.+]] = vector.transfer_read {{.*}} : memref<256x16xf16, {{.*}}>, vector<1x8xf16>
 //       CHECK:     vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
@@ -50,8 +50,8 @@
 // CHECK-COUNT-2:   vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<1x4xf16>
 // CHECK-COUNT-2:   vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4x1xf16>
 // CHECK-COUNT-2:   amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//       CHECK:     %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
-//       CHECK:     scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32>
+//       CHECK:     %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+//       CHECK:     scf.yield %[[BCAST]]
 //       CHECK:  vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}>
 
 // -----
@@ -107,9 +107,8 @@
 //       CHECK:   %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
 //       CHECK:   affine.delinearize_index %[[LIN_ID]]
 //       CHECK:   %[[INIT_READ:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<4x1xf32>
-//       CHECK:   %[[INIT_TRANSP:.+]] = vector.transpose %[[INIT_READ]], [1, 0]
-//       CHECK:   %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_TRANSP]]
-//       CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
+//       CHECK:   %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_READ]]
+//       CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x4x1xf32>)
 //       CHECK:     %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
 //       CHECK:     %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP1]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16>
 //       CHECK:     vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
@@ -118,8 +117,8 @@
 // CHECK-COUNT-2:   vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space<workgroup>>, vector<1x4xf16>
 // CHECK-COUNT-2:   vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4x1xf16>
 // CHECK-COUNT-2:   amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//       CHECK:     %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
-//       CHECK:     scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32>
+//       CHECK:     %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x4x1xf32> to vector<1x1x1x1x4x1xf32>
+//       CHECK:     scf.yield %[[BCAST]]
 //       CHECK:  vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}>
 
 // -----
@@ -237,29 +236,21 @@
 //       CHECK:   %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
 //       CHECK:   affine.delinearize_index %[[LIN_ID]]
 //       CHECK:   %[[INIT_READ0:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP0:.+]] = vector.transpose %[[INIT_READ0]], [1, 0]
-//       CHECK:   %[[INIT0:.+]] = vector.insert_strided_slice %[[INIT_TRANSP0]]
+//       CHECK:   %[[INIT0:.+]] = vector.insert_strided_slice %[[INIT_READ0]]
 //       CHECK:   %[[INIT_READ1:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP1:.+]] = vector.transpose %[[INIT_READ1]], [1, 0]
-//       CHECK:   %[[INIT1:.+]] = vector.insert_strided_slice %[[INIT_TRANSP1]]
+//       CHECK:   %[[INIT1:.+]] = vector.insert_strided_slice %[[INIT_READ1]]
 //       CHECK:   %[[INIT_READ2:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP2:.+]] = vector.transpose %[[INIT_READ2]], [1, 0]
-//       CHECK:   %[[INIT2:.+]] = vector.insert_strided_slice %[[INIT_TRANSP2]]
+//       CHECK:   %[[INIT2:.+]] = vector.insert_strided_slice %[[INIT_READ2]]
 //       CHECK:   %[[INIT_READ3:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP3:.+]] = vector.transpose %[[INIT_READ3]], [1, 0]
-//       CHECK:   %[[INIT3:.+]] = vector.insert_strided_slice %[[INIT_TRANSP3]]
+//       CHECK:   %[[INIT3:.+]] = vector.insert_strided_slice %[[INIT_READ3]]
 //       CHECK:   %[[INIT_READ4:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP4:.+]] = vector.transpose %[[INIT_READ4]], [1, 0]
-//       CHECK:   %[[INIT4:.+]] = vector.insert_strided_slice %[[INIT_TRANSP4]]
+//       CHECK:   %[[INIT4:.+]] = vector.insert_strided_slice %[[INIT_READ4]]
 //       CHECK:   %[[INIT_READ5:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP5:.+]] = vector.transpose %[[INIT_READ5]], [1, 0]
-//       CHECK:   %[[INIT5:.+]] = vector.insert_strided_slice %[[INIT_TRANSP5]]
+//       CHECK:   %[[INIT5:.+]] = vector.insert_strided_slice %[[INIT_READ5]]
 //       CHECK:   %[[INIT_READ6:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP6:.+]] = vector.transpose %[[INIT_READ6]], [1, 0]
-//       CHECK:   %[[INIT6:.+]] = vector.insert_strided_slice %[[INIT_TRANSP6]]
+//       CHECK:   %[[INIT6:.+]] = vector.insert_strided_slice %[[INIT_READ6]]
 //       CHECK:   %[[INIT_READ7:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<1x1xf32>
-//       CHECK:   %[[INIT_TRANSP7:.+]] = vector.transpose %[[INIT_READ7]], [1, 0]
-//       CHECK:   %[[INIT7:.+]] = vector.insert_strided_slice %[[INIT_TRANSP7]]
+//       CHECK:   %[[INIT7:.+]] = vector.insert_strided_slice %[[INIT_READ7]]
 //       CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT7]]) -> (vector<1x1x8x1x1x1xf32>)
 //       CHECK:     %[[LLOAD0:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
 //       CHECK:     %[[LLOAD1:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
index 520f7f9..76ca1f2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir
@@ -19,11 +19,9 @@
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
 
 // -----
@@ -47,11 +45,10 @@
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [2, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4],
-// CHECK-SAME:   outer_order = [1, 0], thread_order = [1, 0]
+// CHECK-SAME:   thread_order = [1, 0]
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [2, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0]
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 32]>
 
 // -----
@@ -125,11 +122,10 @@
 // CHECK-SAME:   subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME:   subgroup_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME:   subgroup_order = [1, 0]
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
 
 // -----
@@ -180,7 +176,7 @@
 // CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [16, 4]>
 //      CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1],
-// CHECK-SAME:   subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0], element_order = [1, 0],
+// CHECK-SAME:   subgroup_order = [1, 0], thread_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1], thread_basis = [4, 16]>
 
 //      CHECK: contract A vector layout: #iree_vector_ext.nested_layout<
@@ -189,11 +185,9 @@
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [4, 16]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [4, 16]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [4, 16]>
 
 // -----
@@ -257,11 +251,9 @@
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
 
 // -----
@@ -285,11 +277,10 @@
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, false, true], thread_basis = [1, 32]>
 //      CHECK: contract B vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16],
-// CHECK-SAME:   outer_order = [1, 0], thread_order = [1, 0],
+// CHECK-SAME:   thread_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [false, true, true], thread_basis = [1, 32]>
 //      CHECK: contract C vector layout: #iree_vector_ext.nested_layout<
 // CHECK-SAME:   subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [1, 1, 1], subgroup_active_ids = [true, true, false], thread_basis = [2, 16]>
 
 // -----
@@ -325,7 +316,6 @@
 // CHECK-SAME:   threads_per_outer = [4, 16],
 // CHECK-SAME:   elements_per_thread = [4, 1],
 // CHECK-SAME:   subgroup_order = [1, 0],
-// CHECK-SAME:   element_order = [1, 0],
 // CHECK-SAME:   subgroup_basis = [2, 1, 1, 1],
 // CHECK-SAME:   subgroup_active_ids = [false, false, true, true],
 // CHECK-SAME:   thread_basis = [4, 16]>
@@ -335,7 +325,6 @@
 // CHECK-SAME:   outers_per_batch = [1, 1, 1],
 // CHECK-SAME:   threads_per_outer = [1, 4, 16],
 // CHECK-SAME:   elements_per_thread = [1, 4, 1],
-// CHECK-SAME:   element_order = [0, 2, 1],
 // CHECK-SAME:   subgroup_basis = [2, 1, 1, 1],
 // CHECK-SAME:   subgroup_active_ids = [true, true, true, false],
 // CHECK-SAME:   thread_basis = [1, 4, 16]>
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 7096682..61b5710 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -130,8 +130,7 @@
     would have 2-D tile sizes per compute hierarchy level.
 
     We now describe each level of tiling. Each level of tiling represents a
-    count of tiles over the next level (rather than a list of tile sizes) and
-    an ordering over the tiles:
+    count of tiles over the next level (rather than a list of tile sizes).
 
     1. Subgroups per Workgroup
 
@@ -178,9 +177,7 @@
       for b_1 in range(batch_1):
         ...
 
-    Batches are represented using two attributes:
-      - batches_per_subgroup: Ranges of each loop
-      - batch_order: Ordering of each loop, from outermost to innermost
+    `batches_per_subgroup` represents the range of each loop.
 
     The second level, outers, is a way to represent thread layout duplication
     required by a particular intrinsic. For example, some AMDGPU matrix
@@ -193,9 +190,7 @@
     0 1 2 3 4     outers_per_batch=[2, 1]
     5 6 7 8 9     threads_per_outer=[2, 5]
 
-    Outers is represented using two attributes:
-      - outers_per_batch: Number of outers in a batch
-      - outer_order: Ordering of outers, from outermost to innermost
+    `outers_per_batch` represents the number of outers in a batch.
 
     Finally, threads are distributed in a single outer. The thread
     distribution is represented by:
@@ -210,8 +205,6 @@
       outers_per_batch = [2, 2]
       threads_per_outer = [2, 2]
 
-      batch_order = [0, 1]
-      outer_order = [0, 1]
       thread_order = [1, 0]
     }
 
@@ -259,10 +252,7 @@
     The final level of tiling, representing the minimum shape of vector that
     is treated as an atom.
 
-    The elements are placed contigiously with their shape and ordering
-    determined by:
-      - `elements_per_thread`: Sizes of this level of tiling
-      - `element_order`: Ordering of dimensions, from outermost to innermost
+    `elements_per_thread` represents the native size of the vector.
   }];
 
   let parameters = (ins
@@ -270,16 +260,13 @@
     ArrayRefParameter<"int64_t", "subgroup_order">:$subgroupOrder,
 
     ArrayRefParameter<"int64_t", "batches_per_subgroup">:$batchesPerSubgroup,
-    ArrayRefParameter<"int64_t", "batch_order">:$batchOrder,
 
     ArrayRefParameter<"int64_t", "outers_per_batch">:$outersPerBatch,
-    ArrayRefParameter<"int64_t", "outer_order">:$outerOrder,
 
     ArrayRefParameter<"int64_t", "threads_per_outer">:$threadsPerOuter,
     ArrayRefParameter<"int64_t", "thread_order">:$threadOrder,
 
     ArrayRefParameter<"int64_t", "elements_per_thread">:$elementsPerThread,
-    ArrayRefParameter<"int64_t", "element_order">:$elementOrder,
 
     ArrayRefParameter<"int64_t", "subgroup_basis">:$subgroupBasis,
     ArrayRefParameter<"bool", "subgroup_active_ids">:$subgroupActiveIds,
@@ -298,10 +285,7 @@
         `elements_per_thread`     `=` `[` $elementsPerThread `]` `,`
 
         custom<Permutation>("\"subgroup_order\"", ref($subgroupsPerWorkgroup), "true", $subgroupOrder) ``
-        custom<Permutation>("\"batch_order\"", ref($batchesPerSubgroup), "true", $batchOrder) ``
-        custom<Permutation>("\"outer_order\"", ref($outersPerBatch), "true", $outerOrder) ``
         custom<Permutation>("\"thread_order\"", ref($threadsPerOuter), "true", $threadOrder) ``
-        custom<Permutation>("\"element_order\"", ref($elementsPerThread), "true", $elementOrder) ``
 
         custom<Basis>("\"subgroup_basis\"", "\"subgroup_active_ids\"", "true", $subgroupBasis, $subgroupActiveIds) ``
         custom<Basis>("\"thread_basis\"", "\"thread_active_ids\"", "false", $threadBasis, $threadActiveIds) ``
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 0234765..b481964 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -285,12 +285,8 @@
 
   SmallVector<int64_t> subgroupOrder =
       getRankReducedPermutation(getSubgroupOrder());
-  SmallVector<int64_t> batchOrder = getRankReducedPermutation(getBatchOrder());
-  SmallVector<int64_t> outerOrder = getRankReducedPermutation(getOuterOrder());
   SmallVector<int64_t> threadOrder =
       getRankReducedPermutation(getThreadOrder());
-  SmallVector<int64_t> elementOrder =
-      getRankReducedPermutation(getElementOrder());
 
   // Compose the projected dims with the basis mask to get the new active
   // ids. Active ids indicates that we should use the ids marked as true, and
@@ -322,9 +318,8 @@
   composeMasks(threadMask, invertedDroppedSubgroupMask);
 
   return NestedLayoutAttr::get(getContext(), subgroupCount, subgroupOrder,
-                               batchCount, batchOrder, outerCount, outerOrder,
-                               threadCount, threadOrder, elementCount,
-                               elementOrder, getSubgroupBasis(), subgroupMask,
+                               batchCount, outerCount, threadCount, threadOrder,
+                               elementCount, getSubgroupBasis(), subgroupMask,
                                getThreadBasis(), threadMask);
 }
 
@@ -360,33 +355,28 @@
       applyPermutation(invPerm, getSubgroupOrder());
   SmallVector<int64_t> batchCount =
       applyPermutation(getBatchesPerSubgroup(), permutation);
-  SmallVector<int64_t> batchOrder = applyPermutation(invPerm, getBatchOrder());
   SmallVector<int64_t> outerCount =
       applyPermutation(getOutersPerBatch(), permutation);
-  SmallVector<int64_t> outerOrder = applyPermutation(invPerm, getOuterOrder());
   SmallVector<int64_t> threadCount =
       applyPermutation(getThreadsPerOuter(), permutation);
   SmallVector<int64_t> threadOrder =
       applyPermutation(invPerm, getThreadOrder());
   SmallVector<int64_t> elementCount =
       applyPermutation(getElementsPerThread(), permutation);
-  SmallVector<int64_t> elementOrder =
-      applyPermutation(invPerm, getElementOrder());
 
   return NestedLayoutAttr::get(
-      getContext(), subgroupCount, subgroupOrder, batchCount, batchOrder,
-      outerCount, outerOrder, threadCount, threadOrder, elementCount,
-      elementOrder, getSubgroupBasis(), getSubgroupActiveIds(),
-      getThreadBasis(), getThreadActiveIds());
+      getContext(), subgroupCount, subgroupOrder, batchCount, outerCount,
+      threadCount, threadOrder, elementCount, getSubgroupBasis(),
+      getSubgroupActiveIds(), getThreadBasis(), getThreadActiveIds());
 }
 
 /// We distribute to:
 /// <BATCH x OUTER x ELEMENT>
 SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
   SmallVector<int64_t> shape;
-  shape.append(applyPermutation(getBatchesPerSubgroup(), getBatchOrder()));
-  shape.append(applyPermutation(getOutersPerBatch(), getOuterOrder()));
-  shape.append(applyPermutation(getElementsPerThread(), getElementOrder()));
+  shape.append(getBatchesPerSubgroup().begin(), getBatchesPerSubgroup().end());
+  shape.append(getOutersPerBatch().begin(), getOutersPerBatch().end());
+  shape.append(getElementsPerThread().begin(), getElementsPerThread().end());
   return shape;
 }
 
@@ -437,19 +427,26 @@
 LogicalResult NestedLayoutAttr::verify(
     llvm::function_ref<InFlightDiagnostic()> emitError,
     ArrayRef<int64_t> subgroupsPerWorkgroup, ArrayRef<int64_t> subgroupOrder,
-    ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> batchOrder,
-    ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder,
+    ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> outersPerBatch,
     ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder,
-    ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder,
-    ArrayRef<int64_t> subgroupBasis, ArrayRef<bool> subgroupActiveIds,
-    ArrayRef<int64_t> threadBasis, ArrayRef<bool> threadActiveIds) {
+    ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> subgroupBasis,
+    ArrayRef<bool> subgroupActiveIds, ArrayRef<int64_t> threadBasis,
+    ArrayRef<bool> threadActiveIds) {
 
   size_t rank = subgroupsPerWorkgroup.size();
-  auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) {
-    if (tileShape.size() != rank || order.size() != rank) {
+  auto checkTile = [&](ArrayRef<int64_t> tileShape) {
+    if (tileShape.size() != rank) {
       emitError() << "all tiles must have the same rank as the layout";
       return failure();
     }
+    return success();
+  };
+
+  auto checkOrder = [&](ArrayRef<int64_t> order) {
+    if (order.size() != rank) {
+      emitError() << "all orders must have the same rank as the layout";
+      return failure();
+    }
     if (!mlir::isPermutationVector(order)) {
       emitError() << "all orderings must be permutation vectors";
       return failure();
@@ -457,11 +454,14 @@
     return success();
   };
 
-  if (failed(checkTile(subgroupsPerWorkgroup, subgroupOrder)) ||
-      failed(checkTile(batchesPerSubgroup, batchOrder)) ||
-      failed(checkTile(outersPerBatch, outerOrder)) ||
-      failed(checkTile(threadsPerOuter, threadOrder)) ||
-      failed(checkTile(elementsPerThread, elementOrder))) {
+  if (failed(checkTile(subgroupsPerWorkgroup)) ||
+      failed(checkTile(batchesPerSubgroup)) ||
+      failed(checkTile(outersPerBatch)) || failed(checkTile(threadsPerOuter)) ||
+      failed(checkTile(elementsPerThread))) {
+    return failure();
+  }
+
+  if (failed(checkOrder(subgroupOrder)) || failed(checkOrder(threadOrder))) {
     return failure();
   }
 
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 5600905..275eac0 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -29,10 +29,7 @@
   elements_per_thread = [1, 4],
 
   subgroup_order = [0, 1],
-  batch_order = [0, 1],
-  outer_order = [1, 0],
   thread_order = [0, 1],
-  element_order = [0, 1],
 
   subgroup_basis = [1, 1],
   thread_basis   = [4, 2]
@@ -46,9 +43,7 @@
   elements_per_thread = [4, 1],
 
   subgroup_order = [1, 0],
-  batch_order = [1, 0],
   thread_order = [1, 0],
-  element_order = [1, 0],
 
   subgroup_basis = [1, 1],
   thread_basis   = [2, 4]
@@ -62,9 +57,7 @@
   elements_per_thread = [4, 1],
 
   subgroup_order = [1, 0],
-  batch_order = [1, 0],
   thread_order = [1, 0],
-  element_order = [1, 0],
 
   subgroup_basis = [2, 4, 8],
   subgroup_active_ids = [true, true, false],
@@ -79,9 +72,7 @@
   elements_per_thread = [4, 1],
 
   subgroup_order = [1, 0],
-  batch_order = [1, 0],
   thread_order = [1, 0],
-  element_order = [1, 0],
 
   subgroup_basis = [2, 4, 8],
   subgroup_active_ids = [true, true, false],
@@ -97,9 +88,7 @@
   elements_per_thread = [4, 1],
 
   subgroup_order = [1, 0],
-  batch_order = [1, 0],
   thread_order = [1, 0],
-  element_order = [1, 0],
 
   subgroup_basis = [2, 4],
   subgroup_active_ids = [true, true],
@@ -127,7 +116,6 @@
 // CHECK-SAME: outers_per_batch = [4, 1],
 // CHECK-SAME: threads_per_outer = [4, 2],
 // CHECK-SAME: elements_per_thread = [1, 4],
-// CHECK-SAME: outer_order = [1, 0],
 // CHECK-SAME: subgroup_basis = [1, 1],
 // CHECK-SAME: thread_basis = [4, 2]>
 
@@ -138,9 +126,7 @@
 // CHECK-SAME: threads_per_outer = [2, 4],
 // CHECK-SAME: elements_per_thread = [4, 1],
 // CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
 // CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
 // CHECK-SAME: subgroup_basis = [1, 1],
 // CHECK-SAME: thread_basis = [2, 4]>
 
@@ -151,9 +137,7 @@
 // CHECK-SAME: threads_per_outer = [2, 4],
 // CHECK-SAME: elements_per_thread = [4, 1],
 // CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
 // CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
 // CHECK-SAME: subgroup_basis = [2, 4, 8],
 // CHECK-SAME: subgroup_active_ids = [true, true, false],
 // CHECK-SAME: thread_basis = [2, 4]>
@@ -165,9 +149,7 @@
 // CHECK-SAME: threads_per_outer = [2, 4],
 // CHECK-SAME: elements_per_thread = [4, 1],
 // CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
 // CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
 // CHECK-SAME: subgroup_basis = [2, 4, 8],
 // CHECK-SAME: subgroup_active_ids = [true, true, false],
 // CHECK-SAME: thread_basis = [2, 4, 2],
@@ -180,9 +162,7 @@
 // CHECK-SAME: threads_per_outer = [2, 4],
 // CHECK-SAME: elements_per_thread = [4, 1],
 // CHECK-SAME: subgroup_order = [1, 0],
-// CHECK-SAME: batch_order = [1, 0],
 // CHECK-SAME: thread_order = [1, 0],
-// CHECK-SAME: element_order = [1, 0],
 // CHECK-SAME: subgroup_basis = [2, 4],
 // CHECK-SAME: thread_basis = [4, 2]>