Adjust configuration used for matmul vectorization. (#3540)
This PR changes the codegen configuration to get the IREE
Vulkan/SPIR-V path generate the same code as ModelBuilder.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 0021c77..7673cf9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -33,6 +33,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
@@ -224,16 +225,15 @@
std::tie(workgroupSize[0], workgroupSize[1]) =
distributeProcs2D(maxWorkgroupSize);
workgroupSize[2] = 1;
- // TODO(#3131): This is just being hard-wired for now to be minimal viable,
- // but this can be decided better when we have better estimates of device
- // charecteristics.
+ // This is just being hard-wired for now to be minimal viable, but this can be
+ // decided better when we have better estimates of device charecteristics.
const int64_t nRowsPerWorkitem = 1;
const int64_t nColsPerWorkitem = 1;
const int64_t nBatchesPerWorkitem = 1;
int64_t tileSizeK = 0;
if (options.useWorkgroupMemory) {
- // TODO(#3131): This number should be decided based on the amount of
- // shared memory available (maybe). For now, just hard-wire it.
+ // This number should be decided based on the amount of shared memory
+ // available (maybe). For now, just hard-wire it.
tileSizeK = 32;
}
assert(tileSizes.empty());
@@ -244,40 +244,53 @@
return success();
}
-/// The size of the co-operative matrix multiply operations on the device.
-// TODO(#3131): This needs to be queried from the device.
-Optional<std::array<int64_t, 3>> getCooperativeMatmulSubgroupSize(
- Type dataType, Type accumulatorType) {
- if (dataType.isInteger(8) && accumulatorType.isInteger(32)) {
- return std::array<int64_t, 3>{8, 8, 32};
+/// Returns the size of the co-operative matrix multiply operations on the
+/// device.
+static Optional<SmallVector<int64_t, 4>> getCooperativeMatmulSubgroupSize(
+ spirv::ResourceLimitsAttr resourceLimits, Type lhsType, Type rhsType,
+ Type initType, Type resultType) {
+ for (auto coopMatmulProperties :
+ resourceLimits.cooperative_matrix_properties_nv()
+ .getAsRange<spirv::CooperativeMatrixPropertiesNVAttr>()) {
+ if (coopMatmulProperties.a_type().getValue() == lhsType &&
+ coopMatmulProperties.b_type().getValue() == rhsType &&
+ coopMatmulProperties.c_type().getValue() == initType &&
+ coopMatmulProperties.result_type().getValue() == resultType &&
+ spirv::symbolizeScope(
+ coopMatmulProperties.scope().getValue().getZExtValue())
+ .getValue() == spirv::Scope::Subgroup) {
+ return SmallVector<int64_t, 4>{
+ coopMatmulProperties.m_size().getValue().getSExtValue(),
+ coopMatmulProperties.n_size().getValue().getSExtValue(),
+ coopMatmulProperties.k_size().getValue().getSExtValue()};
+ }
}
- if (dataType.isF16() &&
- (accumulatorType.isF32() || accumulatorType.isF16())) {
- return std::array<int64_t, 3>{8, 8, 16};
- }
- return {};
+ return llvm::None;
}
/// Launch configuration for using spv.CooperativeMatrixMulAddNV
/// operations. Needs two levels of tiling.
static LogicalResult getConfigForCooperativeMatmul(
- linalg::MatmulOp op, spirv::ResourceLimitsAttr resourceLimits,
- TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize,
+ linalg::MatmulOp op, const SPIRVCodegenOptions &options,
+ spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
std::array<int64_t, 3> &numSubgroups) {
auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
!targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
return failure();
- ShapedType lhsType = op.getOperand(0).getType().cast<ShapedType>();
+ ShapedType lhsType = op.inputs().front().getType().cast<ShapedType>();
ArrayRef<int64_t> lhsShape = lhsType.getShape();
- ShapedType rhsType = op.getOperand(1).getType().cast<ShapedType>();
+ ShapedType rhsType = op.inputs().back().getType().cast<ShapedType>();
ArrayRef<int64_t> rhsShape = rhsType.getShape();
- ShapedType outputType = op.getOperand(2).getType().cast<ShapedType>();
+ ShapedType outputType =
+ op.output_buffers().front().getType().cast<ShapedType>();
- Optional<std::array<int64_t, 3>> coopMatmulSize =
- getCooperativeMatmulSubgroupSize(lhsType.getElementType(),
- outputType.getElementType());
+ Optional<SmallVector<int64_t, 4>> coopMatmulSize =
+ getCooperativeMatmulSubgroupSize(
+ resourceLimits, lhsType.getElementType(), rhsType.getElementType(),
+ outputType.getElementType(), outputType.getElementType());
if (!coopMatmulSize) return failure();
// Check that the matmul sizes are a multiple of the tilesize.
@@ -290,30 +303,35 @@
!isMultipleOf(rhsShape[0], (*coopMatmulSize)[2]))
return failure();
- // TODO(ravishankarm, antiagainst): For now hardwire the subgroup size.
- const int64_t subgroupSize = 32;
- unsigned maxWorkgroupSize =
- resourceLimits.max_compute_workgroup_invocations().getInt();
- std::tie(numSubgroups[0], numSubgroups[1]) =
- distributeProcs2D(maxWorkgroupSize / subgroupSize);
+ if (options.useWorkgroupMemory) {
+ numSubgroups[0] = 2;
+ numSubgroups[1] = 2;
+ } else {
+ numSubgroups[0] = 1;
+ numSubgroups[1] = 1;
+ }
numSubgroups[2] = 1;
- // TODO(#3131): This is just being hard-wired for now to be minimal viable,
- // but this can be decided better when we have better estimates of device
- // charecteristics.
- const int64_t numVecMatmulPerSubgroupX = 1;
- const int64_t numVecMatmulPerSubgroupY = 1;
+
+ // For now this is being hard-wired to be {4, 4, 2}. This can actually be set
+ // to whatever, but ultimately depends on register pressure.
+ const int64_t numVecMatmulPerSubgroupX = 4;
+ const int64_t numVecMatmulPerSubgroupY = 4;
+ const int64_t numVecMatmulPerSubgroupK = 2;
SmallVector<int64_t, 4> ts = {
numVecMatmulPerSubgroupY * (*coopMatmulSize)[0] * numSubgroups[1],
- numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0]};
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0],
+ numVecMatmulPerSubgroupK * (*coopMatmulSize)[2]};
tileSizes.emplace_back(std::move(ts));
- workgroupSize[0] = numSubgroups[0] * subgroupSize;
- workgroupSize[1] = numSubgroups[1];
+ int64_t subgroupSize =
+ resourceLimits.subgroup_size().getValue().getSExtValue();
+ workgroupSize[0] = numSubgroups[0] * numSubgroups[1] * subgroupSize;
+ workgroupSize[1] = 1;
workgroupSize[2] = 1;
// Subgroup tile sizes
SmallVector<int64_t, 4> subgroupTs = {
numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
- numVecMatmulPerSubgroupX * (*coopMatmulSize)[1], (*coopMatmulSize)[2]};
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1]};
tileSizes.emplace_back(std::move(subgroupTs));
return success();
}
@@ -325,9 +343,9 @@
TileSizesListType &tileSizes,
std::array<int64_t, 3> &workgroupSize,
std::array<int64_t, 3> &numSubgroups) {
- if (options.useVectorization &&
- succeeded(getConfigForCooperativeMatmul(op, resourceLimits, tileSizes,
- workgroupSize, numSubgroups))) {
+ if (options.useVectorization && succeeded(getConfigForCooperativeMatmul(
+ op, options, resourceLimits, tileSizes,
+ workgroupSize, numSubgroups))) {
return success();
}
unsigned maxWorkgroupSize =
@@ -478,9 +496,37 @@
void LaunchConfig::finalize(FuncOp funcOp) {
funcOp.walk([&](linalg::LinalgOp linalgOp) {
linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
- ;
});
}
+template <typename OpTy>
+static Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize(OpTy op) {
+ return llvm::None;
+}
+
+template <>
+Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::ContractionOp>(
+ vector::ContractionOp op) {
+ spirv::ResourceLimitsAttr resourceLimits =
+ spirv::lookupTargetEnv(op).getResourceLimits();
+ return getCooperativeMatmulSubgroupSize(
+ resourceLimits, op.getLhsType().getElementType(),
+ op.getRhsType().getElementType(),
+ op.getAccType().cast<VectorType>().getElementType(),
+ op.getResultType().cast<VectorType>().getElementType());
+}
+
+Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
+#define DISPATCH(opname) \
+ if (isa<opname>(op)) { \
+ return getOpNativeVectorSize(cast<opname>(op)); \
+ }
+
+ DISPATCH(vector::ContractionOp)
+
+#undef DISPATCH
+ return llvm::None;
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 919464e..918cd6b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -152,6 +152,10 @@
Optional<StringRef> getKey(Operation *op) const;
};
+/// Returns the size of instruction in `vector` dialect that maps directly to
+/// the hardware.
+Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op);
+
} // namespace iree_compiler
} // namespace mlir
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 20deb83..b74d5e3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -377,8 +377,11 @@
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(subgroupDistributionOptions),
- linalg::LinalgMarker(Identifier::get(getWorkgroupMarker(), context),
- Identifier::get(getVectorizeMarker(), context)));
+ linalg::LinalgMarker(
+ /*matchDisjunction=*/{Identifier::get(getWorkgroupMemoryMarker(),
+ context),
+ Identifier::get(getWorkgroupMarker(), context)},
+ /*replacement=*/Identifier::get(getVectorizeMarker(), context)));
}
//====---------------------------------------------------------------------===//
@@ -404,6 +407,19 @@
applyPatternsAndFoldGreedily(op, canonicalizationPatterns);
}
+//====---------------------------------------------------------------------===//
+// Patterns for unrolling vectors.
+//====---------------------------------------------------------------------===//
+
+static void populateVectorUnrollPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<vector::UnrollVectorPattern<vector::ContractionOp>>(
+ context,
+ vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
+ vector::populateVectorToVectorTransformationPatterns(patterns, context);
+}
+
void LinalgTileAndFusePass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -476,6 +492,12 @@
});
}
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After First level of tile+distribute ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
// make sure that the allocated scratchspace memory is constant sizes
@@ -484,19 +506,51 @@
populatePromotionPatterns(context, promotionPatterns);
applyPatternsAndFoldGreedily(funcOp, promotionPatterns);
applyCanonicalizationPatterns(context, funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Promotion ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
}
if (options.useVectorization) {
- OwningRewritePatternList secondLevelTilingPatterns;
- populateTilingToSubgroupPatterns(context, launchConfig,
- secondLevelTilingPatterns);
- applyPatternsAndFoldGreedily(funcOp, secondLevelTilingPatterns);
- applyCanonicalizationPatterns(context, funcOp);
+ {
+ OwningRewritePatternList secondLevelTilingPatterns;
+ populateTilingToSubgroupPatterns(context, launchConfig,
+ secondLevelTilingPatterns);
+ applyPatternsAndFoldGreedily(funcOp, secondLevelTilingPatterns);
+ applyCanonicalizationPatterns(context, funcOp);
- OwningRewritePatternList vectorizationPatterns;
- populateVectorizationPatterns(context, launchConfig,
- vectorizationPatterns);
- applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Second level Tiling ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
+ OwningRewritePatternList vectorizationPatterns;
+ populateVectorizationPatterns(context, launchConfig,
+ vectorizationPatterns);
+ applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Vectorization ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
+ OwningRewritePatternList vectorUnrollPatterns;
+ populateVectorUnrollPatterns(context, vectorUnrollPatterns);
+ applyPatternsAndFoldGreedily(funcOp, vectorUnrollPatterns);
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Vector Unroll ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
}
launchConfig.finalize(funcOp);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index d751065..d413df5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -63,12 +63,18 @@
llvm::cl::desc(
"Enable use of vectorization in SPIR-V code generation pipeline"),
llvm::cl::init(false)};
- Option<bool> useVectorPass{
- *this, "use-vector-pass",
- llvm::cl::desc("Enable use of Linalg vectorization in SPIR-V code "
+ Option<bool> useVectorizeMemrefPass{
+ *this, "use-vectorize-memref-pass",
+ llvm::cl::desc("Enable use of Vector loads/stores in SPIR-V code "
"generation pipeline"),
llvm::cl::init(false)};
+ Option<bool> useWorkgroupMemory{
+ *this, "use-workgroup-memory",
+ llvm::cl::desc(
+ "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+ llvm::cl::init(false)};
};
+} // namespace
static void addLinalgToSPIRVPasses(OpPassManager &pm,
const SPIRVCodegenOptions &options) {
@@ -99,7 +105,7 @@
//===--------------------------------------------------------------------===//
pm.addPass(createSplitDispatchFunctionPass());
pm.addPass(createLinalgTileAndFusePass(options));
- if (options.useVectorPass) {
+ if (options.useVectorizeMemrefPass) {
pm.addPass(createLoadStoreVectorizationPass());
}
pm.addPass(createCanonicalizerPass());
@@ -113,6 +119,9 @@
// - Linalg ops are converted to loop.for ops and mapped to workitems.
//===--------------------------------------------------------------------===//
pm.addPass(createConvertToGPUPass());
+ if (options.useVectorization) {
+ pm.addPass(createVectorToGPUPass());
+ }
pm.addPass(createLowerAffinePass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
@@ -163,6 +172,11 @@
pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
+ if (options.useVectorization) {
+ pm.addPass(createVectorizeMemref());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ }
//===--------------------------------------------------------------------===//
// Final conversion to SPIR-V dialect.
@@ -186,7 +200,6 @@
spirvModulePM.addPass(createCSEPass());
spirvModulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
}
-} // namespace
void buildSPIRVTransformPassPipeline(OpPassManager &pm,
const SPIRVCodegenOptions &options) {
@@ -255,6 +268,8 @@
const LinalgToSPIRVPassPipelineOptions &options) {
SPIRVCodegenOptions codegenOptions;
codegenOptions.useVectorization = options.useVectorization;
+ codegenOptions.useWorkgroupMemory = options.useWorkgroupMemory;
+ codegenOptions.useVectorizeMemrefPass = options.useVectorizeMemrefPass;
return codegenOptions;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 80805e2..30067d7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -29,7 +29,7 @@
SmallVector<int64_t, 3> tileSizes = {};
bool useWorkgroupMemory = false;
bool useVectorization = false;
- bool useVectorPass = false;
+ bool useVectorizeMemrefPass = false;
};
/// Pass to initialize the function that computes the number of workgroups for
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
deleted file mode 100644
index 31b3cea..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
+++ /dev/null
@@ -1,77 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-vectorization %s | IreeFileCheck %s
-
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader, CooperativeMatrixNV],
- [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
- {max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_static_shape()
- attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<128x64xf16>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<64x256xf16>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<128x256xf16>
- linalg.matmul ins(%arg0, %arg1 : memref<128x64xf16>, memref<64x256xf16>)
- outs(%ret0 : memref<128x256xf16>)
- return
- }
- func @matmul_static_shape__num_workgroups__
- (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
- !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
- attributes {sym_visibility = "private"}
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>
-// CHECK: func @matmul_static_shape
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[CST:.+]] = constant 0.0
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
-// CHECK-SAME: [%[[BOFFSET_Y]], 0] [32, 64]
-// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
-// CHECK-SAME: [0, %[[BOFFSET_X]]] [64, 32]
-// CHECK: %[[BOFFSET_Y_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[BOFFSET_X_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[BOFFSET_Y_2]], %[[BOFFSET_X_2]]] [32, 32]
-// CHECK: %[[SGID:.+]] = gpu.subgroup_id
-// CHECK: %[[SGID_Y:.+]] = divi_signed %[[SGID]], %[[C4]]
-// CHECK: scf.for %[[IV2:.+]] =
-// CHECK: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
-// CHECK: %[[SUBVIEW2_LHS:.+]] = subview %[[SUBVIEW_LHS]]
-// CHECK-SAME: [%[[SGOFFSET_Y]], %[[IV2]]] [8, 16]
-// CHECK: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
-// CHECK: %[[SUBVIEW2_RHS:.+]] = subview %[[SUBVIEW_RHS]]
-// CHECK-SAME: [%[[IV2]], %[[SGOFFSET_X]]] [16, 8]
-// CHECK: %[[SGOFFSET_Y_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
-// CHECK: %[[SGOFFSET_X_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
-// CHECK: %[[SUBVIEW2_RESULT:.+]] = subview %[[SUBVIEW_RESULT]]
-// CHECK-SAME: [%[[SGOFFSET_Y_2]], %[[SGOFFSET_X_2]]] [8, 8]
-// CHECK: %[[VTR_LHS:.+]] = vector.transfer_read %[[SUBVIEW2_LHS]]
-// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
-// CHECK: %[[VTR_RHS:.+]] = vector.transfer_read %[[SUBVIEW2_RHS]]
-// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
-// CHECK: %[[VTR_RESULT:.+]] = vector.transfer_read %[[SUBVIEW2_RESULT]]
-// CHECK-SAME: [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
-// CHECK: %[[VECTOR_CONTRACT:.+]] = vector.contract
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
-// CHECK-SAME: vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf16>
-// CHECK: vector.transfer_write %[[VECTOR_CONTRACT]], %[[SUBVIEW2_RESULT]]
-// CHECK-SAME: masked = [false, false]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
index dd846ce..6c2ba3c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -1,21 +1,463 @@
-// RUN: iree-opt --iree-codegen-linalg-to-gpu-matmul-vectorization-pass
-// RUN: -split-input-file %s --iree-codegen-linalg-to-gpu-unroll-size=8,8,32 \
-// RUN: -iree-codegen-linalg-to-gpu-matmul-licm | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-tile-and-fuse{use-vectorization},canonicalize,cse" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-tile-and-fuse{use-vectorization use-workgroup-memory},canonicalize,cse" %s | IreeFileCheck %s -check-prefix=PROMOTE
-// CHECK-LABEL: func @matmul_128x128x128
-// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
-func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
- linalg.matmul ins(%arg0, %arg1 : memref<128x128xf32>, memref<128x128xf32>) outs(%arg2 : memref<128x128xf32>)
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.5,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer, CooperativeMatrixNV],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+ SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+ {cooperative_matrix_properties_nv = [
+ {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+ m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+ scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+ scope = 3 : i32}],
+ max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>} {
+ func @matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
+ outs(%ret0 : memref<4096x4096xf16>)
return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
-// CHECK-DAG: %[[TILESIZE:.+]] = constant 32 : index
-// CHECK-DAG: %[[MATSIZE:.+]] = constant 128 : index
-// CHECK-DAG: %[[START:.+]] = constant 0 : index
-// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
-// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
-// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 64)>
+// CHECK: func @matmul_static_shape
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CST:.+]] = constant 0.0
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C48:.+]] = constant 48 : index
+// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: scf.for %[[IV0:.+]] =
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME: [%[[BOFFSET_Y]], %[[IV0]]] [64, 32]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [%[[IV0]], %[[BOFFSET_X]]] [32, 64]
+// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BOFFSET_Y]], %[[BOFFSET_X]]] [64, 64]
+// CHECK: %[[SUBVIEW_LHS_2:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME: [0, 0] [64, 32] [1, 1]
+// CHECK: %[[SUBVIEW_RHS_2:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME: [0, 0] [32, 64] [1, 1]
+// CHECK: %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME: [0, 0] [64, 64] [1, 1]
+// CHECK-DAG: %[[READ_LHS_0_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_0_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C16]]]
+
+// CHECK-DAG: %[[READ_LHS_1_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_1_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C16]], %[[C16]]]
+
+// CHECK-DAG: %[[READ_LHS_2_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C32]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_2_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C32]], %[[C16]]]
+
+// CHECK-DAG: %[[READ_LHS_3_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C48]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_3_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C48]], %[[C16]]]
+
+// CHECK-DAG: %[[READ_RHS_0_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_RHS_0_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C16]]]
+// CHECK-DAG: %[[READ_RHS_0_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C32]]]
+// CHECK-DAG: %[[READ_RHS_0_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_RHS_1_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: %[[READ_RHS_1_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C16]], %[[C16]]]
+// CHECK-DAG: %[[READ_RHS_1_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C16]], %[[C32]]]
+// CHECK-DAG: %[[READ_RHS_1_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C16]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+// CHECK-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+// CHECK-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+// CHECK-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+
+// CHECK: %[[CONTRACT_0_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// CHECK: %[[CONTRACT_0_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
+// CHECK: %[[CONTRACT_0_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// CHECK: %[[CONTRACT_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
+// CHECK: %[[CONTRACT_0_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// CHECK: %[[CONTRACT_0_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
+// CHECK: %[[CONTRACT_0_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// CHECK: %[[CONTRACT_0_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
+
+// CHECK: %[[CONTRACT_1_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// CHECK: %[[CONTRACT_1_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
+// CHECK: %[[CONTRACT_1_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// CHECK: %[[CONTRACT_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
+// CHECK: %[[CONTRACT_1_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// CHECK: %[[CONTRACT_1_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
+// CHECK: %[[CONTRACT_1_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// CHECK: %[[CONTRACT_1_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
+
+// CHECK: %[[CONTRACT_2_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// CHECK: %[[CONTRACT_2_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
+// CHECK: %[[CONTRACT_2_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// CHECK: %[[CONTRACT_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
+// CHECK: %[[CONTRACT_2_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// CHECK: %[[CONTRACT_2_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
+// CHECK: %[[CONTRACT_2_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// CHECK: %[[CONTRACT_2_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
+
+// CHECK: %[[CONTRACT_3_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// CHECK: %[[CONTRACT_3_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
+// CHECK: %[[CONTRACT_3_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// CHECK: %[[CONTRACT_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
+// CHECK: %[[CONTRACT_3_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// CHECK: %[[CONTRACT_3_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
+// CHECK: %[[CONTRACT_3_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// CHECK: %[[CONTRACT_3_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
+
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_0_0]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_0_1]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_0_2]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_0_3]], %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C48]]]
+
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_1_0]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_1_1]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_1_2]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_1_3]], %[[SUBVIEW_RESULT_2]][%[[C16]], %[[C48]]]
+
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_2_0]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_2_1]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_2_2]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_2_3]], %[[SUBVIEW_RESULT_2]][%[[C32]], %[[C48]]]
+
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_3_0]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_3_1]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C16]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_3_2]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C32]]]
+// CHECK-DAG: vector.transfer_write
+// CHECK-SAME: %[[CONTRACT_3_3]], %[[SUBVIEW_RESULT_2]][%[[C48]], %[[C48]]]
+
+
+// PROMOTE-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>
+// PROMOTE: func @matmul_static_shape
+// PROMOTE-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// PROMOTE-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// PROMOTE-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// PROMOTE-DAG: %[[C0:.+]] = constant 0
+// PROMOTE-DAG: %[[C2:.+]] = constant 2
+// PROMOTE-DAG: %[[C16:.+]] = constant 16
+// PROMOTE-DAG: %[[C32:.+]] = constant 32
+// PROMOTE-DAG: %[[C48:.+]] = constant 48
+// PROMOTE: scf.for %[[IV0:.+]]
+// PROMOTE: %[[LHS_SUBVIEW:.+]] = subview %[[ARG0]]
+// PROMOTE: %[[RHS_SUBVIEW:.+]] = subview %[[ARG1]]
+// PROMOTE: %[[RESULT_SUBVIEW:.+]] = subview %[[RET0]]
+// PROMOTE: %[[ALLOC1:.+]] = alloc()
+// PROMOTE: %[[WGMEM_LHS_SUBVIEW:.+]] = subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]
+// PROMOTE: %[[ALLOC2:.+]] = alloc()
+// PROMOTE: %[[WGMEM_RHS_SUBVIEW:.+]] = subview %[[ALLOC2]][0, 0] [32, 128] [1, 1]
+// PROMOTE: linalg.copy(%[[LHS_SUBVIEW]], %[[WGMEM_LHS_SUBVIEW]])
+// PROMOTE: linalg.copy(%[[RHS_SUBVIEW]], %[[WGMEM_RHS_SUBVIEW]])
+// PROMOTE: %[[SG_X:.+]] = gpu.subgroup_id
+// PROMOTE: %[[SG_Y:.+]] = divi_signed %[[SG_X]], %[[C2]]
+// PROMOTE: %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP4]](%[[SG_Y]])
+// PROMOTE: %[[SG_LHS_SUBVIEW:.+]] = subview %[[WGMEM_LHS_SUBVIEW]][%[[SGOFFSET_Y]], 0]
+// PROMOTE: %[[SGOFFSET_X:.+]] = affine.apply #[[MAP4]](%[[SG_X]])
+// PROMOTE: %[[SG_RHS_SUBVIEW:.+]] = subview %[[WGMEM_RHS_SUBVIEW]][0, %[[SGOFFSET_X]]]
+// PROMOTE: %[[SG_RESULT_SUBVIEW:.+]] = subview %[[RESULT_SUBVIEW]][%[[SGOFFSET_Y]], %[[SGOFFSET_X]]]
+
+// PROMOTE-DAG: %[[READ_LHS_0_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_LHS_0_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C0]], %[[C16]]]
+
+// PROMOTE-DAG: %[[READ_LHS_1_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_LHS_1_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C16]], %[[C16]]]
+
+// PROMOTE-DAG: %[[READ_LHS_2_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C32]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_LHS_2_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C32]], %[[C16]]]
+
+// PROMOTE-DAG: %[[READ_LHS_3_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C48]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_LHS_3_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_LHS_SUBVIEW]][%[[C48]], %[[C16]]]
+
+// PROMOTE-DAG: %[[READ_RHS_0_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_RHS_0_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C0]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_RHS_0_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C0]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_RHS_0_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C0]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_RHS_1_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_RHS_1_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_RHS_1_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_RHS_1_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RHS_SUBVIEW]][%[[C16]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_0_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_0_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_0_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_0_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_1_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_1_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_1_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_1_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_2_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_2_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_2_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_2_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+// PROMOTE-DAG: %[[READ_INIT_3_0:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+// PROMOTE-DAG: %[[READ_INIT_3_1:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+// PROMOTE-DAG: %[[READ_INIT_3_2:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+// PROMOTE-DAG: %[[READ_INIT_3_3:.+]] = vector.transfer_read
+// PROMOTE-SAME: %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
+
+// PROMOTE: %[[CONTRACT_0_0_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_0]], %[[READ_INIT_0_0]]
+// PROMOTE: %[[CONTRACT_0_0:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_0]], %[[CONTRACT_0_0_1]]
+// PROMOTE: %[[CONTRACT_0_1_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_1]], %[[READ_INIT_0_1]]
+// PROMOTE: %[[CONTRACT_0_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_1]], %[[CONTRACT_0_1_1]]
+// PROMOTE: %[[CONTRACT_0_2_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_2]], %[[READ_INIT_0_2]]
+// PROMOTE: %[[CONTRACT_0_2:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_2]], %[[CONTRACT_0_2_1]]
+// PROMOTE: %[[CONTRACT_0_3_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0_3]], %[[READ_INIT_0_3]]
+// PROMOTE: %[[CONTRACT_0_3:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1_3]], %[[CONTRACT_0_3_1]]
+
+// PROMOTE: %[[CONTRACT_1_0_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_0]], %[[READ_INIT_1_0]]
+// PROMOTE: %[[CONTRACT_1_0:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_0]], %[[CONTRACT_1_0_1]]
+// PROMOTE: %[[CONTRACT_1_1_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_1]], %[[READ_INIT_1_1]]
+// PROMOTE: %[[CONTRACT_1_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_1]], %[[CONTRACT_1_1_1]]
+// PROMOTE: %[[CONTRACT_1_2_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_2]], %[[READ_INIT_1_2]]
+// PROMOTE: %[[CONTRACT_1_2:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_2]], %[[CONTRACT_1_2_1]]
+// PROMOTE: %[[CONTRACT_1_3_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0_3]], %[[READ_INIT_1_3]]
+// PROMOTE: %[[CONTRACT_1_3:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1_3]], %[[CONTRACT_1_3_1]]
+
+// PROMOTE: %[[CONTRACT_2_0_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_0]], %[[READ_INIT_2_0]]
+// PROMOTE: %[[CONTRACT_2_0:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_0]], %[[CONTRACT_2_0_1]]
+// PROMOTE: %[[CONTRACT_2_1_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_1]], %[[READ_INIT_2_1]]
+// PROMOTE: %[[CONTRACT_2_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_1]], %[[CONTRACT_2_1_1]]
+// PROMOTE: %[[CONTRACT_2_2_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_2]], %[[READ_INIT_2_2]]
+// PROMOTE: %[[CONTRACT_2_2:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_2]], %[[CONTRACT_2_2_1]]
+// PROMOTE: %[[CONTRACT_2_3_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0_3]], %[[READ_INIT_2_3]]
+// PROMOTE: %[[CONTRACT_2_3:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1_3]], %[[CONTRACT_2_3_1]]
+
+// PROMOTE: %[[CONTRACT_3_0_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_0]], %[[READ_INIT_3_0]]
+// PROMOTE: %[[CONTRACT_3_0:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_0]], %[[CONTRACT_3_0_1]]
+// PROMOTE: %[[CONTRACT_3_1_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_1]], %[[READ_INIT_3_1]]
+// PROMOTE: %[[CONTRACT_3_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_1]], %[[CONTRACT_3_1_1]]
+// PROMOTE: %[[CONTRACT_3_2_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_2]], %[[READ_INIT_3_2]]
+// PROMOTE: %[[CONTRACT_3_2:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_2]], %[[CONTRACT_3_2_1]]
+// PROMOTE: %[[CONTRACT_3_3_1:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0_3]], %[[READ_INIT_3_3]]
+// PROMOTE: %[[CONTRACT_3_3:.+]] = vector.contract
+// PROMOTE-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1_3]], %[[CONTRACT_3_3_1]]
+
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_0_0]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_0_1]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_0_2]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_0_3]], %[[SG_RESULT_SUBVIEW]][%[[C0]], %[[C48]]]
+
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_1_0]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_1_1]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_1_2]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_1_3]], %[[SG_RESULT_SUBVIEW]][%[[C16]], %[[C48]]]
+
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_2_0]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_2_1]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_2_2]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_2_3]], %[[SG_RESULT_SUBVIEW]][%[[C32]], %[[C48]]]
+
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_3_0]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C0]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_3_1]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C16]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_3_2]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C32]]]
+// PROMOTE-DAG: vector.transfer_write
+// PROMOTE-SAME: %[[CONTRACT_3_3]], %[[SG_RESULT_SUBVIEW]][%[[C48]], %[[C48]]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization_licm.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization_licm.mlir
new file mode 100644
index 0000000..dd846ce
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization_licm.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt --iree-codegen-linalg-to-gpu-matmul-vectorization-pass
+// RUN: -split-input-file %s --iree-codegen-linalg-to-gpu-unroll-size=8,8,32 \
+// RUN: -iree-codegen-linalg-to-gpu-matmul-licm | IreeFileCheck %s
+
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<128x128xf32>, memref<128x128xf32>) outs(%arg2 : memref<128x128xf32>)
+ return
+}
+
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 32 : index
+// CHECK-DAG: %[[MATSIZE:.+]] = constant 128 : index
+// CHECK-DAG: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index fd05ac6..77822ab 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -1,37 +1,135 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-to-spirv-pipeline=use-vectorization %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" %s | IreeFileCheck %s
module attributes {
spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Float16, Shader, CooperativeMatrixNV],
- [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
- {max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ #spv.target_env<#spv.vce<v1.5,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer, CooperativeMatrixNV],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+ SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+ {cooperative_matrix_properties_nv = [
+ {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+ m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+ scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+ scope = 3 : i32}],
+ max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>} {
func @matmul_static_shape()
attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
- %0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0, operand_result_num = 0} : memref<128x64xf16>
- %1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1, operand_result_num = 1} : memref<64x256xf16>
- %2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0, operand_result_num = 2} : memref<128x256xf16>
- linalg.matmul ins(%0, %1 : memref<128x64xf16>, memref<64x256xf16>)
- outs(%2 : memref<128x256xf16>)
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
+ outs(%ret0 : memref<4096x4096xf16>)
return
}
func @matmul_static_shape__num_workgroups__
- (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
- !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
}
+
// CHECK-LABEL: spv.func @matmul_static_shape
// CHECK: spv.CooperativeMatrixLoadNV
// CHECK: spv.CooperativeMatrixLoadNV
// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+// CHECK: spv.CooperativeMatrixLoadNV
+
// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+// CHECK: spv.CooperativeMatrixMulAddNV
+
// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+// CHECK: spv.CooperativeMatrixStoreNV
+
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 91de450c..f2787b6 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -48,10 +48,10 @@
// llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
// "IREE Vulkan/SPIR-V backend options");
- static llvm::cl::opt<bool> clUseVectorPass(
- "iree-spirv-use-vector-pass",
+ static llvm::cl::opt<bool> clUseVectorizeMemrefPass(
+ "iree-spirv-use-vectorize-memref-pass",
llvm::cl::desc(
- "Enable use of Linalg vectorization in SPIR-V code generation"),
+ "Enable use of Memref vectorization in SPIR-V code generation"),
llvm::cl::init(false));
static llvm::cl::opt<bool> clUseWorkgroupMemory(
@@ -87,7 +87,8 @@
targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
clTileSizes.end());
targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
- targetOptions.codegenOptions.useVectorPass = clUseVectorPass;
+ targetOptions.codegenOptions.useVectorizeMemrefPass =
+ clUseVectorizeMemrefPass;
if (!clVulkanTargetEnv.empty()) {
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
} else {
diff --git a/iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.cpp b/iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.cpp
index 392adb5..24d47b0 100644
--- a/iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.cpp
+++ b/iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.cpp
@@ -256,10 +256,10 @@
mSize = 8: i32, nSize = 8: i32, kSize = 32: i32, aType = i8,
bType = i8, cType = i32, resultType = i32, scope = 3: i32
}, {
- mSize = 8: i32, nSize = 8: i32, kSize = 16: i32, aType = f16,
+ mSize = 16: i32, nSize = 16: i32, kSize = 16: i32, aType = f16,
bType = f16, cType = f16, resultType = f16, scope = 3: i32
}, {
- mSize = 8: i32, nSize = 8: i32, kSize = 16: i32, aType = f16,
+ mSize = 16: i32, nSize = 16: i32, kSize = 16: i32, aType = f16,
bType = f16, cType = f32, resultType = f32, scope = 3: i32
}]
}>)";
diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
index 7549884..5280c5a 100644
--- a/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
+++ b/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -10,7 +10,7 @@
// DEFAULT: #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>
// ADRENO640: #spv.target_env<#spv.vce<v1.3, [Shader, Int16, GroupNonUniform, GroupNonUniformVote, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, Qualcomm:IntegratedGPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[1024, 1024, 64]> : vector<3xi32>, subgroup_size = 64 : i32}>
// MALIG77: #spv.target_env<#spv.vce<v1.3, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}>
-// TURINGT4: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>
+// TURINGT4: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>
flow.executable @simpleMath_ex_dispatch_0 {
flow.dispatch.entry @simpleMath_rgn_dispatch_0 attributes {
workload = 4 : index
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index f4741ac..655864b 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -53,7 +53,7 @@
"log_plus_one.mlir",
"pw_add_multiwg.mlir",
],
- compiler_flags = ["-iree-spirv-use-vector-pass"],
+ compiler_flags = ["-iree-spirv-use-vectorize-memref-pass"],
driver = "vulkan",
target_backend = "vulkan-spirv",
)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index d5bd481..ef6c730 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -54,5 +54,5 @@
DRIVER
vulkan
COMPILER_FLAGS
- "-iree-spirv-use-vector-pass"
+ "-iree-spirv-use-vectorize-memref-pass"
)