Mainloop multistaged pipeline and instruction scheduling for NVIDIA Ampere Tensor Cores (F16 and F32) (#12603)

This PR creates a coarse-grained multistage pipelining and fine-grained
instruction scheduling for optimal performance on NVIDIA Ampere Tensor
Cores.

Multi-staging is essential to hide the Global Memory load latency by
building longer software pipelines and using the available Shared Memory
capacity, especially on NVIDIA A100. Additionally, fine-grained
instruction scheduling hides the Shared Memory load latency by
prefetching the math operands into registers. The PR adds support for
F16 and F32 datatype using `mma.sync` native Tensor Core exposure.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
index f717357..37530ce 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -9,6 +9,7 @@
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -18,6 +19,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#define DEBUG_TYPE "iree-gpu-pipelining"
 
 //====---------------------------------------------------------------------===//
 // Pass to pipeline copy to shared memory for matmul op.
@@ -43,12 +45,12 @@
   return addrSpace.isa<IREE::HAL::DescriptorTypeAttr>();
 }
 
-// Returns a new predicated operation to support unpeeled epilogue. Unpeeled
-// epilogue needs to handle the last iterations within the mainloop which
-// requires predicating operations, for e.g., OOB global memory access. This
-// helper function predicates operations (where predication is avialable),
-// checks if unpredicated operations are side-effect free and acceptable to
-// execute speculatively.
+/// Returns a new predicated operation to support unpeeled epilogue. Unpeeled
+/// epilogue needs to handle the last iterations within the mainloop which
+/// requires predicating operations, for e.g., OOB global memory access. This
+/// helper function predicates operations (where predication is avialable),
+/// checks if unpredicated operations are side-effect free and acceptable to
+/// execute speculatively.
 static Operation* replaceOpWithPredicatedOp(Operation* op, Value pred,
                                             PatternRewriter& rewriter) {
   // Predication is only supported for AsyncCopyOp. Thus, for operations which
@@ -66,6 +68,10 @@
       auto loadBaseType = loadOp.getBase().getType().cast<MemRefType>();
       if (hasSharedMemoryAddressSpace(loadBaseType)) return op;
     }
+    if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+      auto loadBaseType = loadOp.getMemRefType();
+      if (hasSharedMemoryAddressSpace(loadBaseType)) return op;
+    }
     // If we are here that means the operation does not have predication support
     // and cannot be speculatively executed. Thus, unpeeled epilogue is not
     // supported.
@@ -109,21 +115,6 @@
   }
 }
 
-/// Add an op and its dependency to `ops` set and skip operations contained into
-/// filter. This also adds all the ops to filter so that they don't get matched
-/// again.
-static void addOpsAndDeps(llvm::SmallDenseSet<Operation*>& filter,
-                          llvm::SmallDenseSet<Operation*>& ops, Operation* op,
-                          Block* block) {
-  if (!filter.insert(op).second) return;
-  ops.insert(op);
-  for (Value operand : op->getOperands()) {
-    Operation* defOp = operand.getDefiningOp();
-    if (defOp && defOp->getBlock() == block)
-      addOpsAndDeps(filter, ops, defOp, block);
-  }
-}
-
 /// Assign stages to the loop ops. Simple logic by default, put load from global
 /// memory in stage 0 and the rest in stage 1. If store_stage = 0 then put store
 /// to shared memory in stage 0 as well.
@@ -236,107 +227,341 @@
   return copyToWorkgroupMemory;
 }
 
-/// Return true if the op is used as operand `index` of an mma.sync op.
-static bool isMMAOperand(Operation* op, int64_t index) {
-  OpOperand* use = &(*((op->getUses()).begin()));
-  if (auto extract = dyn_cast<vector::ExtractStridedSliceOp>(use->getOwner())) {
-    use = &(*((extract->getUses()).begin()));
+/// Warp-level TensorOp.
+/// The data structures holds the operations and their dependencies for the
+/// Shared Memory and Tensor Core (mma.sync) for a kgroup.
+struct WarpMmaOp {
+  // Load matrixA from Shared Memory to Registers.
+  llvm::SetVector<Operation*> loadOperationsA;
+  // Load matrixB from Shared Memory to Registers.
+  llvm::SetVector<Operation*> loadOperationsB;
+  // Warp-level Tensor Core operations on Registers.
+  llvm::SetVector<Operation*> mmaOperations;
+};
+
+/// Structure to hold the matmul's mainloop information:
+/// Seperates the mma operations into kgroups and collects the Shared Memory
+/// loads for each kgroup. This information is used to pipeline the mainloop and
+/// to generate an optimal schedule interleaving Global Memory loads, Shared
+/// Memory loads, and math operations.
+struct MainLoopInfo {
+  // Mainloop asyncronous copy operations:
+  // `cp.async` GlobalMemory -> SharedMemory
+  llvm::SetVector<Operation*> copyGlobalToSharedOps;
+  llvm::SetVector<Operation*> asyncCreateGroupOp;
+  llvm::SetVector<Operation*> barrierOps;
+  llvm::SetVector<Operation*> asyncWaitOps;
+
+  // Mainloop asyncronous copy operations dependencies
+  llvm::SetVector<Operation*> copyGlobalToSharedOpDeps;
+
+  // Warp-level syncronous operations:
+  // `ldmatrix, ld.shared` SharedMemory -> Registers
+  // `mma.sync` Registers -> Tensor Cores.
+  llvm::SmallVector<WarpMmaOp, 4> warpOperations;
+
+  // Set to track the dependencies already seen to a backward slice.
+  llvm::SetVector<Operation*> seenDepOps;
+
+  // Set to track the mma operations in forward slice to count kgroups and
+  // populate the warp-level warpOperations
+  llvm::SetVector<Operation*> seenMmaOps;
+
+  // Boolen to store if the mainloop can be pipelined (coarse-grained
+  // scheduling) and the instructions can be interleaved (fine-grained
+  // scheduling).
+  bool isSchedulable = false;
+
+  // Populates the dependent operations in ``dependentOps`` for the given a op
+  // recursively that are in the same block and not added to the backward slice
+  // of some other op.
+  void backwardSliceOfDependentOps(llvm::SetVector<Operation*>& dependentOps,
+                                   Operation* op, Block* block) {
+    if (!seenDepOps.insert(op)) return;
+    // Add the unseen op to the dependentOps and recurse on its operands.
+    dependentOps.insert(op);
+    for (Value operand : op->getOperands()) {
+      Operation* defOp = operand.getDefiningOp();
+      if (defOp && defOp->getBlock() == block)
+        backwardSliceOfDependentOps(dependentOps, defOp, block);
+    }
   }
-  if (!isa<nvgpu::MmaSyncOp>(use->getOwner())) return false;
-  return use->getOperandNumber() == index;
+
+  // Backtrack from the MmaSyncOp operand (mlir::OpOperand) to its defining
+  // mlir::Operation to find the ldmatrix or ld.shared operations that load
+  // MmaSyncOp operands.
+  void backtrackToFindSmemLoad(Operation* op,
+                               llvm::SetVector<Operation*>& loadOperations,
+                               Block* block) {
+    if (!op) return;
+
+    // If the operation is a ldmatrix or ld.shared, then add it to the set of
+    // the current kgroup load operations.
+    if (isa<nvgpu::LdMatrixOp, memref::LoadOp>(op)) {
+      if (op->getBlock() == block) {
+        loadOperations.insert(op);
+      }
+      return;
+    }
+
+    // If the operation is not a ldmatrix or ld.shared, then it has to be a
+    // vector.extract_strided_slice or vector.insert operation  to recurse up
+    // the chain to find the ldmatrix or ld.shared Shared Memory load operation.
+    if (!isa<vector::ExtractStridedSliceOp, vector::InsertOp>(op)) {
+      LLVM_DEBUG({
+        llvm::dbgs()
+            << "Unable to find Shared Memory load for MmaSyncOp. Thus, the "
+               "Shared Memory load into the register operand will not be "
+               "pipelined."
+            << "\n";
+      });
+      return;
+    }
+
+    // Recurse upwards towards the definition until a Shared Memory load is
+    // found. Only vector.extract_strided_slice or vector.insert operations
+    // leading up to a Shared Memory loads are considered.
+    Operation* defOp = op->getOperand(0).getDefiningOp();
+
+    backtrackToFindSmemLoad(defOp, loadOperations, block);
+  }
+
+  // Recursively traverse the chain of mma operations for all kgroups from 0
+  // (start) to numKgroups (ends scf.yield).
+  // Assumption: The mma operations are in a chain in monotonic increasing
+  // kgroup order.
+  void vistMmaSyncOp(Operation* op, int kgroup) {
+    // if the operation in an `scf.yield`, we reached the end of MmaSyncOp chain
+    // return.
+    if (seenMmaOps.count(op) || isa<scf::YieldOp>(op)) return;
+
+    seenMmaOps.insert(op);
+
+    // If the kgroup is not in the vector, create a new WarpMmaOp.
+    if (warpOperations.size() < kgroup + 1)
+      warpOperations.push_back(WarpMmaOp());
+
+    backtrackToFindSmemLoad(op->getOperand(0).getDefiningOp(),
+                            warpOperations[kgroup].loadOperationsA,
+                            op->getBlock());
+
+    backtrackToFindSmemLoad(op->getOperand(1).getDefiningOp(),
+                            warpOperations[kgroup].loadOperationsB,
+                            op->getBlock());
+
+    warpOperations[kgroup].mmaOperations.insert(op);
+
+    vistMmaSyncOp((op->getUses().begin())->getOwner(), ++kgroup);
+  }
+
+  MainLoopInfo(scf::ForOp forOp) : isSchedulable(true) { analyze(forOp); }
+
+  // Iterate through the mainloop and collect `cp.async`, `cp.commit_group`,
+  // `cp.wait_group`, and `barrier` operations. These operations are used to
+  // pipeline the mainloop and cheorograph asyncroncy for a *coarse-grained*
+  // schedule. Additionally, collect the `mma.sync` and `ldmatrix`/`ld.shared`
+  // operations and separate them into kgroups. The information is helpful in
+  // generating an optimal *finer-grained* instruction interleaving of global
+  // memory loads, shared memory loads, and math operations.
+  void analyze(scf::ForOp forOp) {
+    for (Operation& op : forOp.getBody()->getOperations()) {
+      if (op.getNumRegions() > 0) {
+        // Pipeline and schedule the most inner for op ,i.e., the mainloop that
+        // should be a flat region.
+        isSchedulable = false;
+        return;
+      }
+      if (isa<nvgpu::DeviceAsyncCopyOp>(op)) {
+        copyGlobalToSharedOps.insert(&op);
+      }
+      if (isa<nvgpu::DeviceAsyncCreateGroupOp>(op)) {
+        asyncCreateGroupOp.insert(&op);
+      }
+      if (isa<gpu::BarrierOp>(op)) {
+        barrierOps.insert(&op);
+      }
+      if (isa<nvgpu::DeviceAsyncWaitOp>(op)) {
+        asyncWaitOps.insert(&op);
+      }
+      if (isa<nvgpu::MmaSyncOp>(op)) {
+        // MmaSyncOp visitor traverses the chain of mma operations and separates
+        // them into kgroups.
+        vistMmaSyncOp(&op, 0 /*kgroup=0*/);
+      }
+    }
+
+    // If one of the ingredients (`cp.async`, `cp.commit_group`,
+    // `cp.wait_group`, `bar.sync`, `mma.sync`, `ldmatrix` or `ld.shared`) for
+    // scheduling is missing, the mainloop cannot be scheduled.
+    if (copyGlobalToSharedOps.empty() || asyncCreateGroupOp.empty() ||
+        asyncWaitOps.empty() || barrierOps.empty() || warpOperations.empty()) {
+      isSchedulable = false;
+      return;
+    }
+
+    // Collect the dependent operations for `cp.async` in the mainloop order for
+    // coarse-grained software pipeling. The deps are collected in stage order,
+    // i.e., `cp.async`'s deps in stage 0 are collected first.
+    for (Operation& op : forOp.getBody()->getOperations()) {
+      if (isa<nvgpu::DeviceAsyncCopyOp>(&op)) {
+        backwardSliceOfDependentOps(copyGlobalToSharedOpDeps, &op,
+                                    forOp.getBody());
+      }
+    }
+
+    // Collect the dependent operations for `mma.sync` and `ldmatrix/ld.shared`
+    // operations seperated by kgroups for fine-grained instruction scheduling.
+    for (int kgroup = 0; kgroup < getNumberOfKgroups(); ++kgroup) {
+      for (Operation& op : forOp.getBody()->getOperations()) {
+        if (isa<nvgpu::LdMatrixOp, memref::LoadOp>(&op)) {
+          if (warpOperations[kgroup].loadOperationsA.count(&op)) {
+            backwardSliceOfDependentOps(warpOperations[kgroup].loadOperationsA,
+                                        &op, forOp.getBody());
+          }
+          if (warpOperations[kgroup].loadOperationsB.count(&op)) {
+            backwardSliceOfDependentOps(warpOperations[kgroup].loadOperationsB,
+                                        &op, forOp.getBody());
+          }
+        }
+      }
+      for (Operation& op : forOp.getBody()->getOperations()) {
+        if (isa<nvgpu::MmaSyncOp>(&op)) {
+          if (warpOperations[kgroup].mmaOperations.count(&op)) {
+            backwardSliceOfDependentOps(warpOperations[kgroup].mmaOperations,
+                                        &op, forOp.getBody());
+          }
+        }
+      }
+    }
+  }
+
+  // Returns the number of kgroups in the Warp-level MMA operations.
+  int getNumberOfKgroups() { return warpOperations.size(); }
+};
+
+/// Prints the given `funcOp` after a leading `step` comment header.
+static void debugMainloopSchedule(
+    MainLoopInfo& mainloop, int numStages,
+    std::vector<std::pair<Operation*, unsigned>>& ops) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "//--- Mainloop schedule generated for Nvidia Ampere "
+                    "mma.sync TensorCore Pipeline. ---//\n";
+    llvm::dbgs() << " Number of stages: " << numStages << "\n";
+    llvm::dbgs() << " Number of kgroups: " << mainloop.getNumberOfKgroups()
+                 << "\n";
+    llvm::dbgs() << " Number of mainloop instructions " << ops.size() << "\n";
+    llvm::dbgs() << " Mainloop instructions schedule and stage assignment: ";
+    for (auto& stage_op_pair : ops) {
+      llvm::dbgs() << " Stage (" << stage_op_pair.second << ") , Operation: ";
+      stage_op_pair.first->dump();
+    }
+    llvm::dbgs() << "\n\n";
+  });
 }
 
-/// Return true if the op is used as operand A of an mma.sync op.
-static bool isAOperand(Operation* op) { return isMMAOperand(op, 0); }
-/// Return true if the op is used as operand B of an mma.sync op.
-static bool isBOperand(Operation* op) { return isMMAOperand(op, 1); }
-
-/// Return a pipelining schedule that gives good performance on Nvidia
-/// Ampere target.
-static void getNvidiaTensorCorePipeline(
+/// This function returns an *coarse-grained* stage assignment for software
+/// pipelining of the mainloop and a *fine-grained* instruction interleaving.
+/// The schedule provides good performance on Nvidia Ampere architecture using
+/// Ampere-style multi-staged pipeline.
+///
+/// @param forOp the mainloop to pipeline and schedule.
+/// @param ops a vector of pairs: [(operations, pipeline_stage)].
+/// @param numStages the total number of pipeline stages used for multi-buffer.
+static void getNvidiaAmpereTensorCorePipeline(
     scf::ForOp forOp, std::vector<std::pair<Operation*, unsigned>>& ops,
-    unsigned depth) {
-  bool loopCanBePipelined = false;
-  // TODO: Tune this and make it a more fine grain logic.
-  static constexpr int64_t numPrefetchSmemLoadPerOperand = 4;
-  SmallVector<Operation*> stageCopyToSharedMemory;
-  SmallVector<Operation*> stagePrefetch;
-  SmallVector<Operation*> stageCompute;
-  int64_t numLoadA = 0;
-  int64_t numLoadB = 0;
-  for (Operation& op : forOp.getBody()->getOperations()) {
-    // Pipeline the most inner for op that should be a flat region.
-    if (op.getNumRegions() > 0) {
-      loopCanBePipelined = false;
-      break;
-    }
-    if (isa<gpu::BarrierOp, nvgpu::DeviceAsyncWaitOp>(op)) {
-      stagePrefetch.push_back(&op);
-    }
-    if (isa<nvgpu::MmaSyncOp>(op)) {
-      stageCompute.push_back(&op);
-    }
-    if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
-      stageCopyToSharedMemory.push_back(&op);
-    }
-    if (isa<nvgpu::LdMatrixOp>(op)) {
-      // Prefecth some of the ldmatrix.
-      if (isAOperand(&op)) {
-        numLoadA++;
-        if (numLoadA <= numPrefetchSmemLoadPerOperand) {
-          stagePrefetch.push_back(&op);
-          continue;
-        }
-      }
-      if (isBOperand(&op)) {
-        numLoadB++;
-        if (numLoadB <= numPrefetchSmemLoadPerOperand) {
-          stagePrefetch.push_back(&op);
-          continue;
-        }
-      }
-      // If not prefected go in the last stage.
-      stageCompute.push_back(&op);
-    }
-  }
+    unsigned numStages) {
+  // Analyze the main loop and obtain information for coarse-grained pipelining
+  // and fine-grained instruction scheduling.
+  MainLoopInfo mainloop(forOp);
 
-  // Return an empty schedule if the loop is not a candidate to be pipelined.
-  if (loopCanBePipelined || stageCopyToSharedMemory.empty() ||
-      stageCompute.empty())
+  // If the mainloop is not schedulable, return an empty schedule.
+  if (!mainloop.isSchedulable) return;
+
+  // NVIDIA Ampere Tensor Core multi-staged pipeline requires at least 2 kgroups
+  // and 3 software pipeline stages. If the conditions are not met, return an
+  // empty schedule.
+  int numKgroups = mainloop.getNumberOfKgroups();
+  if (numKgroups < 2 || numStages < 3) {
     return;
-
-  // Track dependencies of stage 0 ops.
-  llvm::SmallDenseSet<Operation*> deps;
-  llvm::SmallDenseSet<Operation*> stageCopyToSharedMemoryDeps;
-  llvm::SmallDenseSet<Operation*> stageNMinusOneDeps;
-  llvm::SmallDenseSet<Operation*> stageNDeps;
-  for (Operation* op : stageCopyToSharedMemory) {
-    addOpsAndDeps(deps, stageCopyToSharedMemoryDeps, op, forOp.getBody());
   }
 
-  for (Operation* op : stagePrefetch) {
-    addOpsAndDeps(deps, stageNMinusOneDeps, op, forOp.getBody());
+  // Un-pipelined mainloop should have only one occurance of
+  // cp.async.commit_group and cp.async.wait_group. Additionally, two barrier
+  // ops are inserted around each staged copy. The barrier op before the copy is
+  // un-necessary and will be removed. If the conditions are not met, return an
+  // empty schedule.
+  if (!(mainloop.asyncCreateGroupOp.size() == 1) ||
+      !(mainloop.asyncWaitOps.size() == 1) ||
+      !(mainloop.barrierOps.size() == 2)) {
+    return;
   }
 
-  for (Operation* op : stageCompute) {
-    addOpsAndDeps(deps, stageNDeps, op, forOp.getBody());
+  // Start pipelining and scheduling the main loop, all kgroups but the last
+  // one.
+  for (int kgroup = 0; kgroup < numKgroups - 1; kgroup++) {
+    // Fine-grained instruction scheduling: interleave Shared Memory loads
+    // into and mma.sync operations to hide load latencies.
+
+    // Load the next kgroup into registers.
+    for (Operation& op : forOp.getBody()->getOperations()) {
+      if (mainloop.warpOperations[kgroup + 1].loadOperationsA.count(&op) ||
+          mainloop.warpOperations[kgroup + 1].loadOperationsB.count(&op)) {
+        ops.push_back(std::make_pair(&op, numStages - 1));
+      }
+    }
+
+    // Issue mma.sync on previous loaded kgroup.
+    for (Operation& op : forOp.getBody()->getOperations()) {
+      if (mainloop.warpOperations[kgroup].mmaOperations.count(&op))
+        ops.push_back(std::make_pair(&op, numStages - 1));
+    }
   }
-  // Schedule Last stage followed by stage 0 follwed by prefetch.
+
+  // Coarse-grained instruction pipelining: pipeline Global Memory
+  // transfer (GMEM -> SMEM) several stages in advance.
+
+  // Schedule all cp.async and one cp.async.commit_group.
+  // TODO: Distribute cp.async throughout the main loop and do not concentrate
+  // it at one place.
+  // Schedule all cp.async and one cp.async.commit_group.
   for (Operation& op : forOp.getBody()->getOperations()) {
-    if (stageNDeps.count(&op)) ops.push_back(std::make_pair(&op, depth - 1));
+    if (mainloop.copyGlobalToSharedOpDeps.count(&op))
+      ops.push_back(std::make_pair(&op, 0 /*pipelineStage*/));
   }
+  ops.push_back(
+      std::make_pair(mainloop.asyncCreateGroupOp[0], 0 /*pipelineStage*/));
+
+  // Schedule and pipeline all async.wait and barrier
+  ops.push_back(std::make_pair(mainloop.asyncWaitOps[0], numStages - 2));
+  mainloop.barrierOps[0]->erase();
+  ops.push_back(std::make_pair(mainloop.barrierOps[1], numStages - 2));
+  //////////////////////////////////////////////////////////////////////////////
+
+  // Coarse-grained instruction pipelining: pipeline Shared Memory loads
+  // (SMEM -> Registers) for the first kgroup (kgroup = 0) one stage in
+  // advance.
+
+  // Schedule the Shared Memory loads for the first kgroup and pipeline them
+  // into one stage ahead.
   for (Operation& op : forOp.getBody()->getOperations()) {
-    if (stageCopyToSharedMemoryDeps.count(&op))
-      ops.push_back(std::make_pair(&op, 0));
+    if (mainloop.warpOperations[0].loadOperationsA.count(&op) ||
+        mainloop.warpOperations[0].loadOperationsB.count(&op))
+      ops.push_back(std::make_pair(&op, numStages - 2));
   }
+
+  // Issue mma.sync on for the last kgroup at the end of the mainloop.
   for (Operation& op : forOp.getBody()->getOperations()) {
-    if (stageNMinusOneDeps.count(&op))
-      ops.push_back(std::make_pair(&op, depth - 2));
+    if (mainloop.warpOperations[numKgroups - 1].mmaOperations.count(&op))
+      ops.push_back(std::make_pair(&op, numStages - 1));
   }
+
+  // Prints the mainloop schedule generated for NVIDIA Ampere through native
+  // Tensor Core operations (asyncronous copy, load matrix, and mma.sync).
+  debugMainloopSchedule(mainloop, numStages, ops);
 }
 
-/// Apply pipeline rewrite pattern assuming the operations were already
-/// annotated with stage information.
+// Apply pipeline rewrite pattern assuming the operations were already
+// annotated with stage information.
 // TODO: move away from using attribute annotations.
 static FailureOr<scf::ForOp> applyPipelining(
     scf::ForOp forOp, int64_t depth, bool epiloguePeeling,
@@ -357,7 +582,7 @@
                          scf::ForOp forOp,
                          std::vector<std::pair<Operation*, unsigned>>& ops) {
     if (schedule == PipeliningSchedulingStrategy::nvidiaTensorCore) {
-      return getNvidiaTensorCorePipeline(forOp, ops, maxDepth);
+      return getNvidiaAmpereTensorCorePipeline(forOp, ops, maxDepth);
     }
     return getPipelineStages(forOp, ops, maxDepth);
   };
@@ -431,8 +656,8 @@
 /// Pass options
 /// epiloguePeeling - try enable/disable epilogue peeling.
 /// true  : Peel epilogue (no additional checks required)
-/// false : Try and use unpeeled epilogue (check if predication is supported is
-/// avialable)
+/// false : Try and use unpeeled epilogue (check if predication is supported
+/// is avialable)
 std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
     bool epiloguePeeling, unsigned depth,
     PipeliningSchedulingStrategy schedule) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir
index cbf33c7..32597f9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_pipeline.mlir
@@ -57,7 +57,7 @@
 
 // -----
 
-func.func @nvidia_tenscore_schedule() {
+func.func @nvidia_tenscore_schedule_f16() {
   %c3 = arith.constant 3 : index
   %c31 = arith.constant 31 : index
   %c2 = arith.constant 2 : index
@@ -482,7 +482,7 @@
   return
 }
 
-//  CHECK-NV-LABEL: func.func @nvidia_tenscore_schedule
+//  CHECK-NV-LABEL: func.func @nvidia_tenscore_schedule_f16
 //  CHECK-NV-COUNT-6:  nvgpu.device_async_copy
 //          CHECK-NV:  nvgpu.device_async_create_group
 //  CHECK-NV-COUNT-6:  nvgpu.device_async_copy
@@ -491,12 +491,859 @@
 //          CHECK-NV:  gpu.barrier
 //  CHECK-NV-COUNT-8:  nvgpu.ldmatrix
 //          CHECK-NV:  scf.for
-//  CHECK-NV-COUNT-8:    nvgpu.ldmatrix
-// CHECK-NV-COUNT-64:    nvgpu.mma.sync
+//  CHECK-NV-COUNT-4:    nvgpu.ldmatrix
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
 //  CHECK-NV-COUNT-6:    nvgpu.device_async_copy
 //          CHECK-NV:    nvgpu.device_async_create_group
 //          CHECK-NV:    nvgpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
 //          CHECK-NV:    gpu.barrier
 //  CHECK-NV-COUNT-8:    nvgpu.ldmatrix
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
 //          CHECK-NV:  }
-//          CHECK-NV:  vector.store
\ No newline at end of file
+//          CHECK-NV:  vector.store
+
+// -----
+func.func @nvidia_tenscore_schedule_f32() {
+  %c31 = arith.constant 31 : index
+  %c2 = arith.constant 2 : index
+  %c7 = arith.constant 7 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<2x1xf32>
+  %c32 = arith.constant 32 : index
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+  %c256 = arith.constant 256 : index
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %0 = gpu.thread_id  x
+  %1 = gpu.thread_id  y
+  %2 = gpu.thread_id  z
+  %alloc = memref.alloc() : memref<128x128xf32, #gpu.address_space<workgroup>>
+  %alloc_2 = memref.alloc() : memref<3x128x32xf32, #gpu.address_space<workgroup>>
+  %alloc_3 = memref.alloc() : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+  %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<256x256xf32>
+  memref.assume_alignment %3, 64 : memref<256x256xf32>
+  %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<256x256xf32>
+  memref.assume_alignment %4, 64 : memref<256x256xf32>
+  %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<256x256xf32>
+  memref.assume_alignment %5, 64 : memref<256x256xf32>
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %6 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8)>()[%0, %1, %2, %workgroup_id_y]
+  %7 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8)>()[%0, %1, %2]
+  %8 = affine.apply affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>()[%0]
+  %9 = arith.andi %7, %c7 : index
+  %10 = arith.shli %9, %c2 : index
+  %11 = arith.xori %8, %10 : index
+  %12 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 16)>()[%0, %1, %2, %workgroup_id_y]
+  %13 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 16)>()[%0, %1, %2]
+  %14 = arith.andi %13, %c7 : index
+  %15 = arith.shli %14, %c2 : index
+  %16 = arith.xori %8, %15 : index
+  %17 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 32)>()[%0, %1, %2, %workgroup_id_y]
+  %18 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 32)>()[%0, %1, %2]
+  %19 = arith.andi %18, %c7 : index
+  %20 = arith.shli %19, %c2 : index
+  %21 = arith.xori %8, %20 : index
+  %22 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 48)>()[%0, %1, %2, %workgroup_id_y]
+  %23 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 48)>()[%0, %1, %2]
+  %24 = arith.andi %23, %c7 : index
+  %25 = arith.shli %24, %c2 : index
+  %26 = arith.xori %8, %25 : index
+  %27 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 64)>()[%0, %1, %2, %workgroup_id_y]
+  %28 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 64)>()[%0, %1, %2]
+  %29 = arith.andi %28, %c7 : index
+  %30 = arith.shli %29, %c2 : index
+  %31 = arith.xori %8, %30 : index
+  %32 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 80)>()[%0, %1, %2, %workgroup_id_y]
+  %33 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 80)>()[%0, %1, %2]
+  %34 = arith.andi %33, %c7 : index
+  %35 = arith.shli %34, %c2 : index
+  %36 = arith.xori %8, %35 : index
+  %37 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 96)>()[%0, %1, %2, %workgroup_id_y]
+  %38 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 96)>()[%0, %1, %2]
+  %39 = arith.andi %38, %c7 : index
+  %40 = arith.shli %39, %c2 : index
+  %41 = arith.xori %8, %40 : index
+  %42 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 8 + s2 * 16 + s3 * 128 + s0 floordiv 8 + 112)>()[%0, %1, %2, %workgroup_id_y]
+  %43 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 16 + s0 floordiv 8 + 112)>()[%0, %1, %2]
+  %44 = arith.andi %43, %c7 : index
+  %45 = arith.shli %44, %c2 : index
+  %46 = arith.xori %8, %45 : index
+  %47 = affine.apply affine_map<()[s0, s1] -> (s0 * 4 + s1 * 128 - (s0 floordiv 32) * 128)>()[%0, %workgroup_id_x]
+  %48 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32)>()[%0, %1, %2]
+  %49 = affine.apply affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 32) * 128)>()[%0]
+  %50 = arith.andi %48, %c31 : index
+  %51 = arith.shli %50, %c2 : index
+  %52 = arith.xori %49, %51 : index
+  %53 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 4)>()[%0, %1, %2]
+  %54 = arith.andi %53, %c31 : index
+  %55 = arith.shli %54, %c2 : index
+  %56 = arith.xori %49, %55 : index
+  %57 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 8)>()[%0, %1, %2]
+  %58 = arith.andi %57, %c31 : index
+  %59 = arith.shli %58, %c2 : index
+  %60 = arith.xori %49, %59 : index
+  %61 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 12)>()[%0, %1, %2]
+  %62 = arith.andi %61, %c31 : index
+  %63 = arith.shli %62, %c2 : index
+  %64 = arith.xori %49, %63 : index
+  %65 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 16)>()[%0, %1, %2]
+  %66 = arith.andi %65, %c31 : index
+  %67 = arith.shli %66, %c2 : index
+  %68 = arith.xori %49, %67 : index
+  %69 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 20)>()[%0, %1, %2]
+  %70 = arith.andi %69, %c31 : index
+  %71 = arith.shli %70, %c2 : index
+  %72 = arith.xori %49, %71 : index
+  %73 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 24)>()[%0, %1, %2]
+  %74 = arith.andi %73, %c31 : index
+  %75 = arith.shli %74, %c2 : index
+  %76 = arith.xori %49, %75 : index
+  %77 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 28)>()[%0, %1, %2]
+  %78 = arith.andi %77, %c31 : index
+  %79 = arith.shli %78, %c2 : index
+  %80 = arith.xori %49, %79 : index
+  %81 = gpu.lane_id
+  %82 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 * 64 - (d0 floordiv 16) * 16)>(%81)[%1]
+  %83 = affine.apply affine_map<(d0) -> ((d0 floordiv 16) * 4)>(%81)
+  %84 = arith.andi %82, %c7 : index
+  %85 = arith.shli %84, %c2 : index
+  %86 = arith.xori %83, %85 : index
+  %87 = affine.apply affine_map<(d0) -> ((d0 floordiv 16) * 4 + 8)>(%81)
+  %88 = arith.xori %87, %85 : index
+  %89 = affine.apply affine_map<(d0) -> ((d0 floordiv 16) * 4 + 16)>(%81)
+  %90 = arith.xori %89, %85 : index
+  %91 = affine.apply affine_map<(d0) -> ((d0 floordiv 16) * 4 + 24)>(%81)
+  %92 = arith.xori %91, %85 : index
+  %93 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 * 64 - (d0 floordiv 16) * 16 + 16)>(%81)[%1]
+  %94 = arith.andi %93, %c7 : index
+  %95 = arith.shli %94, %c2 : index
+  %96 = arith.xori %83, %95 : index
+  %97 = arith.xori %87, %95 : index
+  %98 = arith.xori %89, %95 : index
+  %99 = arith.xori %91, %95 : index
+  %100 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 * 64 - (d0 floordiv 16) * 16 + 32)>(%81)[%1]
+  %101 = arith.andi %100, %c7 : index
+  %102 = arith.shli %101, %c2 : index
+  %103 = arith.xori %83, %102 : index
+  %104 = arith.xori %87, %102 : index
+  %105 = arith.xori %89, %102 : index
+  %106 = arith.xori %91, %102 : index
+  %107 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 * 64 - (d0 floordiv 16) * 16 + 48)>(%81)[%1]
+  %108 = arith.andi %107, %c7 : index
+  %109 = arith.shli %108, %c2 : index
+  %110 = arith.xori %83, %109 : index
+  %111 = arith.xori %87, %109 : index
+  %112 = arith.xori %89, %109 : index
+  %113 = arith.xori %91, %109 : index
+  %114 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64)>(%81)[%0]
+  %115 = affine.apply affine_map<(d0) -> (d0 mod 4)>(%81)
+  %116 = arith.andi %115, %c31 : index
+  %117 = arith.shli %116, %c2 : index
+  %118 = arith.xori %114, %117 : index
+  %119 = affine.apply affine_map<(d0) -> (d0 mod 4 + 4)>(%81)
+  %120 = arith.andi %119, %c31 : index
+  %121 = arith.shli %120, %c2 : index
+  %122 = arith.xori %114, %121 : index
+  %123 = affine.apply affine_map<(d0) -> (d0 mod 4 + 8)>(%81)
+  %124 = arith.andi %123, %c31 : index
+  %125 = arith.shli %124, %c2 : index
+  %126 = arith.xori %114, %125 : index
+  %127 = affine.apply affine_map<(d0) -> (d0 mod 4 + 12)>(%81)
+  %128 = arith.andi %127, %c31 : index
+  %129 = arith.shli %128, %c2 : index
+  %130 = arith.xori %114, %129 : index
+  %131 = affine.apply affine_map<(d0) -> (d0 mod 4 + 16)>(%81)
+  %132 = arith.andi %131, %c31 : index
+  %133 = arith.shli %132, %c2 : index
+  %134 = arith.xori %114, %133 : index
+  %135 = affine.apply affine_map<(d0) -> (d0 mod 4 + 20)>(%81)
+  %136 = arith.andi %135, %c31 : index
+  %137 = arith.shli %136, %c2 : index
+  %138 = arith.xori %114, %137 : index
+  %139 = affine.apply affine_map<(d0) -> (d0 mod 4 + 24)>(%81)
+  %140 = arith.andi %139, %c31 : index
+  %141 = arith.shli %140, %c2 : index
+  %142 = arith.xori %114, %141 : index
+  %143 = affine.apply affine_map<(d0) -> (d0 mod 4 + 28)>(%81)
+  %144 = arith.andi %143, %c31 : index
+  %145 = arith.shli %144, %c2 : index
+  %146 = arith.xori %114, %145 : index
+  %147 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 8)>(%81)[%0]
+  %148 = arith.xori %147, %117 : index
+  %149 = arith.xori %147, %121 : index
+  %150 = arith.xori %147, %125 : index
+  %151 = arith.xori %147, %129 : index
+  %152 = arith.xori %147, %133 : index
+  %153 = arith.xori %147, %137 : index
+  %154 = arith.xori %147, %141 : index
+  %155 = arith.xori %147, %145 : index
+  %156 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 16)>(%81)[%0]
+  %157 = arith.xori %156, %117 : index
+  %158 = arith.xori %156, %121 : index
+  %159 = arith.xori %156, %125 : index
+  %160 = arith.xori %156, %129 : index
+  %161 = arith.xori %156, %133 : index
+  %162 = arith.xori %156, %137 : index
+  %163 = arith.xori %156, %141 : index
+  %164 = arith.xori %156, %145 : index
+  %165 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 24)>(%81)[%0]
+  %166 = arith.xori %165, %117 : index
+  %167 = arith.xori %165, %121 : index
+  %168 = arith.xori %165, %125 : index
+  %169 = arith.xori %165, %129 : index
+  %170 = arith.xori %165, %133 : index
+  %171 = arith.xori %165, %137 : index
+  %172 = arith.xori %165, %141 : index
+  %173 = arith.xori %165, %145 : index
+  %174 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 32)>(%81)[%0]
+  %175 = arith.xori %174, %117 : index
+  %176 = arith.xori %174, %121 : index
+  %177 = arith.xori %174, %125 : index
+  %178 = arith.xori %174, %129 : index
+  %179 = arith.xori %174, %133 : index
+  %180 = arith.xori %174, %137 : index
+  %181 = arith.xori %174, %141 : index
+  %182 = arith.xori %174, %145 : index
+  %183 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 40)>(%81)[%0]
+  %184 = arith.xori %183, %117 : index
+  %185 = arith.xori %183, %121 : index
+  %186 = arith.xori %183, %125 : index
+  %187 = arith.xori %183, %129 : index
+  %188 = arith.xori %183, %133 : index
+  %189 = arith.xori %183, %137 : index
+  %190 = arith.xori %183, %141 : index
+  %191 = arith.xori %183, %145 : index
+  %192 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 48)>(%81)[%0]
+  %193 = arith.xori %192, %117 : index
+  %194 = arith.xori %192, %121 : index
+  %195 = arith.xori %192, %125 : index
+  %196 = arith.xori %192, %129 : index
+  %197 = arith.xori %192, %133 : index
+  %198 = arith.xori %192, %137 : index
+  %199 = arith.xori %192, %141 : index
+  %200 = arith.xori %192, %145 : index
+  %201 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv 4 + (s0 floordiv 32) * 64 + 56)>(%81)[%0]
+  %202 = arith.xori %201, %117 : index
+  %203 = arith.xori %201, %121 : index
+  %204 = arith.xori %201, %125 : index
+  %205 = arith.xori %201, %129 : index
+  %206 = arith.xori %201, %133 : index
+  %207 = arith.xori %201, %137 : index
+  %208 = arith.xori %201, %141 : index
+  %209 = arith.xori %201, %145 : index
+  %210:32 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_0, %arg2 = %cst_0, %arg3 = %cst_0, %arg4 = %cst_0, %arg5 = %cst_0, %arg6 = %cst_0, %arg7 = %cst_0, %arg8 = %cst_0, %arg9 = %cst_0, %arg10 = %cst_0, %arg11 = %cst_0, %arg12 = %cst_0, %arg13 = %cst_0, %arg14 = %cst_0, %arg15 = %cst_0, %arg16 = %cst_0, %arg17 = %cst_0, %arg18 = %cst_0, %arg19 = %cst_0, %arg20 = %cst_0, %arg21 = %cst_0, %arg22 = %cst_0, %arg23 = %cst_0, %arg24 = %cst_0, %arg25 = %cst_0, %arg26 = %cst_0, %arg27 = %cst_0, %arg28 = %cst_0, %arg29 = %cst_0, %arg30 = %cst_0, %arg31 = %cst_0, %arg32 = %cst_0) -> (vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>) {
+    gpu.barrier
+    %390 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4 - (s1 floordiv 8) * 32)>()[%arg0, %0]
+    %391 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) mod 3)>(%arg0)
+    %392 = nvgpu.device_async_copy %3[%6, %390], %alloc_2[%391, %7, %11], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %393 = nvgpu.device_async_copy %3[%12, %390], %alloc_2[%391, %13, %16], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %394 = nvgpu.device_async_copy %3[%17, %390], %alloc_2[%391, %18, %21], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %395 = nvgpu.device_async_copy %3[%22, %390], %alloc_2[%391, %23, %26], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %396 = nvgpu.device_async_copy %3[%27, %390], %alloc_2[%391, %28, %31], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %397 = nvgpu.device_async_copy %3[%32, %390], %alloc_2[%391, %33, %36], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %398 = nvgpu.device_async_copy %3[%37, %390], %alloc_2[%391, %38, %41], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %399 = nvgpu.device_async_copy %3[%42, %390], %alloc_2[%391, %43, %46], 4 {bypassL1} : memref<256x256xf32> to memref<3x128x32xf32, #gpu.address_space<workgroup>>
+    %400 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32)>()[%arg0, %0, %1, %2]
+    %401 = nvgpu.device_async_copy %4[%400, %47], %alloc_3[%391, %48, %52], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %402 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 4)>()[%arg0, %0, %1, %2]
+    %403 = nvgpu.device_async_copy %4[%402, %47], %alloc_3[%391, %53, %56], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %404 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 8)>()[%arg0, %0, %1, %2]
+    %405 = nvgpu.device_async_copy %4[%404, %47], %alloc_3[%391, %57, %60], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %406 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 12)>()[%arg0, %0, %1, %2]
+    %407 = nvgpu.device_async_copy %4[%406, %47], %alloc_3[%391, %61, %64], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %408 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 16)>()[%arg0, %0, %1, %2]
+    %409 = nvgpu.device_async_copy %4[%408, %47], %alloc_3[%391, %65, %68], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %410 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 20)>()[%arg0, %0, %1, %2]
+    %411 = nvgpu.device_async_copy %4[%410, %47], %alloc_3[%391, %69, %72], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %412 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 24)>()[%arg0, %0, %1, %2]
+    %413 = nvgpu.device_async_copy %4[%412, %47], %alloc_3[%391, %73, %76], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %414 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 2 + s3 * 4 + s1 floordiv 32 + 28)>()[%arg0, %0, %1, %2]
+    %415 = nvgpu.device_async_copy %4[%414, %47], %alloc_3[%391, %77, %80], 4 {bypassL1} : memref<256x256xf32> to memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %416 = nvgpu.device_async_create_group %392, %393, %394, %395, %396, %397, %398, %399, %401, %403, %405, %407, %409, %411, %413, %415
+    nvgpu.device_async_wait %416
+    gpu.barrier
+    %417 = nvgpu.ldmatrix %alloc_2[%391, %82, %86] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %418 = nvgpu.ldmatrix %alloc_2[%391, %82, %88] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %419 = nvgpu.ldmatrix %alloc_2[%391, %82, %90] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %420 = nvgpu.ldmatrix %alloc_2[%391, %82, %92] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %421 = nvgpu.ldmatrix %alloc_2[%391, %93, %96] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %422 = nvgpu.ldmatrix %alloc_2[%391, %93, %97] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %423 = nvgpu.ldmatrix %alloc_2[%391, %93, %98] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %424 = nvgpu.ldmatrix %alloc_2[%391, %93, %99] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %425 = nvgpu.ldmatrix %alloc_2[%391, %100, %103] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %426 = nvgpu.ldmatrix %alloc_2[%391, %100, %104] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %427 = nvgpu.ldmatrix %alloc_2[%391, %100, %105] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %428 = nvgpu.ldmatrix %alloc_2[%391, %100, %106] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %429 = nvgpu.ldmatrix %alloc_2[%391, %107, %110] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %430 = nvgpu.ldmatrix %alloc_2[%391, %107, %111] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %431 = nvgpu.ldmatrix %alloc_2[%391, %107, %112] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %432 = nvgpu.ldmatrix %alloc_2[%391, %107, %113] {numTiles = 4 : i32, transpose = false} : memref<3x128x32xf32, #gpu.address_space<workgroup>> -> vector<4x1xf32>
+    %433 = memref.load %alloc_3[%391, %115, %118] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %434 = vector.insert %433, %cst [0, 0] : f32 into vector<2x1xf32>
+    %435 = memref.load %alloc_3[%391, %119, %122] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %436 = vector.insert %435, %434 [1, 0] : f32 into vector<2x1xf32>
+    %437 = memref.load %alloc_3[%391, %123, %126] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %438 = vector.insert %437, %cst [0, 0] : f32 into vector<2x1xf32>
+    %439 = memref.load %alloc_3[%391, %127, %130] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %440 = vector.insert %439, %438 [1, 0] : f32 into vector<2x1xf32>
+    %441 = memref.load %alloc_3[%391, %131, %134] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %442 = vector.insert %441, %cst [0, 0] : f32 into vector<2x1xf32>
+    %443 = memref.load %alloc_3[%391, %135, %138] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %444 = vector.insert %443, %442 [1, 0] : f32 into vector<2x1xf32>
+    %445 = memref.load %alloc_3[%391, %139, %142] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %446 = vector.insert %445, %cst [0, 0] : f32 into vector<2x1xf32>
+    %447 = memref.load %alloc_3[%391, %143, %146] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %448 = vector.insert %447, %446 [1, 0] : f32 into vector<2x1xf32>
+    %449 = memref.load %alloc_3[%391, %115, %148] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %450 = vector.insert %449, %cst [0, 0] : f32 into vector<2x1xf32>
+    %451 = memref.load %alloc_3[%391, %119, %149] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %452 = vector.insert %451, %450 [1, 0] : f32 into vector<2x1xf32>
+    %453 = memref.load %alloc_3[%391, %123, %150] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %454 = vector.insert %453, %cst [0, 0] : f32 into vector<2x1xf32>
+    %455 = memref.load %alloc_3[%391, %127, %151] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %456 = vector.insert %455, %454 [1, 0] : f32 into vector<2x1xf32>
+    %457 = memref.load %alloc_3[%391, %131, %152] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %458 = vector.insert %457, %cst [0, 0] : f32 into vector<2x1xf32>
+    %459 = memref.load %alloc_3[%391, %135, %153] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %460 = vector.insert %459, %458 [1, 0] : f32 into vector<2x1xf32>
+    %461 = memref.load %alloc_3[%391, %139, %154] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %462 = vector.insert %461, %cst [0, 0] : f32 into vector<2x1xf32>
+    %463 = memref.load %alloc_3[%391, %143, %155] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %464 = vector.insert %463, %462 [1, 0] : f32 into vector<2x1xf32>
+    %465 = memref.load %alloc_3[%391, %115, %157] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %466 = vector.insert %465, %cst [0, 0] : f32 into vector<2x1xf32>
+    %467 = memref.load %alloc_3[%391, %119, %158] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %468 = vector.insert %467, %466 [1, 0] : f32 into vector<2x1xf32>
+    %469 = memref.load %alloc_3[%391, %123, %159] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %470 = vector.insert %469, %cst [0, 0] : f32 into vector<2x1xf32>
+    %471 = memref.load %alloc_3[%391, %127, %160] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %472 = vector.insert %471, %470 [1, 0] : f32 into vector<2x1xf32>
+    %473 = memref.load %alloc_3[%391, %131, %161] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %474 = vector.insert %473, %cst [0, 0] : f32 into vector<2x1xf32>
+    %475 = memref.load %alloc_3[%391, %135, %162] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %476 = vector.insert %475, %474 [1, 0] : f32 into vector<2x1xf32>
+    %477 = memref.load %alloc_3[%391, %139, %163] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %478 = vector.insert %477, %cst [0, 0] : f32 into vector<2x1xf32>
+    %479 = memref.load %alloc_3[%391, %143, %164] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %480 = vector.insert %479, %478 [1, 0] : f32 into vector<2x1xf32>
+    %481 = memref.load %alloc_3[%391, %115, %166] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %482 = vector.insert %481, %cst [0, 0] : f32 into vector<2x1xf32>
+    %483 = memref.load %alloc_3[%391, %119, %167] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %484 = vector.insert %483, %482 [1, 0] : f32 into vector<2x1xf32>
+    %485 = memref.load %alloc_3[%391, %123, %168] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %486 = vector.insert %485, %cst [0, 0] : f32 into vector<2x1xf32>
+    %487 = memref.load %alloc_3[%391, %127, %169] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %488 = vector.insert %487, %486 [1, 0] : f32 into vector<2x1xf32>
+    %489 = memref.load %alloc_3[%391, %131, %170] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %490 = vector.insert %489, %cst [0, 0] : f32 into vector<2x1xf32>
+    %491 = memref.load %alloc_3[%391, %135, %171] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %492 = vector.insert %491, %490 [1, 0] : f32 into vector<2x1xf32>
+    %493 = memref.load %alloc_3[%391, %139, %172] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %494 = vector.insert %493, %cst [0, 0] : f32 into vector<2x1xf32>
+    %495 = memref.load %alloc_3[%391, %143, %173] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %496 = vector.insert %495, %494 [1, 0] : f32 into vector<2x1xf32>
+    %497 = memref.load %alloc_3[%391, %115, %175] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %498 = vector.insert %497, %cst [0, 0] : f32 into vector<2x1xf32>
+    %499 = memref.load %alloc_3[%391, %119, %176] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %500 = vector.insert %499, %498 [1, 0] : f32 into vector<2x1xf32>
+    %501 = memref.load %alloc_3[%391, %123, %177] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %502 = vector.insert %501, %cst [0, 0] : f32 into vector<2x1xf32>
+    %503 = memref.load %alloc_3[%391, %127, %178] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %504 = vector.insert %503, %502 [1, 0] : f32 into vector<2x1xf32>
+    %505 = memref.load %alloc_3[%391, %131, %179] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %506 = vector.insert %505, %cst [0, 0] : f32 into vector<2x1xf32>
+    %507 = memref.load %alloc_3[%391, %135, %180] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %508 = vector.insert %507, %506 [1, 0] : f32 into vector<2x1xf32>
+    %509 = memref.load %alloc_3[%391, %139, %181] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %510 = vector.insert %509, %cst [0, 0] : f32 into vector<2x1xf32>
+    %511 = memref.load %alloc_3[%391, %143, %182] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %512 = vector.insert %511, %510 [1, 0] : f32 into vector<2x1xf32>
+    %513 = memref.load %alloc_3[%391, %115, %184] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %514 = vector.insert %513, %cst [0, 0] : f32 into vector<2x1xf32>
+    %515 = memref.load %alloc_3[%391, %119, %185] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %516 = vector.insert %515, %514 [1, 0] : f32 into vector<2x1xf32>
+    %517 = memref.load %alloc_3[%391, %123, %186] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %518 = vector.insert %517, %cst [0, 0] : f32 into vector<2x1xf32>
+    %519 = memref.load %alloc_3[%391, %127, %187] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %520 = vector.insert %519, %518 [1, 0] : f32 into vector<2x1xf32>
+    %521 = memref.load %alloc_3[%391, %131, %188] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %522 = vector.insert %521, %cst [0, 0] : f32 into vector<2x1xf32>
+    %523 = memref.load %alloc_3[%391, %135, %189] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %524 = vector.insert %523, %522 [1, 0] : f32 into vector<2x1xf32>
+    %525 = memref.load %alloc_3[%391, %139, %190] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %526 = vector.insert %525, %cst [0, 0] : f32 into vector<2x1xf32>
+    %527 = memref.load %alloc_3[%391, %143, %191] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %528 = vector.insert %527, %526 [1, 0] : f32 into vector<2x1xf32>
+    %529 = memref.load %alloc_3[%391, %115, %193] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %530 = vector.insert %529, %cst [0, 0] : f32 into vector<2x1xf32>
+    %531 = memref.load %alloc_3[%391, %119, %194] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %532 = vector.insert %531, %530 [1, 0] : f32 into vector<2x1xf32>
+    %533 = memref.load %alloc_3[%391, %123, %195] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %534 = vector.insert %533, %cst [0, 0] : f32 into vector<2x1xf32>
+    %535 = memref.load %alloc_3[%391, %127, %196] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %536 = vector.insert %535, %534 [1, 0] : f32 into vector<2x1xf32>
+    %537 = memref.load %alloc_3[%391, %131, %197] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %538 = vector.insert %537, %cst [0, 0] : f32 into vector<2x1xf32>
+    %539 = memref.load %alloc_3[%391, %135, %198] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %540 = vector.insert %539, %538 [1, 0] : f32 into vector<2x1xf32>
+    %541 = memref.load %alloc_3[%391, %139, %199] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %542 = vector.insert %541, %cst [0, 0] : f32 into vector<2x1xf32>
+    %543 = memref.load %alloc_3[%391, %143, %200] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %544 = vector.insert %543, %542 [1, 0] : f32 into vector<2x1xf32>
+    %545 = memref.load %alloc_3[%391, %115, %202] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %546 = vector.insert %545, %cst [0, 0] : f32 into vector<2x1xf32>
+    %547 = memref.load %alloc_3[%391, %119, %203] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %548 = vector.insert %547, %546 [1, 0] : f32 into vector<2x1xf32>
+    %549 = memref.load %alloc_3[%391, %123, %204] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %550 = vector.insert %549, %cst [0, 0] : f32 into vector<2x1xf32>
+    %551 = memref.load %alloc_3[%391, %127, %205] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %552 = vector.insert %551, %550 [1, 0] : f32 into vector<2x1xf32>
+    %553 = memref.load %alloc_3[%391, %131, %206] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %554 = vector.insert %553, %cst [0, 0] : f32 into vector<2x1xf32>
+    %555 = memref.load %alloc_3[%391, %135, %207] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %556 = vector.insert %555, %554 [1, 0] : f32 into vector<2x1xf32>
+    %557 = memref.load %alloc_3[%391, %139, %208] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %558 = vector.insert %557, %cst [0, 0] : f32 into vector<2x1xf32>
+    %559 = memref.load %alloc_3[%391, %143, %209] : memref<3x32x128xf32, #gpu.address_space<workgroup>>
+    %560 = vector.insert %559, %558 [1, 0] : f32 into vector<2x1xf32>
+    %561 = nvgpu.mma.sync(%417, %436, %arg1) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %562 = nvgpu.mma.sync(%417, %452, %arg2) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %563 = nvgpu.mma.sync(%417, %468, %arg3) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %564 = nvgpu.mma.sync(%417, %484, %arg4) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %565 = nvgpu.mma.sync(%417, %500, %arg5) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %566 = nvgpu.mma.sync(%417, %516, %arg6) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %567 = nvgpu.mma.sync(%417, %532, %arg7) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %568 = nvgpu.mma.sync(%417, %548, %arg8) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %569 = nvgpu.mma.sync(%421, %436, %arg9) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %570 = nvgpu.mma.sync(%421, %452, %arg10) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %571 = nvgpu.mma.sync(%421, %468, %arg11) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %572 = nvgpu.mma.sync(%421, %484, %arg12) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %573 = nvgpu.mma.sync(%421, %500, %arg13) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %574 = nvgpu.mma.sync(%421, %516, %arg14) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %575 = nvgpu.mma.sync(%421, %532, %arg15) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %576 = nvgpu.mma.sync(%421, %548, %arg16) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %577 = nvgpu.mma.sync(%425, %436, %arg17) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %578 = nvgpu.mma.sync(%425, %452, %arg18) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %579 = nvgpu.mma.sync(%425, %468, %arg19) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %580 = nvgpu.mma.sync(%425, %484, %arg20) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %581 = nvgpu.mma.sync(%425, %500, %arg21) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %582 = nvgpu.mma.sync(%425, %516, %arg22) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %583 = nvgpu.mma.sync(%425, %532, %arg23) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %584 = nvgpu.mma.sync(%425, %548, %arg24) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %585 = nvgpu.mma.sync(%429, %436, %arg25) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %586 = nvgpu.mma.sync(%429, %452, %arg26) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %587 = nvgpu.mma.sync(%429, %468, %arg27) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %588 = nvgpu.mma.sync(%429, %484, %arg28) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %589 = nvgpu.mma.sync(%429, %500, %arg29) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %590 = nvgpu.mma.sync(%429, %516, %arg30) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %591 = nvgpu.mma.sync(%429, %532, %arg31) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %592 = nvgpu.mma.sync(%429, %548, %arg32) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %593 = nvgpu.mma.sync(%418, %440, %561) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %594 = nvgpu.mma.sync(%418, %456, %562) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %595 = nvgpu.mma.sync(%418, %472, %563) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %596 = nvgpu.mma.sync(%418, %488, %564) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %597 = nvgpu.mma.sync(%418, %504, %565) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %598 = nvgpu.mma.sync(%418, %520, %566) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %599 = nvgpu.mma.sync(%418, %536, %567) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %600 = nvgpu.mma.sync(%418, %552, %568) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %601 = nvgpu.mma.sync(%422, %440, %569) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %602 = nvgpu.mma.sync(%422, %456, %570) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %603 = nvgpu.mma.sync(%422, %472, %571) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %604 = nvgpu.mma.sync(%422, %488, %572) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %605 = nvgpu.mma.sync(%422, %504, %573) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %606 = nvgpu.mma.sync(%422, %520, %574) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %607 = nvgpu.mma.sync(%422, %536, %575) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %608 = nvgpu.mma.sync(%422, %552, %576) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %609 = nvgpu.mma.sync(%426, %440, %577) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %610 = nvgpu.mma.sync(%426, %456, %578) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %611 = nvgpu.mma.sync(%426, %472, %579) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %612 = nvgpu.mma.sync(%426, %488, %580) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %613 = nvgpu.mma.sync(%426, %504, %581) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %614 = nvgpu.mma.sync(%426, %520, %582) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %615 = nvgpu.mma.sync(%426, %536, %583) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %616 = nvgpu.mma.sync(%426, %552, %584) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %617 = nvgpu.mma.sync(%430, %440, %585) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %618 = nvgpu.mma.sync(%430, %456, %586) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %619 = nvgpu.mma.sync(%430, %472, %587) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %620 = nvgpu.mma.sync(%430, %488, %588) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %621 = nvgpu.mma.sync(%430, %504, %589) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %622 = nvgpu.mma.sync(%430, %520, %590) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %623 = nvgpu.mma.sync(%430, %536, %591) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %624 = nvgpu.mma.sync(%430, %552, %592) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %625 = nvgpu.mma.sync(%419, %444, %593) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %626 = nvgpu.mma.sync(%419, %460, %594) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %627 = nvgpu.mma.sync(%419, %476, %595) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %628 = nvgpu.mma.sync(%419, %492, %596) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %629 = nvgpu.mma.sync(%419, %508, %597) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %630 = nvgpu.mma.sync(%419, %524, %598) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %631 = nvgpu.mma.sync(%419, %540, %599) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %632 = nvgpu.mma.sync(%419, %556, %600) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %633 = nvgpu.mma.sync(%423, %444, %601) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %634 = nvgpu.mma.sync(%423, %460, %602) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %635 = nvgpu.mma.sync(%423, %476, %603) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %636 = nvgpu.mma.sync(%423, %492, %604) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %637 = nvgpu.mma.sync(%423, %508, %605) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %638 = nvgpu.mma.sync(%423, %524, %606) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %639 = nvgpu.mma.sync(%423, %540, %607) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %640 = nvgpu.mma.sync(%423, %556, %608) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %641 = nvgpu.mma.sync(%427, %444, %609) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %642 = nvgpu.mma.sync(%427, %460, %610) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %643 = nvgpu.mma.sync(%427, %476, %611) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %644 = nvgpu.mma.sync(%427, %492, %612) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %645 = nvgpu.mma.sync(%427, %508, %613) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %646 = nvgpu.mma.sync(%427, %524, %614) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %647 = nvgpu.mma.sync(%427, %540, %615) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %648 = nvgpu.mma.sync(%427, %556, %616) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %649 = nvgpu.mma.sync(%431, %444, %617) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %650 = nvgpu.mma.sync(%431, %460, %618) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %651 = nvgpu.mma.sync(%431, %476, %619) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %652 = nvgpu.mma.sync(%431, %492, %620) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %653 = nvgpu.mma.sync(%431, %508, %621) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %654 = nvgpu.mma.sync(%431, %524, %622) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %655 = nvgpu.mma.sync(%431, %540, %623) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %656 = nvgpu.mma.sync(%431, %556, %624) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %657 = nvgpu.mma.sync(%420, %448, %625) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %658 = nvgpu.mma.sync(%420, %464, %626) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %659 = nvgpu.mma.sync(%420, %480, %627) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %660 = nvgpu.mma.sync(%420, %496, %628) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %661 = nvgpu.mma.sync(%420, %512, %629) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %662 = nvgpu.mma.sync(%420, %528, %630) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %663 = nvgpu.mma.sync(%420, %544, %631) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %664 = nvgpu.mma.sync(%420, %560, %632) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %665 = nvgpu.mma.sync(%424, %448, %633) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %666 = nvgpu.mma.sync(%424, %464, %634) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %667 = nvgpu.mma.sync(%424, %480, %635) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %668 = nvgpu.mma.sync(%424, %496, %636) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %669 = nvgpu.mma.sync(%424, %512, %637) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %670 = nvgpu.mma.sync(%424, %528, %638) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %671 = nvgpu.mma.sync(%424, %544, %639) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %672 = nvgpu.mma.sync(%424, %560, %640) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %673 = nvgpu.mma.sync(%428, %448, %641) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %674 = nvgpu.mma.sync(%428, %464, %642) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %675 = nvgpu.mma.sync(%428, %480, %643) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %676 = nvgpu.mma.sync(%428, %496, %644) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %677 = nvgpu.mma.sync(%428, %512, %645) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %678 = nvgpu.mma.sync(%428, %528, %646) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %679 = nvgpu.mma.sync(%428, %544, %647) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %680 = nvgpu.mma.sync(%428, %560, %648) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %681 = nvgpu.mma.sync(%432, %448, %649) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %682 = nvgpu.mma.sync(%432, %464, %650) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %683 = nvgpu.mma.sync(%432, %480, %651) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %684 = nvgpu.mma.sync(%432, %496, %652) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %685 = nvgpu.mma.sync(%432, %512, %653) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %686 = nvgpu.mma.sync(%432, %528, %654) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %687 = nvgpu.mma.sync(%432, %544, %655) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    %688 = nvgpu.mma.sync(%432, %560, %656) {mmaShape = [16, 8, 8], tf32Enabled} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+    scf.yield %657, %658, %659, %660, %661, %662, %663, %664, %665, %666, %667, %668, %669, %670, %671, %672, %673, %674, %675, %676, %677, %678, %679, %680, %681, %682, %683, %684, %685, %686, %687, %688 : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+  }
+  %211 = gpu.lane_id
+  %212 = vector.extract %210#31[0] : vector<2x2xf32>
+  %213 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 48)>()[%1, %211]
+  %214 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 56)>()[%0, %211]
+  vector.store %212, %alloc[%213, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %215 = vector.extract %210#31[1] : vector<2x2xf32>
+  %216 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 56)>()[%1, %211]
+  vector.store %215, %alloc[%216, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %217 = vector.extract %210#30[0] : vector<2x2xf32>
+  %218 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 48)>()[%0, %211]
+  vector.store %217, %alloc[%213, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %219 = vector.extract %210#30[1] : vector<2x2xf32>
+  vector.store %219, %alloc[%216, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %220 = vector.extract %210#29[0] : vector<2x2xf32>
+  %221 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 40)>()[%0, %211]
+  vector.store %220, %alloc[%213, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %222 = vector.extract %210#29[1] : vector<2x2xf32>
+  vector.store %222, %alloc[%216, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %223 = vector.extract %210#28[0] : vector<2x2xf32>
+  %224 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 32)>()[%0, %211]
+  vector.store %223, %alloc[%213, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %225 = vector.extract %210#28[1] : vector<2x2xf32>
+  vector.store %225, %alloc[%216, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %226 = vector.extract %210#27[0] : vector<2x2xf32>
+  %227 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 24)>()[%0, %211]
+  vector.store %226, %alloc[%213, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %228 = vector.extract %210#27[1] : vector<2x2xf32>
+  vector.store %228, %alloc[%216, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %229 = vector.extract %210#26[0] : vector<2x2xf32>
+  %230 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 16)>()[%0, %211]
+  vector.store %229, %alloc[%213, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %231 = vector.extract %210#26[1] : vector<2x2xf32>
+  vector.store %231, %alloc[%216, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %232 = vector.extract %210#25[0] : vector<2x2xf32>
+  %233 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64 + 8)>()[%0, %211]
+  vector.store %232, %alloc[%213, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %234 = vector.extract %210#25[1] : vector<2x2xf32>
+  vector.store %234, %alloc[%216, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %235 = vector.extract %210#24[0] : vector<2x2xf32>
+  %236 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 - (s1 floordiv 4) * 8 + (s0 floordiv 32) * 64)>()[%0, %211]
+  vector.store %235, %alloc[%213, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %237 = vector.extract %210#24[1] : vector<2x2xf32>
+  vector.store %237, %alloc[%216, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %238 = vector.extract %210#23[0] : vector<2x2xf32>
+  %239 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 32)>()[%1, %211]
+  vector.store %238, %alloc[%239, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %240 = vector.extract %210#23[1] : vector<2x2xf32>
+  %241 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 40)>()[%1, %211]
+  vector.store %240, %alloc[%241, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %242 = vector.extract %210#22[0] : vector<2x2xf32>
+  vector.store %242, %alloc[%239, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %243 = vector.extract %210#22[1] : vector<2x2xf32>
+  vector.store %243, %alloc[%241, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %244 = vector.extract %210#21[0] : vector<2x2xf32>
+  vector.store %244, %alloc[%239, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %245 = vector.extract %210#21[1] : vector<2x2xf32>
+  vector.store %245, %alloc[%241, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %246 = vector.extract %210#20[0] : vector<2x2xf32>
+  vector.store %246, %alloc[%239, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %247 = vector.extract %210#20[1] : vector<2x2xf32>
+  vector.store %247, %alloc[%241, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %248 = vector.extract %210#19[0] : vector<2x2xf32>
+  vector.store %248, %alloc[%239, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %249 = vector.extract %210#19[1] : vector<2x2xf32>
+  vector.store %249, %alloc[%241, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %250 = vector.extract %210#18[0] : vector<2x2xf32>
+  vector.store %250, %alloc[%239, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %251 = vector.extract %210#18[1] : vector<2x2xf32>
+  vector.store %251, %alloc[%241, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %252 = vector.extract %210#17[0] : vector<2x2xf32>
+  vector.store %252, %alloc[%239, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %253 = vector.extract %210#17[1] : vector<2x2xf32>
+  vector.store %253, %alloc[%241, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %254 = vector.extract %210#16[0] : vector<2x2xf32>
+  vector.store %254, %alloc[%239, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %255 = vector.extract %210#16[1] : vector<2x2xf32>
+  vector.store %255, %alloc[%241, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %256 = vector.extract %210#15[0] : vector<2x2xf32>
+  %257 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 16)>()[%1, %211]
+  vector.store %256, %alloc[%257, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %258 = vector.extract %210#15[1] : vector<2x2xf32>
+  %259 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 24)>()[%1, %211]
+  vector.store %258, %alloc[%259, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %260 = vector.extract %210#14[0] : vector<2x2xf32>
+  vector.store %260, %alloc[%257, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %261 = vector.extract %210#14[1] : vector<2x2xf32>
+  vector.store %261, %alloc[%259, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %262 = vector.extract %210#13[0] : vector<2x2xf32>
+  vector.store %262, %alloc[%257, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %263 = vector.extract %210#13[1] : vector<2x2xf32>
+  vector.store %263, %alloc[%259, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %264 = vector.extract %210#12[0] : vector<2x2xf32>
+  vector.store %264, %alloc[%257, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %265 = vector.extract %210#12[1] : vector<2x2xf32>
+  vector.store %265, %alloc[%259, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %266 = vector.extract %210#11[0] : vector<2x2xf32>
+  vector.store %266, %alloc[%257, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %267 = vector.extract %210#11[1] : vector<2x2xf32>
+  vector.store %267, %alloc[%259, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %268 = vector.extract %210#10[0] : vector<2x2xf32>
+  vector.store %268, %alloc[%257, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %269 = vector.extract %210#10[1] : vector<2x2xf32>
+  vector.store %269, %alloc[%259, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %270 = vector.extract %210#9[0] : vector<2x2xf32>
+  vector.store %270, %alloc[%257, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %271 = vector.extract %210#9[1] : vector<2x2xf32>
+  vector.store %271, %alloc[%259, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %272 = vector.extract %210#8[0] : vector<2x2xf32>
+  vector.store %272, %alloc[%257, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %273 = vector.extract %210#8[1] : vector<2x2xf32>
+  vector.store %273, %alloc[%259, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %274 = vector.extract %210#7[0] : vector<2x2xf32>
+  %275 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4)>()[%1, %211]
+  vector.store %274, %alloc[%275, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %276 = vector.extract %210#7[1] : vector<2x2xf32>
+  %277 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 floordiv 4 + 8)>()[%1, %211]
+  vector.store %276, %alloc[%277, %214] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %278 = vector.extract %210#6[0] : vector<2x2xf32>
+  vector.store %278, %alloc[%275, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %279 = vector.extract %210#6[1] : vector<2x2xf32>
+  vector.store %279, %alloc[%277, %218] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %280 = vector.extract %210#5[0] : vector<2x2xf32>
+  vector.store %280, %alloc[%275, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %281 = vector.extract %210#5[1] : vector<2x2xf32>
+  vector.store %281, %alloc[%277, %221] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %282 = vector.extract %210#4[0] : vector<2x2xf32>
+  vector.store %282, %alloc[%275, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %283 = vector.extract %210#4[1] : vector<2x2xf32>
+  vector.store %283, %alloc[%277, %224] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %284 = vector.extract %210#3[0] : vector<2x2xf32>
+  vector.store %284, %alloc[%275, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %285 = vector.extract %210#3[1] : vector<2x2xf32>
+  vector.store %285, %alloc[%277, %227] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %286 = vector.extract %210#2[0] : vector<2x2xf32>
+  vector.store %286, %alloc[%275, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %287 = vector.extract %210#2[1] : vector<2x2xf32>
+  vector.store %287, %alloc[%277, %230] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %288 = vector.extract %210#1[0] : vector<2x2xf32>
+  vector.store %288, %alloc[%275, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %289 = vector.extract %210#1[1] : vector<2x2xf32>
+  vector.store %289, %alloc[%277, %233] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %290 = vector.extract %210#0[0] : vector<2x2xf32>
+  vector.store %290, %alloc[%275, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  %291 = vector.extract %210#0[1] : vector<2x2xf32>
+  vector.store %291, %alloc[%277, %236] : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+  gpu.barrier
+  %292 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32)>()[%0, %1, %2]
+  %293 = affine.apply affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 32) * 128)>()[%0]
+  %294 = vector.transfer_read %alloc[%292, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %295 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32)>()[%0, %1, %2, %workgroup_id_y]
+  %296 = affine.apply affine_map<()[s0, s1] -> (s0 * 4 + s1 * 128 - (s0 floordiv 32) * 128)>()[%0, %workgroup_id_x]
+  vector.transfer_write %294, %5[%295, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %297 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 4)>()[%0, %1, %2]
+  %298 = vector.transfer_read %alloc[%297, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %299 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 4)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %298, %5[%299, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %300 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 8)>()[%0, %1, %2]
+  %301 = vector.transfer_read %alloc[%300, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %302 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 8)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %301, %5[%302, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %303 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 12)>()[%0, %1, %2]
+  %304 = vector.transfer_read %alloc[%303, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %305 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 12)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %304, %5[%305, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %306 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 16)>()[%0, %1, %2]
+  %307 = vector.transfer_read %alloc[%306, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %308 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 16)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %307, %5[%308, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %309 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 20)>()[%0, %1, %2]
+  %310 = vector.transfer_read %alloc[%309, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %311 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 20)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %310, %5[%311, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %312 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 24)>()[%0, %1, %2]
+  %313 = vector.transfer_read %alloc[%312, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %314 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 24)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %313, %5[%314, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %315 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 28)>()[%0, %1, %2]
+  %316 = vector.transfer_read %alloc[%315, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %317 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 28)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %316, %5[%317, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %318 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 32)>()[%0, %1, %2]
+  %319 = vector.transfer_read %alloc[%318, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %320 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 32)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %319, %5[%320, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %321 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 36)>()[%0, %1, %2]
+  %322 = vector.transfer_read %alloc[%321, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %323 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 36)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %322, %5[%323, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %324 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 40)>()[%0, %1, %2]
+  %325 = vector.transfer_read %alloc[%324, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %326 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 40)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %325, %5[%326, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %327 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 44)>()[%0, %1, %2]
+  %328 = vector.transfer_read %alloc[%327, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %329 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 44)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %328, %5[%329, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %330 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 48)>()[%0, %1, %2]
+  %331 = vector.transfer_read %alloc[%330, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %332 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 48)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %331, %5[%332, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %333 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 52)>()[%0, %1, %2]
+  %334 = vector.transfer_read %alloc[%333, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %335 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 52)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %334, %5[%335, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %336 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 56)>()[%0, %1, %2]
+  %337 = vector.transfer_read %alloc[%336, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %338 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 56)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %337, %5[%338, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %339 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 60)>()[%0, %1, %2]
+  %340 = vector.transfer_read %alloc[%339, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %341 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 60)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %340, %5[%341, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %342 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 64)>()[%0, %1, %2]
+  %343 = vector.transfer_read %alloc[%342, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %344 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 64)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %343, %5[%344, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %345 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 68)>()[%0, %1, %2]
+  %346 = vector.transfer_read %alloc[%345, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %347 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 68)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %346, %5[%347, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %348 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 72)>()[%0, %1, %2]
+  %349 = vector.transfer_read %alloc[%348, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %350 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 72)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %349, %5[%350, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %351 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 76)>()[%0, %1, %2]
+  %352 = vector.transfer_read %alloc[%351, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %353 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 76)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %352, %5[%353, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %354 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 80)>()[%0, %1, %2]
+  %355 = vector.transfer_read %alloc[%354, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %356 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 80)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %355, %5[%356, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %357 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 84)>()[%0, %1, %2]
+  %358 = vector.transfer_read %alloc[%357, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %359 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 84)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %358, %5[%359, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %360 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 88)>()[%0, %1, %2]
+  %361 = vector.transfer_read %alloc[%360, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %362 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 88)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %361, %5[%362, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %363 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 92)>()[%0, %1, %2]
+  %364 = vector.transfer_read %alloc[%363, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %365 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 92)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %364, %5[%365, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %366 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 96)>()[%0, %1, %2]
+  %367 = vector.transfer_read %alloc[%366, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %368 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 96)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %367, %5[%368, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %369 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 100)>()[%0, %1, %2]
+  %370 = vector.transfer_read %alloc[%369, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %371 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 100)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %370, %5[%371, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %372 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 104)>()[%0, %1, %2]
+  %373 = vector.transfer_read %alloc[%372, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %374 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 104)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %373, %5[%374, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %375 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 108)>()[%0, %1, %2]
+  %376 = vector.transfer_read %alloc[%375, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %377 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 108)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %376, %5[%377, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %378 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 112)>()[%0, %1, %2]
+  %379 = vector.transfer_read %alloc[%378, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %380 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 112)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %379, %5[%380, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %381 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 116)>()[%0, %1, %2]
+  %382 = vector.transfer_read %alloc[%381, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %383 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 116)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %382, %5[%383, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %384 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 120)>()[%0, %1, %2]
+  %385 = vector.transfer_read %alloc[%384, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %386 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 120)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %385, %5[%386, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  %387 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 2 + s2 * 4 + s0 floordiv 32 + 124)>()[%0, %1, %2]
+  %388 = vector.transfer_read %alloc[%387, %293], %cst_1 {in_bounds = [true]} : memref<128x128xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+  %389 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 2 + s2 * 4 + s3 * 128 + s0 floordiv 32 + 124)>()[%0, %1, %2, %workgroup_id_y]
+  vector.transfer_write %388, %5[%389, %296] {in_bounds = [true]} : vector<4xf32>, memref<256x256xf32>
+  gpu.barrier
+  return
+}
+
+
+//  CHECK-NV-LABEL: func.func @nvidia_tenscore_schedule_f32
+//  CHECK-NV-COUNT-6:  nvgpu.device_async_copy
+//          CHECK-NV:  nvgpu.device_async_create_group
+//  CHECK-NV-COUNT-6:  nvgpu.device_async_copy
+//          CHECK-NV:  nvgpu.device_async_create_group
+//          CHECK-NV:  nvgpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
+//          CHECK-NV:  gpu.barrier
+//  CHECK-NV-COUNT-4:  nvgpu.ldmatrix
+//  CHECK-NV-COUNT-8:  memref.load
+//          CHECK-NV:  scf.for
+//  CHECK-NV-COUNT-4:    nvgpu.ldmatrix
+//  CHECK-NV-COUNT-16:   memref.load
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
+//  CHECK-NV-COUNT-4:    nvgpu.ldmatrix
+//  CHECK-NV-COUNT-16:   memref.load
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
+//  CHECK-NV-COUNT-4:    nvgpu.ldmatrix
+//  CHECK-NV-COUNT-16:   memref.load
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
+//  CHECK-NV-COUNT-6:    nvgpu.device_async_copy
+//          CHECK-NV:    nvgpu.device_async_create_group
+//          CHECK-NV:    nvgpu.device_async_wait %{{.*}} {numGroups = 1 : i32}
+//          CHECK-NV:    gpu.barrier
+//  CHECK-NV-COUNT-4:    nvgpu.ldmatrix
+//  CHECK-NV-COUNT-16:   memref.load
+// CHECK-NV-COUNT-32:    nvgpu.mma.sync
+//          CHECK-NV:  }
+//          CHECK-NV:  vector.store
+
+// -----
\ No newline at end of file
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir
index 0f51926..9cfba09 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir
@@ -3,7 +3,6 @@
 // Verify that a simple element wise op gets lowered succefully all the way to
 // nvvm/llvm dialect via mma.sync path.
 
-
 // -----
 
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
@@ -67,13 +66,15 @@
 //  CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //          CHECK:   nvvm.cp.async.commit.group
 //          CHECK:   nvvm.cp.async.wait.group 2
-//  CHECK-COUNT-4:   nvvm.ldmatrix {{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-2:   nvvm.ldmatrix {{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
 //          CHECK:   llvm.br
-//  CHECK-COUNT-4:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+//  CHECK-COUNT-2:   nvvm.ldmatrix {{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-2:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 //  CHECK-COUNT-2:   llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>, i32, i32) -> !llvm.void
 //          CHECK:   nvvm.cp.async.commit.group
 //          CHECK:   nvvm.cp.async.wait.group 2
-//  CHECK-COUNT-4:   nvvm.ldmatrix {{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-2:   nvvm.ldmatrix {{.*}} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-2:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 //          CHECK:   llvm.br
 //      CHECK-NOT:   nvvm.mma.sync
 //  CHECK-COUNT-4:   llvm.store {{.*}} : !llvm.ptr<vector<2xf16>, 3>
@@ -90,7 +91,7 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable @mma_fused {
+hal.executable @mma_fused_f32 {
   hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
   hal.executable.export public @_large_aligned_dispatch_0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>) {
   ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
@@ -134,7 +135,7 @@
 }
 
 // mma.sync.1688.f32.tf32 / TensorCore(f32):
-//    CHECK-LABEL: hal.executable public @mma_fused
+//    CHECK-LABEL: hal.executable public @mma_fused_f32
 //          CHECK:   hal.executable.variant public @cuda
 //      CHECK-NOT:   llvm.store
 //  CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
@@ -144,13 +145,18 @@
 //  CHECK-COUNT-2:   nvvm.cp.async.shared.global {{.*}}, {{.*}}, 16
 //          CHECK:   nvvm.cp.async.commit.group
 //          CHECK:   nvvm.cp.async.wait.group 2
-//  CHECK-COUNT-2:   nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-1:   nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-4:   llvm.extractvalue{{.*}} : !llvm.struct<(i32, i32, i32, i32)> 
 //          CHECK:   llvm.br
-//  CHECK-COUNT-4:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+//  CHECK-COUNT-1:   nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-4:   llvm.extractvalue{{.*}} : !llvm.struct<(i32, i32, i32, i32)> 
+//  CHECK-COUNT-2:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
 //  CHECK-COUNT-2:   llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<i8, 3>, !llvm.ptr<i8, 1>, i32, i32) -> !llvm.void
 //          CHECK:   nvvm.cp.async.commit.group
 //          CHECK:   nvvm.cp.async.wait.group 2
-//  CHECK-COUNT-2:   nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-1:   nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+//  CHECK-COUNT-4:   llvm.extractvalue{{.*}} : !llvm.struct<(i32, i32, i32, i32)> 
+//  CHECK-COUNT-2:   nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
 //          CHECK:   llvm.br
 //      CHECK-NOT:   nvvm.mma.sync
 //  CHECK-COUNT-4:   llvm.store {{.*}} : !llvm.ptr<vector<2xf32>, 3>
@@ -164,4 +170,4 @@
 //    CHECK-COUNT:   llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
 //    CHECK-COUNT:   llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
 //    CHECK-COUNT:   llvm.fadd {{.*}} : vector<4xf32>
-//    CHECK-COUNT:   llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
+//    CHECK-COUNT:   llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
\ No newline at end of file
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 80074d5..4e0d4ec 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 80074d5fc0ab3f165865b15f5bf55ffac0917bcd
+Subproject commit 4e0d4ec3bdc8d6dbb3e6a4c6f4b6dbe7d51e94a5