Add strategy for Mali GPU in vectorization path and enable hoisting (#3650)
Add launch strategy for Mali GPU and a set of new transformation needed
and specify a good tile and workgroup size to use based on experiment.
Also enable hoisting of transfer_ops for the vector path.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 58106b0..db1a6bb 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -197,6 +197,27 @@
return success();
}
+/// Launch configuration for different known GPU configuration.
+static LogicalResult getTargetSpecificConfig(
+ linalg::MatmulOp op, const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ if (spirv::lookupTargetEnv(op).getVendorID() != spirv::Vendor::ARM)
+ return failure();
+ workgroupSize[0] = resourceLimits.subgroup_size().getInt();
+ workgroupSize[1] = 1;
+ workgroupSize[2] = 1;
+ SmallVector<int64_t, 4> ts = {8, 64, 4};
+ tileSizes.emplace_back(ts);
+ // No tiling at the subgroup level since this target doesn't use subgroup op
+ // or shared memory.
+ tileSizes.emplace_back();
+ SmallVector<int64_t, 4> threadTs = {ts[0], ts[1] / workgroupSize[0], ts[2]};
+ tileSizes.emplace_back(threadTs);
+ return success();
+}
+
template <>
LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
const SPIRVCodegenOptions &options,
@@ -208,6 +229,11 @@
op, options, resourceLimits, tileSizes,
workgroupSize, numSubgroups))) {
return success();
+ } else if (options.useVectorization &&
+ succeeded(getTargetSpecificConfig(op, options, resourceLimits,
+ tileSizes, workgroupSize,
+ numSubgroups))) {
+ return success();
}
unsigned maxWorkgroupSize =
resourceLimits.max_compute_workgroup_invocations().getInt();
@@ -368,13 +394,37 @@
template <>
Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::ContractionOp>(
vector::ContractionOp op) {
- spirv::ResourceLimitsAttr resourceLimits =
- spirv::lookupTargetEnv(op).getResourceLimits();
- return getCooperativeMatmulSubgroupSize(
- resourceLimits, op.getLhsType().getElementType(),
- op.getRhsType().getElementType(),
- op.getAccType().cast<VectorType>().getElementType(),
- op.getResultType().cast<VectorType>().getElementType());
+ auto targetEnvAttr = spirv::lookupTargetEnv(op);
+ auto targetEnv = spirv::TargetEnv(targetEnvAttr);
+ if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
+ targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
+ spirv::ResourceLimitsAttr resourceLimits =
+ targetEnvAttr.getResourceLimits();
+ return getCooperativeMatmulSubgroupSize(
+ resourceLimits, op.getLhsType().getElementType(),
+ op.getRhsType().getElementType(),
+ op.getAccType().cast<VectorType>().getElementType(),
+ op.getResultType().cast<VectorType>().getElementType());
+ } else {
+ // Map to vec4 fma operations.
+ return SmallVector<int64_t, 4>({1, 4, 1});
+ }
+}
+
+template <>
+Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::TransferReadOp>(
+ vector::TransferReadOp op) {
+ auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+ if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
+ targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
+ // Don't unroll cooperative martrix load as they should match the size of
+ // the contract.
+ return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
+ op.getVectorType().getDimSize(1));
+ } else {
+ // Map to load4.
+ return SmallVector<int64_t, 4>({1, 4});
+ }
}
Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
@@ -384,6 +434,7 @@
}
DISPATCH(vector::ContractionOp)
+ DISPATCH(vector::TransferReadOp)
#undef DISPATCH
return llvm::None;
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index a78343d..5d69e24 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -31,6 +31,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/Function.h"
@@ -40,6 +41,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
@@ -407,6 +409,57 @@
getVectorizeMarker(), context));
}
+//===----------------------------------------------------------------------===//
+// Patterns and methods for thread tiling.
+//===----------------------------------------------------------------------===//
+
+/// Patterns for third level tiling to target invocations.
+static void populateTilingToInvocationPatterns(
+ MLIRContext *context, const LaunchConfig &launchConfig,
+ OwningRewritePatternList &patterns) {
+ linalg::TileSizeComputationFunction getInnerTileSizeFn =
+ [&launchConfig](OpBuilder &builder, Operation *operation) {
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 2);
+ if (tileSizes.empty()) return SmallVector<Value, 4>();
+ SmallVector<Value, 4> tileSizesVal;
+ tileSizesVal.reserve(tileSizes.size());
+ for (auto val : tileSizes) {
+ tileSizesVal.push_back(
+ builder.create<ConstantIndexOp>(operation->getLoc(), val));
+ }
+ return tileSizesVal;
+ };
+
+ auto getThreadProcInfoFn = [&launchConfig](
+ OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ Type indexType = builder.getIndexType();
+ SmallVector<linalg::ProcInfo, 2> procInfo(2);
+ procInfo[1] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
+ builder.getStringAttr("x")),
+ builder.create<ConstantIndexOp>(
+ loc, launchConfig.getWorkgroupSize()[0])};
+ procInfo[0] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
+ builder.getStringAttr("y")),
+ builder.create<ConstantIndexOp>(
+ loc, launchConfig.getWorkgroupSize()[1])};
+ return procInfo;
+ };
+ linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+ getThreadProcInfoFn,
+ {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+ patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizeComputationFunction(getInnerTileSizeFn)
+ .setDistributionOptions(subgroupDistributionOptions),
+ getLinalgMatchAndReplaceMarker(
+ {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+ getVectorizeMarker(), context));
+}
+
//====---------------------------------------------------------------------===//
// Patterns for vectorization
//====---------------------------------------------------------------------===//
@@ -436,6 +489,9 @@
static void populateVectorUnrollPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
+ patterns.insert<vector::UnrollVectorPattern<vector::TransferReadOp>>(
+ context,
+ vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
patterns.insert<vector::UnrollVectorPattern<vector::ContractionOp>>(
context,
vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
@@ -444,6 +500,41 @@
}
//====---------------------------------------------------------------------===//
+// Vector patterns
+//====---------------------------------------------------------------------===//
+
+static void applyVectorTransformation(FuncOp funcOp) {
+ {
+ OwningRewritePatternList vectorUnrollPatterns;
+ populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
+ applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
+
+ OwningRewritePatternList canonicalizationPatterns;
+ vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns,
+ funcOp.getContext());
+ applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Vector Unroll ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
+ // TODO(ravishankarm): remove this transformation once allocations get
+ // inserted at the top of the function.
+ linalg::hoistViewAllocOps(funcOp);
+ linalg::hoistRedundantVectorTransfers(funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Hoisting ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+}
+
+//====---------------------------------------------------------------------===//
// Main pass implementation
//====---------------------------------------------------------------------===//
@@ -549,6 +640,7 @@
applyPatternsAndFoldGreedily(funcOp,
std::move(secondLevelTilingPatterns));
applyCanonicalizationPatterns(context, funcOp);
+ promoteSingleIterationLoops(funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After Second level Tiling ---\n";
@@ -558,6 +650,22 @@
}
{
+ OwningRewritePatternList thirdLevelTilingPatterns;
+ populateTilingToInvocationPatterns(context, launchConfig,
+ thirdLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp,
+ std::move(thirdLevelTilingPatterns));
+ applyCanonicalizationPatterns(context, funcOp);
+ promoteSingleIterationLoops(funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Third level Tiling ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
OwningRewritePatternList vectorizationPatterns;
populateVectorizationPatterns(context, launchConfig,
vectorizationPatterns);
@@ -569,16 +677,7 @@
});
}
- {
- OwningRewritePatternList vectorUnrollPatterns;
- populateVectorUnrollPatterns(context, vectorUnrollPatterns);
- applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns));
- LLVM_DEBUG({
- llvm::dbgs() << "--- After Vector Unroll ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
+ applyVectorTransformation(funcOp);
}
launchConfig.finalize(funcOp);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index b09f909..b2c7142 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -20,6 +20,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
@@ -175,6 +176,7 @@
pm.addPass(createCSEPass());
if (options.useVectorization) {
pm.addPass(createVectorizeMemref());
+ pm.addPass(createForOpCanonicalizationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
index 6c2ba3c..d98ea68 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -61,21 +61,74 @@
// CHECK-DAG: %[[C48:.+]] = constant 48 : index
// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: scf.for %[[IV0:.+]] =
-// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
-// CHECK-SAME: [%[[BOFFSET_Y]], %[[IV0]]] [64, 32]
-// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
-// CHECK-SAME: [%[[IV0]], %[[BOFFSET_X]]] [32, 64]
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
// CHECK-SAME: [%[[BOFFSET_Y]], %[[BOFFSET_X]]] [64, 64]
+// CHECK: %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME: [0, 0] [64, 64] [1, 1]
+
+// CHECK-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+
+// CHECK: %[[FOR_RES:.+]]:16 = scf.for %[[IV0:.+]] = {{.*}} to
+// CHECK-SAME: iter_args(%[[ACC_0_0:.+]] = %[[READ_INIT_0_0]],
+// CHECK-SAME: %[[ACC_0_1:.+]] = %[[READ_INIT_0_1]],
+// CHECK-SAME: %[[ACC_0_2:.+]] = %[[READ_INIT_0_2]],
+// CHECK-SAME: %[[ACC_0_3:.+]] = %[[READ_INIT_0_3]],
+// CHECK-SAME: %[[ACC_1_0:.+]] = %[[READ_INIT_1_0]],
+// CHECK-SAME: %[[ACC_1_1:.+]] = %[[READ_INIT_1_1]],
+// CHECK-SAME: %[[ACC_1_2:.+]] = %[[READ_INIT_1_2]],
+// CHECK-SAME: %[[ACC_1_3:.+]] = %[[READ_INIT_1_3]],
+// CHECK-SAME: %[[ACC_2_0:.+]] = %[[READ_INIT_2_0]],
+// CHECK-SAME: %[[ACC_2_1:.+]] = %[[READ_INIT_2_1]],
+// CHECK-SAME: %[[ACC_2_2:.+]] = %[[READ_INIT_2_2]],
+// CHECK-SAME: %[[ACC_2_3:.+]] = %[[READ_INIT_2_3]],
+// CHECK-SAME: %[[ACC_3_0:.+]] = %[[READ_INIT_3_0]],
+// CHECK-SAME: %[[ACC_3_1:.+]] = %[[READ_INIT_3_1]],
+// CHECK-SAME: %[[ACC_3_2:.+]] = %[[READ_INIT_3_2]],
+// CHECK-SAME: %[[ACC_3_3:.+]] = %[[READ_INIT_3_3]])
+// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME: [%[[BOFFSET_Y]], %[[IV0]]] [64, 32]
+// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [%[[IV0]], %[[BOFFSET_X]]] [32, 64]
// CHECK: %[[SUBVIEW_LHS_2:.+]] = subview %[[SUBVIEW_LHS]]
// CHECK-SAME: [0, 0] [64, 32] [1, 1]
// CHECK: %[[SUBVIEW_RHS_2:.+]] = subview %[[SUBVIEW_RHS]]
// CHECK-SAME: [0, 0] [32, 64] [1, 1]
-// CHECK: %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
-// CHECK-SAME: [0, 0] [64, 64] [1, 1]
// CHECK-DAG: %[[READ_LHS_0_0:.+]] = vector.transfer_read
// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C0]]]
@@ -115,148 +168,103 @@
// CHECK-DAG: %[[READ_RHS_1_3:.+]] = vector.transfer_read
// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C16]], %[[C48]]]
-// CHECK-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
-// CHECK-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
-// CHECK-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
-// CHECK-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
-
-// CHECK-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
-// CHECK-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
-// CHECK-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
-// CHECK-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
-
-// CHECK-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
-// CHECK-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
-// CHECK-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
-// CHECK-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
-
-// CHECK-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
-// CHECK-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
-// CHECK-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
-// CHECK-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
-
// CHECK: %[[CONTRACT_0_0_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[ACC_0_0]]
// CHECK: %[[CONTRACT_0_0:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
// CHECK: %[[CONTRACT_0_1_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[ACC_0_1]]
// CHECK: %[[CONTRACT_0_1:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
// CHECK: %[[CONTRACT_0_2_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[ACC_0_2]]
// CHECK: %[[CONTRACT_0_2:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
// CHECK: %[[CONTRACT_0_3_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[ACC_0_3]]
// CHECK: %[[CONTRACT_0_3:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
// CHECK: %[[CONTRACT_1_0_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[ACC_1_0]]
// CHECK: %[[CONTRACT_1_0:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
// CHECK: %[[CONTRACT_1_1_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[ACC_1_1]]
// CHECK: %[[CONTRACT_1_1:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
// CHECK: %[[CONTRACT_1_2_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[ACC_1_2]]
// CHECK: %[[CONTRACT_1_2:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
// CHECK: %[[CONTRACT_1_3_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[ACC_1_3]]
// CHECK: %[[CONTRACT_1_3:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
// CHECK: %[[CONTRACT_2_0_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[ACC_2_0]]
// CHECK: %[[CONTRACT_2_0:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
// CHECK: %[[CONTRACT_2_1_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[ACC_2_1]]
// CHECK: %[[CONTRACT_2_1:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
// CHECK: %[[CONTRACT_2_2_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[ACC_2_2]]
// CHECK: %[[CONTRACT_2_2:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
// CHECK: %[[CONTRACT_2_3_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[ACC_2_3]]
// CHECK: %[[CONTRACT_2_3:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
// CHECK: %[[CONTRACT_3_0_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[ACC_3_0]]
// CHECK: %[[CONTRACT_3_0:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
// CHECK: %[[CONTRACT_3_1_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[ACC_3_1]]
// CHECK: %[[CONTRACT_3_1:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
// CHECK: %[[CONTRACT_3_2_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[ACC_3_2]]
// CHECK: %[[CONTRACT_3_2:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
// CHECK: %[[CONTRACT_3_3_1:.+]] = vector.contract
-// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[ACC_3_3]]
// CHECK: %[[CONTRACT_3_3:.+]] = vector.contract
// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_0_0]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_0_1]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_0_2]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_0_3]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+// CHECK: scf.yield %[[CONTRACT_0_0]], %[[CONTRACT_0_1]],
+// CHECK-SAME: %[[CONTRACT_0_2]], %[[CONTRACT_0_3]], %[[CONTRACT_1_0]],
+// CHECK-SAME: %[[CONTRACT_1_1]], %[[CONTRACT_1_2]], %[[CONTRACT_1_3]],
+// CHECK-SAME: %[[CONTRACT_2_0]], %[[CONTRACT_2_1]], %[[CONTRACT_2_2]],
+// CHECK-SAME: %[[CONTRACT_2_3]], %[[CONTRACT_3_0]], %[[CONTRACT_3_1]],
+// CHECK-SAME: %[[CONTRACT_3_2]], %[[CONTRACT_3_3]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_1_0]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_1_1]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_1_2]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_1_3]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#0, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#1, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#2, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#3, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_2_0]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_2_1]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_2_2]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_2_3]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#4, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#5, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#6, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#7, %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_3_0]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_3_1]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_3_2]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
-// CHECK-DAG: vector.transfer_write
-// CHECK-SAME: %[[CONTRACT_3_3]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#8, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#9, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#10, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#11, %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#12, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#13, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#14, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#15, %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
-// PROMOTE-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>
+// PROMOTE-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 * 64 - (s0 floordiv 2) * 128)>
// PROMOTE: func @matmul_static_shape
// PROMOTE-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
// PROMOTE-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
@@ -266,23 +274,78 @@
// PROMOTE-DAG: %[[C16:.+]] = constant 16
// PROMOTE-DAG: %[[C32:.+]] = constant 32
// PROMOTE-DAG: %[[C48:.+]] = constant 48
-// PROMOTE: scf.for %[[IV0:.+]]
+
+// PROMOTE: %[[ALLOC1:.+]] = alloc()
+// PROMOTE: %[[ALLOC2:.+]] = alloc()
+// PROMOTE: %[[RESULT_SUBVIEW:.+]] = subview %[[RET0]]
+// PROMOTE: %[[WGMEM_LHS_SUBVIEW:.+]] = subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]
+// PROMOTE: %[[WGMEM_RHS_SUBVIEW:.+]] = subview %[[ALLOC2]][0, 0] [32, 128] [1, 1]
+// PROMOTE: %[[SG_X:.+]] = gpu.subgroup_id
+// PROMOTE: %[[SG_Y:.+]] = divi_signed %[[SG_X]], %[[C2]]
+// PROMOTE: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP4]]()[%[[SG_Y]]]
+// PROMOTE: %[[SG_LHS_SUBVIEW:.+]] = subview %[[WGMEM_LHS_SUBVIEW]][%[[SGOFFSET_Y]], 0]
+// PROMOTE: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP4]]()[%[[SG_X]]]
+// PROMOTE: %[[SG_RHS_SUBVIEW:.+]] = subview %[[WGMEM_RHS_SUBVIEW]][0, %[[SGOFFSET_X]]]
+// PROMOTE: %[[SG_RESULT_SUBVIEW:.+]] = subview %[[RESULT_SUBVIEW]][%[[SGOFFSET_Y]], %[[SGOFFSET_X]]]
+
+// PROMOTE-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
+
+// PROMOTE: %[[FOR_RES:.+]]:16 = scf.for %[[IV0:.+]] = {{.*}} to
+// PROMOTE-SAME: iter_args(%[[ACC_0_0:.+]] = %[[READ_INIT_0_0]],
+// PROMOTE-SAME: %[[ACC_0_1:.+]] = %[[READ_INIT_0_1]],
+// PROMOTE-SAME: %[[ACC_0_2:.+]] = %[[READ_INIT_0_2]],
+// PROMOTE-SAME: %[[ACC_0_3:.+]] = %[[READ_INIT_0_3]],
+// PROMOTE-SAME: %[[ACC_1_0:.+]] = %[[READ_INIT_1_0]],
+// PROMOTE-SAME: %[[ACC_1_1:.+]] = %[[READ_INIT_1_1]],
+// PROMOTE-SAME: %[[ACC_1_2:.+]] = %[[READ_INIT_1_2]],
+// PROMOTE-SAME: %[[ACC_1_3:.+]] = %[[READ_INIT_1_3]],
+// PROMOTE-SAME: %[[ACC_2_0:.+]] = %[[READ_INIT_2_0]],
+// PROMOTE-SAME: %[[ACC_2_1:.+]] = %[[READ_INIT_2_1]],
+// PROMOTE-SAME: %[[ACC_2_2:.+]] = %[[READ_INIT_2_2]],
+// PROMOTE-SAME: %[[ACC_2_3:.+]] = %[[READ_INIT_2_3]],
+// PROMOTE-SAME: %[[ACC_3_0:.+]] = %[[READ_INIT_3_0]],
+// PROMOTE-SAME: %[[ACC_3_1:.+]] = %[[READ_INIT_3_1]],
+// PROMOTE-SAME: %[[ACC_3_2:.+]] = %[[READ_INIT_3_2]],
+// PROMOTE-SAME: %[[ACC_3_3:.+]] = %[[READ_INIT_3_3]])
+
// PROMOTE: %[[LHS_SUBVIEW:.+]] = subview %[[ARG0]]
// PROMOTE: %[[RHS_SUBVIEW:.+]] = subview %[[ARG1]]
-// PROMOTE: %[[RESULT_SUBVIEW:.+]] = subview %[[RET0]]
-// PROMOTE: %[[ALLOC1:.+]] = alloc()
-// PROMOTE: %[[WGMEM_LHS_SUBVIEW:.+]] = subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]
-// PROMOTE: %[[ALLOC2:.+]] = alloc()
-// PROMOTE: %[[WGMEM_RHS_SUBVIEW:.+]] = subview %[[ALLOC2]][0, 0] [32, 128] [1, 1]
// PROMOTE: linalg.copy(%[[LHS_SUBVIEW]], %[[WGMEM_LHS_SUBVIEW]])
// PROMOTE: linalg.copy(%[[RHS_SUBVIEW]], %[[WGMEM_RHS_SUBVIEW]])
-// PROMOTE: %[[SG_X:.+]] = gpu.subgroup_id
-// PROMOTE: %[[SG_Y:.+]] = divi_signed %[[SG_X]], %[[C2]]
-// PROMOTE: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP4]](%[[SG_Y]])
-// PROMOTE: %[[SG_LHS_SUBVIEW:.+]] = subview %[[WGMEM_LHS_SUBVIEW]][%[[SGOFFSET_Y]], 0]
-// PROMOTE: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP4]](%[[SG_X]])
-// PROMOTE: %[[SG_RHS_SUBVIEW:.+]] = subview %[[WGMEM_RHS_SUBVIEW]][0, %[[SGOFFSET_X]]]
-// PROMOTE: %[[SG_RESULT_SUBVIEW:.+]] = subview %[[RESULT_SUBVIEW]][%[[SGOFFSET_Y]], %[[SGOFFSET_X]]]
// PROMOTE-DAG: %[[READ_LHS_0_0:.+]] = vector.transfer_read
// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C0]], %[[C0]]]
@@ -322,142 +385,97 @@
// PROMOTE-DAG: %[[READ_RHS_1_3:.+]] = vector.transfer_read
// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C48]]]
-// PROMOTE-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
-// PROMOTE-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
-// PROMOTE-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
-// PROMOTE-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
-
-// PROMOTE-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
-// PROMOTE-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
-// PROMOTE-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
-// PROMOTE-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
-
-// PROMOTE-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
-// PROMOTE-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
-// PROMOTE-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
-// PROMOTE-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
-
-// PROMOTE-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
-// PROMOTE-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
-// PROMOTE-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
-// PROMOTE-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
-// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
-
// PROMOTE: %[[CONTRACT_0_0_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[ACC_0_0]]
// PROMOTE: %[[CONTRACT_0_0:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
// PROMOTE: %[[CONTRACT_0_1_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[ACC_0_1]]
// PROMOTE: %[[CONTRACT_0_1:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
// PROMOTE: %[[CONTRACT_0_2_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[ACC_0_2]]
// PROMOTE: %[[CONTRACT_0_2:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
// PROMOTE: %[[CONTRACT_0_3_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[ACC_0_3]]
// PROMOTE: %[[CONTRACT_0_3:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
// PROMOTE: %[[CONTRACT_1_0_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[ACC_1_0]]
// PROMOTE: %[[CONTRACT_1_0:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
// PROMOTE: %[[CONTRACT_1_1_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[ACC_1_1]]
// PROMOTE: %[[CONTRACT_1_1:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
// PROMOTE: %[[CONTRACT_1_2_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[ACC_1_2]]
// PROMOTE: %[[CONTRACT_1_2:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
// PROMOTE: %[[CONTRACT_1_3_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[ACC_1_3]]
// PROMOTE: %[[CONTRACT_1_3:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
// PROMOTE: %[[CONTRACT_2_0_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[ACC_2_0]]
// PROMOTE: %[[CONTRACT_2_0:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
// PROMOTE: %[[CONTRACT_2_1_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[ACC_2_1]]
// PROMOTE: %[[CONTRACT_2_1:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
// PROMOTE: %[[CONTRACT_2_2_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[ACC_2_2]]
// PROMOTE: %[[CONTRACT_2_2:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
// PROMOTE: %[[CONTRACT_2_3_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[ACC_2_3]]
// PROMOTE: %[[CONTRACT_2_3:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
// PROMOTE: %[[CONTRACT_3_0_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[ACC_3_0]]
// PROMOTE: %[[CONTRACT_3_0:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
// PROMOTE: %[[CONTRACT_3_1_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[ACC_3_1]]
// PROMOTE: %[[CONTRACT_3_1:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
// PROMOTE: %[[CONTRACT_3_2_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[ACC_3_2]]
// PROMOTE: %[[CONTRACT_3_2:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
// PROMOTE: %[[CONTRACT_3_3_1:.+]] = vector.contract
-// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[ACC_3_3]]
// PROMOTE: %[[CONTRACT_3_3:.+]] = vector.contract
// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_0_0]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_0_1]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_0_2]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_0_3]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+// PROMOTE: scf.yield %[[CONTRACT_0_0]], %[[CONTRACT_0_1]],
+// PROMOTE-SAME: %[[CONTRACT_0_2]], %[[CONTRACT_0_3]], %[[CONTRACT_1_0]],
+// PROMOTE-SAME: %[[CONTRACT_1_1]], %[[CONTRACT_1_2]], %[[CONTRACT_1_3]],
+// PROMOTE-SAME: %[[CONTRACT_2_0]], %[[CONTRACT_2_1]], %[[CONTRACT_2_2]],
+// PROMOTE-SAME: %[[CONTRACT_2_3]], %[[CONTRACT_3_0]], %[[CONTRACT_3_1]],
+// PROMOTE-SAME: %[[CONTRACT_3_2]], %[[CONTRACT_3_3]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_1_0]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_1_1]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_1_2]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_1_3]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#0, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#1, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#2, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#3, %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_2_0]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_2_1]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_2_2]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_2_3]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#4, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#5, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#6, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#7, %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_3_0]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_3_1]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_3_2]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
-// PROMOTE-DAG: vector.transfer_write
-// PROMOTE-SAME: %[[CONTRACT_3_3]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#8, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#9, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#10, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#11, %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#12, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#13, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#14, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write %[[FOR_RES]]#15, %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index f988a2c..992e35a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -1,41 +1,31 @@
// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" %s | IreeFileCheck %s
module attributes {
spv.target_env =
- #spv.target_env<#spv.vce<v1.5,
+ #spv.target_env<#spv.vce<v1.3,
[Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
- VariablePointersStorageBuffer, CooperativeMatrixNV],
+ VariablePointersStorageBuffer],
[SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
- SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
- SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- {cooperative_matrix_properties_nv = [
- {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
- m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
- {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
- m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
- scope = 3 : i32},
- {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
- m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
- scope = 3 : i32}],
- max_compute_shared_memory_size = 49152 : i32,
- max_compute_workgroup_invocations = 1024 : i32,
- max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
- subgroup_size = 32 : i32}>} {
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
func @matmul_static_shape()
attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
%arg0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
%arg1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
%ret0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
- linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
- outs(%ret0 : memref<4096x4096xf16>)
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
+ outs(%ret0 : memref<4096x4096xf32>)
return
}
func @matmul_static_shape__num_workgroups__
@@ -50,6 +40,10 @@
}
// CHECK-LABEL: spv.func @matmul_static_shape
-// CHECK-COUNT-32: spv.CooperativeMatrixLoadNV
-// CHECK-COUNT-32: spv.CooperativeMatrixMulAddNV
-// CHECK-COUNT-16: spv.CooperativeMatrixStoreNV
+// CHECK-COUNT-8: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK: spv.loop
+// CHECK-COUNT-12: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-COUNT-32: spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
+// CHECK-COUNT-8: spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
+
+
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
new file mode 100644
index 0000000..1eb90bb
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir
@@ -0,0 +1,57 @@
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.5,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer, CooperativeMatrixNV],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+ SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+ {cooperative_matrix_properties_nv = [
+ {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+ m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+ scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+ scope = 3 : i32}],
+ max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
+ outs(%ret0 : memref<4096x4096xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-LABEL: spv.func @matmul_static_shape
+// CHECK-COUNT-16: spv.CooperativeMatrixLoadNV
+// CHECK: spv.loop
+// CHECK-COUNT-16: spv.CooperativeMatrixLoadNV
+// CHECK-COUNT-32: spv.CooperativeMatrixMulAddNV
+// CHECK-COUNT-16: spv.CooperativeMatrixStoreNV