Add strategy for Mali GPU in vectorization path and enable hoisting (#3650)

Add launch strategy for Mali GPU and a set of new transformation needed
and specify a good tile and workgroup size to use based on experiment.
Also enable hoisting of transfer_ops for the vector path.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 58106b0..db1a6bb 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -197,6 +197,27 @@
   return success();
 }
 
+/// Launch configuration for different known GPU configuration.
+static LogicalResult getTargetSpecificConfig(
+    linalg::MatmulOp op, const SPIRVCodegenOptions &options,
+    spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+    std::array<int64_t, 3> &workgroupSize,
+    std::array<int64_t, 3> &numSubgroups) {
+  if (spirv::lookupTargetEnv(op).getVendorID() != spirv::Vendor::ARM)
+    return failure();
+  workgroupSize[0] = resourceLimits.subgroup_size().getInt();
+  workgroupSize[1] = 1;
+  workgroupSize[2] = 1;
+  SmallVector<int64_t, 4> ts = {8, 64, 4};
+  tileSizes.emplace_back(ts);
+  // No tiling at the subgroup level since this target doesn't use subgroup op
+  // or shared memory.
+  tileSizes.emplace_back();
+  SmallVector<int64_t, 4> threadTs = {ts[0], ts[1] / workgroupSize[0], ts[2]};
+  tileSizes.emplace_back(threadTs);
+  return success();
+}
+
 template <>
 LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
                                 const SPIRVCodegenOptions &options,
@@ -208,6 +229,11 @@
                                       op, options, resourceLimits, tileSizes,
                                       workgroupSize, numSubgroups))) {
     return success();
+  } else if (options.useVectorization &&
+             succeeded(getTargetSpecificConfig(op, options, resourceLimits,
+                                               tileSizes, workgroupSize,
+                                               numSubgroups))) {
+    return success();
   }
   unsigned maxWorkgroupSize =
       resourceLimits.max_compute_workgroup_invocations().getInt();
@@ -368,13 +394,37 @@
 template <>
 Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::ContractionOp>(
     vector::ContractionOp op) {
-  spirv::ResourceLimitsAttr resourceLimits =
-      spirv::lookupTargetEnv(op).getResourceLimits();
-  return getCooperativeMatmulSubgroupSize(
-      resourceLimits, op.getLhsType().getElementType(),
-      op.getRhsType().getElementType(),
-      op.getAccType().cast<VectorType>().getElementType(),
-      op.getResultType().cast<VectorType>().getElementType());
+  auto targetEnvAttr = spirv::lookupTargetEnv(op);
+  auto targetEnv = spirv::TargetEnv(targetEnvAttr);
+  if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
+      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
+    spirv::ResourceLimitsAttr resourceLimits =
+        targetEnvAttr.getResourceLimits();
+    return getCooperativeMatmulSubgroupSize(
+        resourceLimits, op.getLhsType().getElementType(),
+        op.getRhsType().getElementType(),
+        op.getAccType().cast<VectorType>().getElementType(),
+        op.getResultType().cast<VectorType>().getElementType());
+  } else {
+    // Map to vec4 fma operations.
+    return SmallVector<int64_t, 4>({1, 4, 1});
+  }
+}
+
+template <>
+Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::TransferReadOp>(
+    vector::TransferReadOp op) {
+  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+  if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
+      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
+    // Don't unroll cooperative martrix load as they should match the size of
+    // the contract.
+    return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
+                                   op.getVectorType().getDimSize(1));
+  } else {
+    // Map to load4.
+    return SmallVector<int64_t, 4>({1, 4});
+  }
 }
 
 Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
@@ -384,6 +434,7 @@
   }
 
   DISPATCH(vector::ContractionOp)
+  DISPATCH(vector::TransferReadOp)
 
 #undef DISPATCH
   return llvm::None;
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index a78343d..5d69e24 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/Function.h"
@@ -40,6 +41,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
 
 #define DEBUG_TYPE "iree-linalg-tile-and-fuse"
 
@@ -407,6 +409,57 @@
           getVectorizeMarker(), context));
 }
 
+//===----------------------------------------------------------------------===//
+// Patterns and methods for thread tiling.
+//===----------------------------------------------------------------------===//
+
+/// Patterns for third level tiling to target invocations.
+static void populateTilingToInvocationPatterns(
+    MLIRContext *context, const LaunchConfig &launchConfig,
+    OwningRewritePatternList &patterns) {
+  linalg::TileSizeComputationFunction getInnerTileSizeFn =
+      [&launchConfig](OpBuilder &builder, Operation *operation) {
+        ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 2);
+        if (tileSizes.empty()) return SmallVector<Value, 4>();
+        SmallVector<Value, 4> tileSizesVal;
+        tileSizesVal.reserve(tileSizes.size());
+        for (auto val : tileSizes) {
+          tileSizesVal.push_back(
+              builder.create<ConstantIndexOp>(operation->getLoc(), val));
+        }
+        return tileSizesVal;
+      };
+
+  auto getThreadProcInfoFn = [&launchConfig](
+                                 OpBuilder &builder, Location loc,
+                                 ArrayRef<Range> parallelLoopRanges) {
+    Type indexType = builder.getIndexType();
+    SmallVector<linalg::ProcInfo, 2> procInfo(2);
+    procInfo[1] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
+                                                   builder.getStringAttr("x")),
+                   builder.create<ConstantIndexOp>(
+                       loc, launchConfig.getWorkgroupSize()[0])};
+    procInfo[0] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
+                                                   builder.getStringAttr("y")),
+                   builder.create<ConstantIndexOp>(
+                       loc, launchConfig.getWorkgroupSize()[1])};
+    return procInfo;
+  };
+  linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+      getThreadProcInfoFn,
+      {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+       linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
+      context,
+      linalg::LinalgTilingOptions()
+          .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+          .setTileSizeComputationFunction(getInnerTileSizeFn)
+          .setDistributionOptions(subgroupDistributionOptions),
+      getLinalgMatchAndReplaceMarker(
+          {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+          getVectorizeMarker(), context));
+}
+
 //====---------------------------------------------------------------------===//
 // Patterns for vectorization
 //====---------------------------------------------------------------------===//
@@ -436,6 +489,9 @@
 
 static void populateVectorUnrollPatterns(MLIRContext *context,
                                          OwningRewritePatternList &patterns) {
+  patterns.insert<vector::UnrollVectorPattern<vector::TransferReadOp>>(
+      context,
+      vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
   patterns.insert<vector::UnrollVectorPattern<vector::ContractionOp>>(
       context,
       vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
@@ -444,6 +500,41 @@
 }
 
 //====---------------------------------------------------------------------===//
+// Vector patterns
+//====---------------------------------------------------------------------===//
+
+static void applyVectorTransformation(FuncOp funcOp) {
+  {
+    OwningRewritePatternList vectorUnrollPatterns;
+    populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
+    applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
+
+    OwningRewritePatternList canonicalizationPatterns;
+    vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns,
+                                                 funcOp.getContext());
+    applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+    LLVM_DEBUG({
+      llvm::dbgs() << "--- After Vector Unroll ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+  }
+
+  {
+    // TODO(ravishankarm): remove this transformation once allocations get
+    // inserted at the top of the function.
+    linalg::hoistViewAllocOps(funcOp);
+    linalg::hoistRedundantVectorTransfers(funcOp);
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "--- After Hoisting ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+  }
+}
+
+//====---------------------------------------------------------------------===//
 // Main pass implementation
 //====---------------------------------------------------------------------===//
 
@@ -549,6 +640,7 @@
         applyPatternsAndFoldGreedily(funcOp,
                                      std::move(secondLevelTilingPatterns));
         applyCanonicalizationPatterns(context, funcOp);
+        promoteSingleIterationLoops(funcOp);
 
         LLVM_DEBUG({
           llvm::dbgs() << "--- After Second level Tiling  ---\n";
@@ -558,6 +650,22 @@
       }
 
       {
+        OwningRewritePatternList thirdLevelTilingPatterns;
+        populateTilingToInvocationPatterns(context, launchConfig,
+                                           thirdLevelTilingPatterns);
+        applyPatternsAndFoldGreedily(funcOp,
+                                     std::move(thirdLevelTilingPatterns));
+        applyCanonicalizationPatterns(context, funcOp);
+        promoteSingleIterationLoops(funcOp);
+
+        LLVM_DEBUG({
+          llvm::dbgs() << "--- After Third level Tiling  ---\n";
+          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+          llvm::dbgs() << "\n\n";
+        });
+      }
+
+      {
         OwningRewritePatternList vectorizationPatterns;
         populateVectorizationPatterns(context, launchConfig,
                                       vectorizationPatterns);
@@ -569,16 +677,7 @@
         });
       }
 
-      {
-        OwningRewritePatternList vectorUnrollPatterns;
-        populateVectorUnrollPatterns(context, vectorUnrollPatterns);
-        applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
-        LLVM_DEBUG({
-          llvm::dbgs() << "--- After Vector Unroll ---\n";
-          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-          llvm::dbgs() << "\n\n";
-        });
-      }
+      applyVectorTransformation(funcOp);
     }
 
     launchConfig.finalize(funcOp);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index b09f909..b2c7142 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -20,6 +20,7 @@
 
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
 #include "iree/compiler/Conversion/Common/Passes.h"
 #include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
@@ -175,6 +176,7 @@
   pm.addPass(createCSEPass());
   if (options.useVectorization) {
     pm.addPass(createVectorizeMemref());
+    pm.addPass(createForOpCanonicalizationPass());
     pm.addPass(createCanonicalizerPass());
     pm.addPass(createCSEPass());
   }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
index 6c2ba3c..d98ea68 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -61,21 +61,74 @@
 //  CHECK-DAG:  %[[C48:.+]] = constant 48 : index
 //      CHECK:  %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK:  %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK:  scf.for %[[IV0:.+]] =
-//      CHECK:    %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//      CHECK:    %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
-// CHECK-SAME:      [%[[BOFFSET_Y]], %[[IV0]]] [64, 32]
-//      CHECK:    %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-//      CHECK:    %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
-// CHECK-SAME:      [%[[IV0]], %[[BOFFSET_X]]] [32, 64]
+//      CHECK:  %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:  %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:    %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
 // CHECK-SAME:      [%[[BOFFSET_Y]], %[[BOFFSET_X]]] [64, 64]
+//      CHECK:  %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME:    [0, 0] [64, 64] [1, 1]
+
+//  CHECK-DAG:  %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+//  CHECK-DAG:  %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+//  CHECK-DAG:  %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+
+//  CHECK-DAG:  %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+//  CHECK-DAG:  %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+//  CHECK-DAG:  %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+
+//  CHECK-DAG:  %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+//  CHECK-DAG:  %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+//  CHECK-DAG:  %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+//  CHECK-DAG:  %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+//  CHECK-DAG:  %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+//  CHECK-DAG:  %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+
+//      CHECK:  %[[FOR_RES:.+]]:16 = scf.for %[[IV0:.+]] = {{.*}} to
+// CHECK-SAME:  iter_args(%[[ACC_0_0:.+]] = %[[READ_INIT_0_0]],
+// CHECK-SAME:  %[[ACC_0_1:.+]] = %[[READ_INIT_0_1]],
+// CHECK-SAME:  %[[ACC_0_2:.+]] = %[[READ_INIT_0_2]],
+// CHECK-SAME:  %[[ACC_0_3:.+]] = %[[READ_INIT_0_3]],
+// CHECK-SAME:  %[[ACC_1_0:.+]] = %[[READ_INIT_1_0]],
+// CHECK-SAME:  %[[ACC_1_1:.+]] = %[[READ_INIT_1_1]],
+// CHECK-SAME:  %[[ACC_1_2:.+]] = %[[READ_INIT_1_2]],
+// CHECK-SAME:  %[[ACC_1_3:.+]] = %[[READ_INIT_1_3]],
+// CHECK-SAME:  %[[ACC_2_0:.+]] = %[[READ_INIT_2_0]],
+// CHECK-SAME:  %[[ACC_2_1:.+]] = %[[READ_INIT_2_1]],
+// CHECK-SAME:  %[[ACC_2_2:.+]] = %[[READ_INIT_2_2]],
+// CHECK-SAME:  %[[ACC_2_3:.+]] = %[[READ_INIT_2_3]],
+// CHECK-SAME:  %[[ACC_3_0:.+]] = %[[READ_INIT_3_0]],
+// CHECK-SAME:  %[[ACC_3_1:.+]] = %[[READ_INIT_3_1]],
+// CHECK-SAME:  %[[ACC_3_2:.+]] = %[[READ_INIT_3_2]],
+// CHECK-SAME:  %[[ACC_3_3:.+]] = %[[READ_INIT_3_3]])
+//      CHECK:    %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME:      [%[[BOFFSET_Y]], %[[IV0]]] [64, 32]
+//      CHECK:    %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME:      [%[[IV0]], %[[BOFFSET_X]]] [32, 64]
 //      CHECK:    %[[SUBVIEW_LHS_2:.+]] = subview %[[SUBVIEW_LHS]]
 // CHECK-SAME:      [0, 0] [64, 32] [1, 1]
 //      CHECK:    %[[SUBVIEW_RHS_2:.+]] = subview %[[SUBVIEW_RHS]]
 // CHECK-SAME:      [0, 0] [32, 64] [1, 1]
-//      CHECK:    %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
-// CHECK-SAME:      [0, 0] [64, 64] [1, 1]
 
 //  CHECK-DAG:    %[[READ_LHS_0_0:.+]] = vector.transfer_read
 // CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C0]]]
@@ -115,148 +168,103 @@
 //  CHECK-DAG:    %[[READ_RHS_1_3:.+]] = vector.transfer_read
 // CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C16]], %[[C48]]]
 
-//  CHECK-DAG:    %[[READ_INIT_0_0:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
-//  CHECK-DAG:    %[[READ_INIT_0_1:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
-//  CHECK-DAG:    %[[READ_INIT_0_2:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
-//  CHECK-DAG:    %[[READ_INIT_0_3:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
-
-//  CHECK-DAG:    %[[READ_INIT_1_0:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
-//  CHECK-DAG:    %[[READ_INIT_1_1:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
-//  CHECK-DAG:    %[[READ_INIT_1_2:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
-//  CHECK-DAG:    %[[READ_INIT_1_3:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
-
-//  CHECK-DAG:    %[[READ_INIT_2_0:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
-//  CHECK-DAG:    %[[READ_INIT_2_1:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
-//  CHECK-DAG:    %[[READ_INIT_2_2:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
-//  CHECK-DAG:    %[[READ_INIT_2_3:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
-
-//  CHECK-DAG:    %[[READ_INIT_3_0:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
-//  CHECK-DAG:    %[[READ_INIT_3_1:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
-//  CHECK-DAG:    %[[READ_INIT_3_2:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
-//  CHECK-DAG:    %[[READ_INIT_3_3:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
-
 //      CHECK:    %[[CONTRACT_0_0_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[ACC_0_0]]
 //      CHECK:    %[[CONTRACT_0_0:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
 //      CHECK:    %[[CONTRACT_0_1_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[ACC_0_1]]
 //      CHECK:    %[[CONTRACT_0_1:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
 //      CHECK:    %[[CONTRACT_0_2_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[ACC_0_2]]
 //      CHECK:    %[[CONTRACT_0_2:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
 //      CHECK:    %[[CONTRACT_0_3_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[ACC_0_3]]
 //      CHECK:    %[[CONTRACT_0_3:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
 
 //      CHECK:    %[[CONTRACT_1_0_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[ACC_1_0]]
 //      CHECK:    %[[CONTRACT_1_0:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
 //      CHECK:    %[[CONTRACT_1_1_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[ACC_1_1]]
 //      CHECK:    %[[CONTRACT_1_1:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
 //      CHECK:    %[[CONTRACT_1_2_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[ACC_1_2]]
 //      CHECK:    %[[CONTRACT_1_2:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
 //      CHECK:    %[[CONTRACT_1_3_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[ACC_1_3]]
 //      CHECK:    %[[CONTRACT_1_3:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
 
 //      CHECK:    %[[CONTRACT_2_0_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[ACC_2_0]]
 //      CHECK:    %[[CONTRACT_2_0:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
 //      CHECK:    %[[CONTRACT_2_1_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[ACC_2_1]]
 //      CHECK:    %[[CONTRACT_2_1:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
 //      CHECK:    %[[CONTRACT_2_2_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[ACC_2_2]]
 //      CHECK:    %[[CONTRACT_2_2:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
 //      CHECK:    %[[CONTRACT_2_3_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[ACC_2_3]]
 //      CHECK:    %[[CONTRACT_2_3:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
 
 //      CHECK:    %[[CONTRACT_3_0_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[ACC_3_0]]
 //      CHECK:    %[[CONTRACT_3_0:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
 //      CHECK:    %[[CONTRACT_3_1_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[ACC_3_1]]
 //      CHECK:    %[[CONTRACT_3_1:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
 //      CHECK:    %[[CONTRACT_3_2_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[ACC_3_2]]
 //      CHECK:    %[[CONTRACT_3_2:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
 //      CHECK:    %[[CONTRACT_3_3_1:.+]] = vector.contract
-// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[ACC_3_3]]
 //      CHECK:    %[[CONTRACT_3_3:.+]] = vector.contract
 // CHECK-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
 
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_0_0]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_0_1]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_0_2]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_0_3]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+//      CHECK:    scf.yield %[[CONTRACT_0_0]], %[[CONTRACT_0_1]],
+// CHECK-SAME:      %[[CONTRACT_0_2]], %[[CONTRACT_0_3]], %[[CONTRACT_1_0]],
+// CHECK-SAME:      %[[CONTRACT_1_1]], %[[CONTRACT_1_2]], %[[CONTRACT_1_3]],
+// CHECK-SAME:      %[[CONTRACT_2_0]], %[[CONTRACT_2_1]], %[[CONTRACT_2_2]],
+// CHECK-SAME:      %[[CONTRACT_2_3]], %[[CONTRACT_3_0]], %[[CONTRACT_3_1]],
+// CHECK-SAME:      %[[CONTRACT_3_2]], %[[CONTRACT_3_3]]
 
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_1_0]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_1_1]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_1_2]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_1_3]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#0, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#1, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#2, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#3, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
 
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_2_0]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_2_1]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_2_2]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_2_3]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#4, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#5, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#6, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#7, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
 
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_3_0]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_3_1]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_3_2]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
-//  CHECK-DAG:    vector.transfer_write
-// CHECK-SAME:      %[[CONTRACT_3_3]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#8, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#9, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#10, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#11, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#12, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#13, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#14, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#15, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
 
 
-//  PROMOTE-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>
+//  PROMOTE-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 * 64 - (s0 floordiv 2) * 128)>
 //      PROMOTE: func @matmul_static_shape
 //  PROMOTE-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
 //  PROMOTE-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
@@ -266,23 +274,78 @@
 //  PROMOTE-DAG:  %[[C16:.+]] = constant 16
 //  PROMOTE-DAG:  %[[C32:.+]] = constant 32
 //  PROMOTE-DAG:  %[[C48:.+]] = constant 48
-//      PROMOTE:  scf.for %[[IV0:.+]]
+
+//  PROMOTE:    %[[ALLOC1:.+]] = alloc()
+//  PROMOTE:    %[[ALLOC2:.+]] = alloc()
+//  PROMOTE:    %[[RESULT_SUBVIEW:.+]] = subview %[[RET0]]
+//  PROMOTE:    %[[WGMEM_LHS_SUBVIEW:.+]] = subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]
+//  PROMOTE:    %[[WGMEM_RHS_SUBVIEW:.+]] = subview %[[ALLOC2]][0, 0] [32, 128] [1, 1]
+//  PROMOTE:    %[[SG_X:.+]] = gpu.subgroup_id
+//  PROMOTE:    %[[SG_Y:.+]] = divi_signed %[[SG_X]], %[[C2]]
+//  PROMOTE:    %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP4]]()[%[[SG_Y]]]
+//  PROMOTE:    %[[SG_LHS_SUBVIEW:.+]] = subview %[[WGMEM_LHS_SUBVIEW]][%[[SGOFFSET_Y]], 0]
+//  PROMOTE:    %[[SGOFFSET_X:.+]] = affine.apply #[[MAP4]]()[%[[SG_X]]]
+//  PROMOTE:    %[[SG_RHS_SUBVIEW:.+]] = subview %[[WGMEM_RHS_SUBVIEW]][0, %[[SGOFFSET_X]]]
+//  PROMOTE:    %[[SG_RESULT_SUBVIEW:.+]] = subview %[[RESULT_SUBVIEW]][%[[SGOFFSET_Y]], %[[SGOFFSET_X]]]
+
+//  PROMOTE-DAG:  %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+//  PROMOTE-DAG:  %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+//  PROMOTE-DAG:  %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+//  PROMOTE-DAG:  %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+
+//  PROMOTE-DAG:  %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+//  PROMOTE-DAG:  %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+//  PROMOTE-DAG:  %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+//  PROMOTE-DAG:  %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+
+//  PROMOTE-DAG:  %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+//  PROMOTE-DAG:  %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+//  PROMOTE-DAG:  %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+//  PROMOTE-DAG:  %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+//  PROMOTE-DAG:  %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+//  PROMOTE-DAG:  %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+//  PROMOTE-DAG:  %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+//  PROMOTE-DAG:  %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// PROMOTE-SAME:    %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
+
+//      PROMOTE:  %[[FOR_RES:.+]]:16 = scf.for %[[IV0:.+]] = {{.*}} to
+// PROMOTE-SAME:  iter_args(%[[ACC_0_0:.+]] = %[[READ_INIT_0_0]],
+// PROMOTE-SAME:  %[[ACC_0_1:.+]] = %[[READ_INIT_0_1]],
+// PROMOTE-SAME:  %[[ACC_0_2:.+]] = %[[READ_INIT_0_2]],
+// PROMOTE-SAME:  %[[ACC_0_3:.+]] = %[[READ_INIT_0_3]],
+// PROMOTE-SAME:  %[[ACC_1_0:.+]] = %[[READ_INIT_1_0]],
+// PROMOTE-SAME:  %[[ACC_1_1:.+]] = %[[READ_INIT_1_1]],
+// PROMOTE-SAME:  %[[ACC_1_2:.+]] = %[[READ_INIT_1_2]],
+// PROMOTE-SAME:  %[[ACC_1_3:.+]] = %[[READ_INIT_1_3]],
+// PROMOTE-SAME:  %[[ACC_2_0:.+]] = %[[READ_INIT_2_0]],
+// PROMOTE-SAME:  %[[ACC_2_1:.+]] = %[[READ_INIT_2_1]],
+// PROMOTE-SAME:  %[[ACC_2_2:.+]] = %[[READ_INIT_2_2]],
+// PROMOTE-SAME:  %[[ACC_2_3:.+]] = %[[READ_INIT_2_3]],
+// PROMOTE-SAME:  %[[ACC_3_0:.+]] = %[[READ_INIT_3_0]],
+// PROMOTE-SAME:  %[[ACC_3_1:.+]] = %[[READ_INIT_3_1]],
+// PROMOTE-SAME:  %[[ACC_3_2:.+]] = %[[READ_INIT_3_2]],
+// PROMOTE-SAME:  %[[ACC_3_3:.+]] = %[[READ_INIT_3_3]])
+
 //      PROMOTE:    %[[LHS_SUBVIEW:.+]] = subview %[[ARG0]]
 //      PROMOTE:    %[[RHS_SUBVIEW:.+]] = subview %[[ARG1]]
-//      PROMOTE:    %[[RESULT_SUBVIEW:.+]] = subview %[[RET0]]
-//      PROMOTE:    %[[ALLOC1:.+]] = alloc()
-//      PROMOTE:    %[[WGMEM_LHS_SUBVIEW:.+]] = subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]
-//      PROMOTE:    %[[ALLOC2:.+]] = alloc()
-//      PROMOTE:    %[[WGMEM_RHS_SUBVIEW:.+]] = subview %[[ALLOC2]][0, 0] [32, 128] [1, 1]
 //      PROMOTE:    linalg.copy(%[[LHS_SUBVIEW]], %[[WGMEM_LHS_SUBVIEW]])
 //      PROMOTE:    linalg.copy(%[[RHS_SUBVIEW]], %[[WGMEM_RHS_SUBVIEW]])
-//      PROMOTE:    %[[SG_X:.+]] = gpu.subgroup_id
-//      PROMOTE:    %[[SG_Y:.+]] = divi_signed %[[SG_X]], %[[C2]]
-//      PROMOTE:    %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP4]](%[[SG_Y]])
-//      PROMOTE:    %[[SG_LHS_SUBVIEW:.+]] = subview %[[WGMEM_LHS_SUBVIEW]][%[[SGOFFSET_Y]], 0]
-//      PROMOTE:    %[[SGOFFSET_X:.+]] = affine.apply #[[MAP4]](%[[SG_X]])
-//      PROMOTE:    %[[SG_RHS_SUBVIEW:.+]] = subview %[[WGMEM_RHS_SUBVIEW]][0, %[[SGOFFSET_X]]]
-//      PROMOTE:    %[[SG_RESULT_SUBVIEW:.+]] = subview %[[RESULT_SUBVIEW]][%[[SGOFFSET_Y]], %[[SGOFFSET_X]]]
 
 //  PROMOTE-DAG:    %[[READ_LHS_0_0:.+]] = vector.transfer_read
 // PROMOTE-SAME:      %[[SG_LHS_SUBVIEW]][%[[C0]], %[[C0]]]
@@ -322,142 +385,97 @@
 //  PROMOTE-DAG:    %[[READ_RHS_1_3:.+]] = vector.transfer_read
 // PROMOTE-SAME:      %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C48]]]
 
-//  PROMOTE-DAG:    %[[READ_INIT_0_0:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
-//  PROMOTE-DAG:    %[[READ_INIT_0_1:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
-//  PROMOTE-DAG:    %[[READ_INIT_0_2:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
-//  PROMOTE-DAG:    %[[READ_INIT_0_3:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
-
-//  PROMOTE-DAG:    %[[READ_INIT_1_0:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
-//  PROMOTE-DAG:    %[[READ_INIT_1_1:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
-//  PROMOTE-DAG:    %[[READ_INIT_1_2:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
-//  PROMOTE-DAG:    %[[READ_INIT_1_3:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
-
-//  PROMOTE-DAG:    %[[READ_INIT_2_0:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
-//  PROMOTE-DAG:    %[[READ_INIT_2_1:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
-//  PROMOTE-DAG:    %[[READ_INIT_2_2:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
-//  PROMOTE-DAG:    %[[READ_INIT_2_3:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
-
-//  PROMOTE-DAG:    %[[READ_INIT_3_0:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
-//  PROMOTE-DAG:    %[[READ_INIT_3_1:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
-//  PROMOTE-DAG:    %[[READ_INIT_3_2:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
-//  PROMOTE-DAG:    %[[READ_INIT_3_3:.+]] = vector.transfer_read
-// PROMOTE-SAME:      %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
-
 //      PROMOTE:    %[[CONTRACT_0_0_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[ACC_0_0]]
 //      PROMOTE:    %[[CONTRACT_0_0:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
 //      PROMOTE:    %[[CONTRACT_0_1_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[ACC_0_1]]
 //      PROMOTE:    %[[CONTRACT_0_1:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
 //      PROMOTE:    %[[CONTRACT_0_2_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[ACC_0_2]]
 //      PROMOTE:    %[[CONTRACT_0_2:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
 //      PROMOTE:    %[[CONTRACT_0_3_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// PROMOTE-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[ACC_0_3]]
 //      PROMOTE:    %[[CONTRACT_0_3:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
 
 //      PROMOTE:    %[[CONTRACT_1_0_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[ACC_1_0]]
 //      PROMOTE:    %[[CONTRACT_1_0:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
 //      PROMOTE:    %[[CONTRACT_1_1_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[ACC_1_1]]
 //      PROMOTE:    %[[CONTRACT_1_1:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
 //      PROMOTE:    %[[CONTRACT_1_2_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[ACC_1_2]]
 //      PROMOTE:    %[[CONTRACT_1_2:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
 //      PROMOTE:    %[[CONTRACT_1_3_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// PROMOTE-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[ACC_1_3]]
 //      PROMOTE:    %[[CONTRACT_1_3:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
 
 //      PROMOTE:    %[[CONTRACT_2_0_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[ACC_2_0]]
 //      PROMOTE:    %[[CONTRACT_2_0:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
 //      PROMOTE:    %[[CONTRACT_2_1_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[ACC_2_1]]
 //      PROMOTE:    %[[CONTRACT_2_1:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
 //      PROMOTE:    %[[CONTRACT_2_2_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[ACC_2_2]]
 //      PROMOTE:    %[[CONTRACT_2_2:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
 //      PROMOTE:    %[[CONTRACT_2_3_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// PROMOTE-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[ACC_2_3]]
 //      PROMOTE:    %[[CONTRACT_2_3:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
 
 //      PROMOTE:    %[[CONTRACT_3_0_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[ACC_3_0]]
 //      PROMOTE:    %[[CONTRACT_3_0:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
 //      PROMOTE:    %[[CONTRACT_3_1_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[ACC_3_1]]
 //      PROMOTE:    %[[CONTRACT_3_1:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
 //      PROMOTE:    %[[CONTRACT_3_2_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[ACC_3_2]]
 //      PROMOTE:    %[[CONTRACT_3_2:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
 //      PROMOTE:    %[[CONTRACT_3_3_1:.+]] = vector.contract
-// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// PROMOTE-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[ACC_3_3]]
 //      PROMOTE:    %[[CONTRACT_3_3:.+]] = vector.contract
 // PROMOTE-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
 
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_0_0]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_0_1]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_0_2]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_0_3]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+//      PROMOTE:    scf.yield %[[CONTRACT_0_0]], %[[CONTRACT_0_1]],
+// PROMOTE-SAME:      %[[CONTRACT_0_2]], %[[CONTRACT_0_3]], %[[CONTRACT_1_0]],
+// PROMOTE-SAME:      %[[CONTRACT_1_1]], %[[CONTRACT_1_2]], %[[CONTRACT_1_3]],
+// PROMOTE-SAME:      %[[CONTRACT_2_0]], %[[CONTRACT_2_1]], %[[CONTRACT_2_2]],
+// PROMOTE-SAME:      %[[CONTRACT_2_3]], %[[CONTRACT_3_0]], %[[CONTRACT_3_1]],
+// PROMOTE-SAME:      %[[CONTRACT_3_2]], %[[CONTRACT_3_3]]
 
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_1_0]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_1_1]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_1_2]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_1_3]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#0, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#1, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#2, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#3, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
 
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_2_0]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_2_1]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_2_2]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_2_3]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#4, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#5, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#6, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#7, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
 
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_3_0]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_3_1]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_3_2]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
-//  PROMOTE-DAG:    vector.transfer_write
-// PROMOTE-SAME:      %[[CONTRACT_3_3]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#8, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#9, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#10, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#11, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#12, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#13, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#14, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+//  PROMOTE-DAG:  vector.transfer_write %[[FOR_RES]]#15, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]] 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index f988a2c..992e35a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -1,41 +1,31 @@
 // RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" %s | IreeFileCheck %s
 
 module attributes {
   spv.target_env =
-    #spv.target_env<#spv.vce<v1.5,
+    #spv.target_env<#spv.vce<v1.3,
       [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
        StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
        UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
        GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
        GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
-       VariablePointersStorageBuffer, CooperativeMatrixNV],
+       VariablePointersStorageBuffer],
       [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
-       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
-       SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
-      {cooperative_matrix_properties_nv = [
-        {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
-         m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
-        {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
-         m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
-         scope = 3 : i32},
-        {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
-         m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
-         scope = 3 : i32}],
-       max_compute_shared_memory_size = 49152 : i32,
-       max_compute_workgroup_invocations = 1024 : i32,
-       max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
-       subgroup_size = 32 : i32}>} {
+       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+      ARM:IntegratedGPU,
+      {max_compute_shared_memory_size = 32768 : i32,
+       max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<512> : vector<3xi32>,
+       subgroup_size = 16 : i32}>} {
   func @matmul_static_shape()
     attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
     %arg0 = iree.placeholder for "interface buffer"
-      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
     %arg1 = iree.placeholder for "interface buffer"
-      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
     %ret0 = iree.placeholder for "interface buffer"
-      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
-    linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
-                 outs(%ret0 : memref<4096x4096xf16>)
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
+    linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
+                 outs(%ret0 : memref<4096x4096xf32>)
     return
   }
   func @matmul_static_shape__num_workgroups__
@@ -50,6 +40,10 @@
 }
 
 //    CHECK-LABEL: spv.func @matmul_static_shape
-// CHECK-COUNT-32:   spv.CooperativeMatrixLoadNV
-// CHECK-COUNT-32:   spv.CooperativeMatrixMulAddNV
-// CHECK-COUNT-16:   spv.CooperativeMatrixStoreNV
+// CHECK-COUNT-8:    spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+//          CHECK:   spv.loop
+// CHECK-COUNT-12:   spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-COUNT-32:   spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
+// CHECK-COUNT-8:   spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
+
+
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
new file mode 100644
index 0000000..1eb90bb
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
@@ -0,0 +1,57 @@
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" %s | IreeFileCheck %s
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.5,
+      [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+       StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+       UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+       GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+       GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+       VariablePointersStorageBuffer, CooperativeMatrixNV],
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+       SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+      {cooperative_matrix_properties_nv = [
+        {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+         m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+        {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+         m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+         scope = 3 : i32},
+        {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+         m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+         scope = 3 : i32}],
+       max_compute_shared_memory_size = 49152 : i32,
+       max_compute_workgroup_invocations = 1024 : i32,
+       max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+       subgroup_size = 32 : i32}>} {
+  func @matmul_static_shape()
+    attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+    %arg0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+    %arg1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+    %ret0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
+    linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
+                 outs(%ret0 : memref<4096x4096xf16>)
+    return
+  }
+  func @matmul_static_shape__num_workgroups__
+    (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+     !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+//    CHECK-LABEL: spv.func @matmul_static_shape
+// CHECK-COUNT-16:   spv.CooperativeMatrixLoadNV
+//          CHECK:   spv.loop
+// CHECK-COUNT-16:   spv.CooperativeMatrixLoadNV
+// CHECK-COUNT-32:   spv.CooperativeMatrixMulAddNV
+// CHECK-COUNT-16:   spv.CooperativeMatrixStoreNV