Merge spirv_dynamic_pipeline into main
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
index 281f3c7..ad24dac 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
@@ -21,6 +21,39 @@
namespace iree_compiler {
namespace {
+/// Sets the hal.interace.workgroup.size operation to the constant value passed
+/// in as `workloadPerWorkgroup`. The number of entries in
+/// `workloadPerWorkgroup` is at least as much as the dimensionality of the
+/// workgroup. It is assumed that the inner-most loop is mapped to the fastest
+/// varying dimension in flow.dispatch.workgroup_size.
+class SetWorkgroupSizePattern
+ : public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
+ public:
+ SetWorkgroupSizePattern(MLIRContext *context,
+ ArrayRef<int64_t> workloadPerWorkgroupRef,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit),
+ workloadPerWorkgroup(llvm::to_vector<4>(
+ workloadPerWorkgroupRef.size() > kNumMaxParallelDims
+ ? workloadPerWorkgroupRef.take_front(kNumMaxParallelDims)
+ : workloadPerWorkgroupRef)) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::HAL::InterfaceWorkgroupSizeOp workgroupSizeOp,
+ PatternRewriter &rewriter) const override {
+ int64_t dim = workgroupSizeOp.dimension().getSExtValue();
+ if (dim >= workloadPerWorkgroup.size()) {
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(workgroupSizeOp,
+ workloadPerWorkgroup[dim]);
+ return success();
+ }
+
+ private:
+ SmallVector<int64_t, 4> workloadPerWorkgroup;
+};
+
class SetNumWorkgroupsPass : public SetNumWorkgroupsBase<SetNumWorkgroupsPass> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -50,9 +83,10 @@
for (auto funcOp : module.getOps<FuncOp>()) {
auto entryPointOp = entryPoints.lookup(funcOp.getName());
if (!entryPointOp) continue;
+
SmallVector<int64_t, 4> currWorkloadPerWorkgroup;
- // First check if there is a workload provided.
+ // First check if there is a per-workgroup workload provided.
if (!workloadPerWorkgroup.empty()) {
currWorkloadPerWorkgroup.assign(workloadPerWorkgroup.begin(),
workloadPerWorkgroup.end());
@@ -66,29 +100,51 @@
}
}
- if (currWorkloadPerWorkgroup.empty()) {
- // If no workgroup size is specified, leave the workgroup size as is, just
- // set the number of workgroups to be 1, 1, 1 to have a single invocation.
- WorkgroupCountRegionBuilder regionBuilder =
- [](OpBuilder &b, Location loc,
- std::array<Value, 3> workload) -> std::array<Value, 3> {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- return {one, one, one};
- };
- OpBuilder builder(context);
- for (auto funcOp : module.getOps<FuncOp>()) {
- if (failed(
- defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) {
- return signalPassFailure();
- }
- }
- } else {
- if (failed(materializeStaticLaunchInformation(
- funcOp, currWorkloadPerWorkgroup))) {
- funcOp.emitError("failed to materialize constant workgroup size");
+ if (!currWorkloadPerWorkgroup.empty()) {
+ // Fold hal.workgroup.size ops.
+ OwningRewritePatternList patterns(funcOp.getContext());
+ patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
+ currWorkloadPerWorkgroup);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
+
+ // The workgroup count region might already be set by op-specific
+ // configuration logic. If so, just return to avoid overwriting that.
+ if (!entryPointOp.workgroup_count_region().empty()) return;
+
+ WorkgroupCountRegionBuilder regionBuilder;
+ if (currWorkloadPerWorkgroup.empty()) {
+ // If no workgroup size is specified, leave the workgroup size as is, just
+ // set the number of workgroups to be 1, 1, 1 to have a single invocation.
+ regionBuilder = [](OpBuilder &b, Location loc,
+ std::array<Value, 3> workload) {
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ return std::array<Value, 3>{one, one, one};
+ };
+ } else {
+ assert(currWorkloadPerWorkgroup.size() <= kNumMaxParallelDims &&
+ "workloadPerWorkgroup size greater than max num parallel dims");
+ regionBuilder = [&currWorkloadPerWorkgroup](
+ OpBuilder &b, Location loc,
+ std::array<Value, 3> workload) {
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ std::array<Value, 3> returnValues = {one, one, one};
+ for (auto ts : llvm::enumerate(currWorkloadPerWorkgroup)) {
+ returnValues[ts.index()] = linalg::applyMapToValues(
+ b, loc,
+ AffineMap::get(0, 1,
+ b.getAffineSymbolExpr(0).ceilDiv(ts.value())),
+ workload[ts.index()])[0];
+ }
+ return returnValues;
+ };
+ }
+
+ OpBuilder builder(context);
+ if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder)))
+ return signalPassFailure();
}
// Apply post distribution canonicalization passes.
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index f65e912..9133eca 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -230,7 +230,6 @@
"unhandled multiple roots in dispatch region");
}
rootOp = computeOp;
- continue;
}
}
@@ -251,7 +250,6 @@
"unhandled multiple roots in dispatch region");
}
rootOp = computeOp;
- continue;
}
}
}
@@ -279,7 +277,8 @@
// on it, just add the default.
if (!getTranslationInfo(entryPointOp)) {
setTranslationInfo(funcOp,
- IREE::HAL::DispatchLoweringPassPipeline::CPUDefault);
+ IREE::HAL::DispatchLoweringPassPipeline::CPUDefault,
+ /*workgroupSize =*/{}, /*workloadPerWorkgroup =*/{});
}
}
return success();
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index a3d2fef..8e9d20c 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -211,14 +211,17 @@
if (getTranslationInfo(entryPointOp)) continue;
SmallVector<Operation *, 4> computeOps;
SmallVector<Operation *, 4> tiledLoops;
- if (succeeded(getComputeOps(funcOp, computeOps, tiledLoops)) &&
- !computeOps.empty()) {
+ if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
+ return funcOp.emitOpError("failed to get compute ops");
}
if (computeOps.empty()) {
+ // TODO(ravishankarm): Maybe this should just return without setting
+ // anything. Without any compute ops, this shouldnt be using tile and
+ // distribute.
setTranslationInfo(
funcOp, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute,
- {1, 1, 1});
+ {1, 1, 1}, /*workloadPerWorkgroup=*/{});
continue;
}
@@ -241,8 +244,16 @@
}
}
+ if (!rootOperation) {
+ // TODO(ravishankarm): Maybe this should just return without setting
+ // anything. Without any compute ops, this shouldnt be using tile and
+ // distribute.
+ setTranslationInfo(
+ funcOp, IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUDistribute,
+ {1, 1, 1}, /*workloadPerWorkgroup=*/{});
+ continue;
+ }
if (failed(setRootConfig(funcOp, rootOperation))) continue;
- IREE::HAL::LoweringConfig config = getLoweringConfig(rootOperation);
// Propogate the configuration to the other ops.
// TODO(ravishankarm, thomasraoux): This is a very specific use (and
@@ -250,6 +261,7 @@
// and distributed. The rest of the compilation must be structured to either
// use `TileAndFuse` or they are independent configurations that are
// determined based on the op.
+ IREE::HAL::LoweringConfig config = getLoweringConfig(rootOperation);
for (auto op : computeOps) {
if (op == rootOperation) continue;
setLoweringConfig(op, config);
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
index 91ff071..091c36b 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
@@ -123,7 +123,7 @@
}
/// Return a flattened Id Value by combining the 3D gpu thread IDs.
-static Value createFlatId(FuncOp funcOp, std::array<int64_t, 3> workgroupSize) {
+static Value createFlatId(FuncOp funcOp, ArrayRef<int64_t> workgroupSize) {
OpBuilder b(funcOp.getBody());
Type indexType = b.getIndexType();
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -196,7 +196,10 @@
}
void runOnOperation() override {
FuncOp funcOp = getOperation();
- std::array<int64_t, 3> workgroupSize = getWorkgroupSize(funcOp);
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
+ auto workgroupSize = getWorkgroupSize(entryPointOp);
+ workgroupSize.resize(3, 1);
MLIRContext *context = &getContext();
SmallVector<linalg::CopyOp> copiesToWorkgroupMem;
funcOp.walk([&](linalg::CopyOp copyOp) {
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
index aa6ed43..1e2b860 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPURemoveTrivialLoops.cpp
@@ -42,7 +42,10 @@
LLVMGPURemoveSingleIterationLoopPass> {
void runOnOperation() override {
FuncOp funcOp = getOperation();
- std::array<int64_t, 3> workgroupSize = getWorkgroupSize(funcOp);
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
+ auto workgroupSize = getWorkgroupSize(entryPointOp);
+ workgroupSize.resize(3, 1);
auto getThreadIdMinMax = [&workgroupSize](Value value,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
diff --git a/iree/compiler/Codegen/Passes.cpp b/iree/compiler/Codegen/Passes.cpp
index 07f7d19..dde5dcb 100644
--- a/iree/compiler/Codegen/Passes.cpp
+++ b/iree/compiler/Codegen/Passes.cpp
@@ -41,21 +41,12 @@
buildLLVMGPUTransformPassPipeline(passManager, true);
});
- static PassPipelineRegistration<> linalgToSPIRVPipeline(
+ static PassPipelineRegistration<> LinalgSPIRVPipeline(
"iree-codegen-linalg-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from Linalg to SPIR-V",
- [](OpPassManager &passManager) {
- buildLinalgToSPIRVPassPipeline(passManager,
- SPIRVCodegenOptions::getFromCLOptions());
- });
-
- static PassPipelineRegistration<> hloToLinalgSPIRVPipeline(
- "iree-codegen-hlo-to-spirv-pipeline",
"Runs the progressive lowering pipeline from XLA HLO to Linalg to "
"SPIR-V",
[](OpPassManager &passManager) {
- buildSPIRVCodegenPassPipeline(passManager,
- SPIRVCodegenOptions::getFromCLOptions());
+ buildSPIRVCodegenPassPipeline(passManager);
});
}
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 1c13c26..84cfd52 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -257,16 +257,17 @@
// SPIRV Passes
//------------------------------------------------------------------------------
-// Options that can be used to configure SPIR-V code generation.
-struct SPIRVCodegenOptions {
- llvm::SmallVector<unsigned, 3> workgroupSize = {};
- llvm::SmallVector<unsigned, 3> workgroupTileSizes = {};
- llvm::SmallVector<unsigned, 3> invocationTileSizes = {};
+/// Pass pipeline to lower executable obtained from Linalg tile + distribute to
+/// scalar + vector code. Does distribution to threads (no vectorization).
+void addSPIRVDistributePassPipeline(OpPassManager &pm);
- bool useWorkgroupMemory = false;
+/// Pass pipeline to lower executables that contain operations that are not
+/// tiled + distributed.
+void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm);
- static SPIRVCodegenOptions getFromCLOptions();
-};
+/// pipeline to lower executable obtained from Linalg tile + distribute to
+/// scalar + vector code. Does distribution to threads and vectorization.
+void addSPIRVVectorizationPassPipeline(OpPassManager &pm);
/// Pass to perform the final conversion to SPIR-V dialect.
/// This pass converts remaining interface ops into SPIR-V global variables,
@@ -274,24 +275,31 @@
/// corresponding SPIR-V ops.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass();
-/// Creates a pass to concretize hal.interface.workgroup.* ops with concrete
-/// tiling and distribution scheme.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConcretizeWorkgroupTilesPass(const SPIRVCodegenOptions &options);
-
/// Pass to add the synchronizations and attributes needed to lower from PLoops
/// to GPU dialect.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConvertToGPUPass();
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass();
/// Creates a pass to fold processor ID uses where possible.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVFoldProcessorIDUsesPass();
+
+/// Main pass to lower executables to scalar + vector code on SPIR-V
+/// path. Invokes one of the pass pipelines that translate the executable to
+/// scalar + vector code.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVFoldProcessorIDUsesPass();
+createSPIRVLowerExecutableTargetPass();
+
+/// Pass to remove loop generated at Flow for tile + distribute when the loop is
+/// known to have a single trip count. NOTE: DO NOT USE. This is a legacy pass
+/// that is to be deprecated.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVRemoveOneTripTiledLoopPass();
+
+/// Pass to tile and distribute Linalg operations on buffers in a single
+/// workgroup.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndDistributePass();
/// Pass to tile and vectorize Linalg operations on buffers in a single
/// workgroup.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVTileAndVectorizePass(const SPIRVCodegenOptions &options);
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass();
/// Pass to convert vector read/write/arithmetic operations to the corresponding
/// cooperative matrix ops when possible.
@@ -310,10 +318,6 @@
// SPIRV Codegen Pass Pipelines.
//----------------------------------------------------------------------------//
-/// Populates passes need to lower from Linalf to SPIR-V.
-void buildLinalgToSPIRVPassPipeline(OpPassManager &pm,
- const SPIRVCodegenOptions &options);
-
/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
/// structured ops path. The pass manager `pm` in here operate on the module
/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
@@ -321,8 +325,7 @@
/// testing purposes only. The pass pipeline will set an appropriate workgroup
/// size.
/// TODO: Are both of these needed and does this one still work on HLO?
-void buildSPIRVCodegenPassPipeline(OpPassManager &pm,
- const SPIRVCodegenOptions &options);
+void buildSPIRVCodegenPassPipeline(OpPassManager &pm);
//----------------------------------------------------------------------------//
// SPIRV Codegen specific patterns.
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 8bfa428..c322788 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -221,39 +221,45 @@
let constructor = "mlir::iree_compiler::createConvertToSPIRVPass()";
}
-def SPIRVConcretizeWorkgroupTiles :
- Pass<"iree-spirv-concretize-workgroup-tiles",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
- let summary = "Replace hal.interface.workgroup.* ops with constant values";
- let constructor =
- "mlir::iree_compiler::createSPIRVConcretizeWorkgroupTilesPass(mlir::iree_compiler::SPIRVCodegenOptions::getFromCLOptions())";
-}
-
// TODO: Rename argument to be fully qualified.
-def SPIRVConvertToGPU :
- Pass<"iree-spirv-convert-to-gpu",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+def SPIRVConvertToGPU : Pass<"iree-spirv-convert-to-gpu", "FuncOp"> {
let summary = "Map tiled linalg and loop ops to GPU";
let constructor = "mlir::iree_compiler::createSPIRVConvertToGPUPass()";
}
// TODO: Rename argument to be fully qualified.
// TODO: Does not appear used?
-def SPIRVFoldProcessorIDUses :
- Pass<"iree-spirv-fold-gpu-procid-uses",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+def SPIRVFoldProcessorIDUses : Pass<"iree-spirv-fold-gpu-procid-uses", "FuncOp"> {
let summary = "Fold GPU processor ID uses where possible";
let constructor = "mlir::iree_compiler::createSPIRVFoldProcessorIDUsesPass()";
}
+def SPIRVLowerExecutableTarget :
+ Pass<"iree-spirv-lower-executable-target-pass", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+ let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
+ let constructor = "mlir::iree_compiler::createSPIRVLowerExecutableTargetPass()";
+}
+
+def SPIRVRemoveOneTripTiledLoop :
+ Pass<"iree-spirv-remove-one-trip-tiled-loop", "FuncOp"> {
+ let summary = "Remove one trip tiled loop. ---- Legacy Pass! Do not use ---";
+ let constructor = "mlir::iree_compiler::createSPIRVRemoveOneTripTiledLoopPass()";
+}
+
// TODO: Rename argument to be fully qualified.
-def SPIRVTileAndVectorize :
- Pass<"iree-spirv-tile-and-vectorize",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+def SPIRVTileAndVectorize : Pass<"iree-spirv-tile-and-vectorize", "FuncOp"> {
let summary =
"Tile and vectorize Linalg operations on buffers in one workgroup";
let constructor =
- "mlir::iree_compiler::createSPIRVTileAndVectorizePass(mlir::iree_compiler::SPIRVCodegenOptions::getFromCLOptions())";
+ "mlir::iree_compiler::createSPIRVTileAndVectorizePass()";
+}
+
+// TODO: Rename argument to be fully qualified.
+def SPIRVTileAndDistribute : Pass<"iree-spirv-tile-and-distribute", "FuncOp"> {
+ let summary =
+ "Tile and distribute Linalg operations on buffers in one workgroup";
+ let constructor =
+ "mlir::iree_compiler::createSPIRVTileAndDistributePass()";
}
// TODO: Rename argument to be fully qualified.
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index d9533b1..63662d0 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -13,15 +13,15 @@
cc_library(
name = "SPIRV",
srcs = [
- "CodeGenOptionUtils.cpp",
"ConvertToSPIRVPass.cpp",
"KernelDispatchUtils.cpp",
- "LaunchConfig.cpp",
"Passes.cpp",
- "SPIRVConcretizeWorkgroupTiles.cpp",
"SPIRVConvertToGPU.cpp",
"SPIRVCopyToWorkgroupMemory.cpp",
"SPIRVFoldGPUProcessorIDUses.cpp",
+ "SPIRVLowerExecutableTargetPass.cpp",
+ "SPIRVRemoveOneTripTiledLoops.cpp",
+ "SPIRVTileAndDistribute.cpp",
"SPIRVTileAndVectorize.cpp",
"SPIRVVectorToCooperativeMatrix.cpp",
"SPIRVVectorizeLoadStore.cpp",
@@ -29,7 +29,6 @@
],
hdrs = [
"KernelDispatchUtils.h",
- "LaunchConfig.h",
"MemorySpace.h",
"Utils.h",
],
@@ -41,6 +40,8 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
+ "//iree/compiler/Dialect/LinalgExt/IR",
+ "//iree/compiler/Dialect/LinalgExt/Transforms",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index d5411b6..f55dd2a 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -15,19 +15,18 @@
SPIRV
HDRS
"KernelDispatchUtils.h"
- "LaunchConfig.h"
"MemorySpace.h"
"Utils.h"
SRCS
- "CodeGenOptionUtils.cpp"
"ConvertToSPIRVPass.cpp"
"KernelDispatchUtils.cpp"
- "LaunchConfig.cpp"
"Passes.cpp"
- "SPIRVConcretizeWorkgroupTiles.cpp"
"SPIRVConvertToGPU.cpp"
"SPIRVCopyToWorkgroupMemory.cpp"
"SPIRVFoldGPUProcessorIDUses.cpp"
+ "SPIRVLowerExecutableTargetPass.cpp"
+ "SPIRVRemoveOneTripTiledLoops.cpp"
+ "SPIRVTileAndDistribute.cpp"
"SPIRVTileAndVectorize.cpp"
"SPIRVVectorToCooperativeMatrix.cpp"
"SPIRVVectorizeLoadStore.cpp"
@@ -70,6 +69,8 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
+ iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
diff --git a/iree/compiler/Codegen/SPIRV/CodeGenOptionUtils.cpp b/iree/compiler/Codegen/SPIRV/CodeGenOptionUtils.cpp
deleted file mode 100644
index 9d16e61..0000000
--- a/iree/compiler/Codegen/SPIRV/CodeGenOptionUtils.cpp
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Passes.h"
-#include "llvm/Support/CommandLine.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-SPIRVCodegenOptions SPIRVCodegenOptions::getFromCLOptions() {
- static llvm::cl::list<unsigned> clWorkgroupTileSizes(
- "iree-spirv-workgroup-tile-size",
- llvm::cl::desc("Set tile sizes to use for each workgroup when tiling "
- "Linalg ops in SPIR-V code generation"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
-
- static llvm::cl::list<unsigned> clInvocationTileSizes(
- "iree-spirv-invocation-tile-size",
- llvm::cl::desc("Set tile sizes for each invocation when tiling Linalg "
- "ops in SPIR-V code generation"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
-
- static llvm::cl::opt<bool> clUseWorkgroupMemory(
- "iree-spirv-use-workgroup-memory",
- llvm::cl::desc("Use workgroup memory in SPIR-V code generation"),
- llvm::cl::init(false));
-
- static llvm::cl::list<unsigned> clWorkgroupSizes(
- "iree-spirv-workgroup-size",
- llvm::cl::desc("Set workgroup size to use for SPIR-V code generation"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
-
- SPIRVCodegenOptions options;
- options.workgroupSize.assign(clWorkgroupSizes.begin(),
- clWorkgroupSizes.end());
- options.workgroupTileSizes.assign(clWorkgroupTileSizes.begin(),
- clWorkgroupTileSizes.end());
- options.invocationTileSizes.assign(clInvocationTileSizes.begin(),
- clInvocationTileSizes.end());
- options.useWorkgroupMemory = clUseWorkgroupMemory;
- return options;
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 52736e8..6f1e29e 100644
--- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -18,6 +18,7 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/DenseMapInfo.h"
@@ -297,6 +298,27 @@
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
+ getAllEntryPoints(moduleOp);
+ for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+ auto entryPointOp = entryPoints.lookup(funcOp.getName());
+ if (!entryPointOp) continue;
+ // TODO(ravishankarm): This needs to be removed after ConvertToGPU is
+ // deprecated. All passes must set the `workgroup_size` on the
+ // `hal.executable.entry_point` directly and not on the function.
+ if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) continue;
+ SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
+ if (workgroupSize.empty()) {
+ entryPointOp.emitOpError(
+ "expected workgroup_size attribute to be set for SPIR-V lowering");
+ return signalPassFailure();
+ }
+ auto workgroupSize32 = llvm::to_vector<4>(llvm::map_range(
+ workgroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
+ funcOp->setAttr(spirv::getEntryPointABIAttrName(),
+ spirv::getEntryPointABIAttr(workgroupSize32, context));
+ }
+
auto targetAttr = spirv::lookupTargetEnv(moduleOp);
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns(&getContext());
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp
index 0296257..3d2b5ea 100644
--- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp
@@ -16,9 +16,12 @@
#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/LaunchConfig.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/LoweringConfig.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/Support/Debug.h"
@@ -55,74 +58,7 @@
return std::min(shape, tileSize);
}
-/// Sets the `tileSizes` and `workgroupSize` for an Linalg `op` to the default,
-/// where at most 3 inner parallel dimensions of `op` are tiled and distributed,
-/// and each invocation handles one scalar elements.
-// TODO(#5852): revisit the default here: they were chosen to get started and
-// not very good.
-static LogicalResult setDefaultTilingScheme(
- const spirv::TargetEnv &targetEnv, linalg::LinalgOp op,
- TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize) {
- auto maxWorkgroupSize =
- targetEnv.getResourceLimits().max_compute_workgroup_invocations();
-
- const int64_t tileSizeX = 32;
- const int64_t tileSizeY = maxWorkgroupSize.getInt() / tileSizeX;
-
- unsigned numParallelDims = getNumOuterParallelLoops(op);
-
- SmallVector<int64_t, 4> workgroupLevel(numParallelDims, 0);
- SmallVector<int64_t, 4> invocationLevel(numParallelDims, 0);
-
- if (numParallelDims >= 1) {
- workgroupLevel.back() = tileSizeX;
- invocationLevel.back() = 1;
- }
- if (numParallelDims >= 2) {
- workgroupLevel[numParallelDims - 2] = tileSizeY;
- invocationLevel[numParallelDims - 2] = 1;
- }
- if (numParallelDims >= 3) {
- workgroupLevel[numParallelDims - 3] = 1;
- invocationLevel[numParallelDims - 3] = 1;
- }
-
- tileSizes.emplace_back(std::move(workgroupLevel));
- tileSizes.emplace_back(); // Subgroup level
- tileSizes.emplace_back(std::move(invocationLevel));
-
- workgroupSize = {tileSizeX, tileSizeY, 1};
-
- return success();
-}
-
-/// Fills `inputTypes` and `outputTypes` with the original input/output types
-/// for all tiles for `op`.
-static std::tuple<SmallVector<ShapedType>, SmallVector<ShapedType>>
-getInputOutputTypes(linalg::LinalgOp op) {
- SmallVector<ShapedType> inputTypes(op.getNumInputs()),
- outputTypes(op.getNumOutputs());
- auto inputOperands = op.getInputOperands();
- for (auto operand : enumerate(inputOperands)) {
- assert(!op.isScalar(operand.value()));
- inputTypes[operand.index()] =
- getUntiledType(operand.value()->get()).dyn_cast<ShapedType>();
- }
- auto outputOperands = op.getOutputOperands();
- for (auto operand : enumerate(outputOperands)) {
- outputTypes[operand.index()] =
- getUntiledType(operand.value()->get()).dyn_cast<ShapedType>();
- }
- return std::make_tuple(std::move(inputTypes), std::move(outputTypes));
-}
-
namespace {
-struct LaunchConfigInfo {
- std::array<int64_t, 3> workgroupSize = {32, 1, 1};
- std::array<int64_t, 3> numSubgroups = {1, 1, 1};
- bool vectorize = false;
-};
-
struct TileWorkgroupSizePair {
// How many scalar elements each workgroup should handle along each dimension.
std::array<int64_t, 3> tileSize;
@@ -130,20 +66,6 @@
};
} // namespace
-/// For a given operation `op`, compute the following configurations according
-/// to SPIR-V `targetEnv` and `options`:
-/// 1) number of tiling levels and tile sizes to use (updates `tileSizes`),
-/// 2) workgroup size to use (updates `workgroupSize`),
-/// 3) number of subgroups to use if two level tiling is used (updates
-/// `numSubgroups`).
-template <typename T>
-static LogicalResult getOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
- return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
-}
-
static void getMaliBestMatMulTileSizes(
Type elementType, SmallVectorImpl<TileWorkgroupSizePair> &tileSizes,
int64_t dstSize) {
@@ -188,89 +110,83 @@
}
/// Launch configuration for Mali GPU configuration.
-static LogicalResult getMaliSpecificConfig(
- linalg::BatchMatmulOp op, const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+static LogicalResult setMaliSpecificConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ linalg::BatchMatmulOp op) {
if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure();
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
-
- ShapedType lhsType = inputTypes[0], rhsType = inputTypes[1];
- if (!lhsType || !rhsType || !lhsType.hasStaticShape() ||
- !rhsType.hasStaticShape())
+ ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
+ // If the shape size is unknonw fall back to none vectorized path.
+ if (llvm::any_of(lhsShape, ShapedType::isDynamic) ||
+ llvm::any_of(rhsShape, ShapedType::isDynamic)) {
return failure();
+ }
+
// Get a vector of best tile size ordered from best to worst.
SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
- int64_t dstSize =
- lhsType.getDimSize(0) * lhsType.getDimSize(1) * rhsType.getDimSize(2);
- getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
- dstSize);
+ int64_t dstSize = lhsShape[0] * lhsShape[1] * rhsShape[2];
+ getMaliBestMatMulTileSizes(
+ op.inputs()[0].getType().cast<ShapedType>().getElementType(),
+ workgroupLevelTs, dstSize);
for (TileWorkgroupSizePair pair : workgroupLevelTs) {
- if (lhsType.getDimSize(1) % pair.tileSize[0] != 0 ||
- rhsType.getDimSize(2) % pair.tileSize[1] != 0 ||
- lhsType.getDimSize(2) % pair.tileSize[2] != 0) {
+ if (lhsShape[1] % pair.tileSize[0] != 0 ||
+ rhsShape[2] % pair.tileSize[1] != 0 ||
+ lhsShape[2] % pair.tileSize[2] != 0) {
continue;
}
- workgroupSize = pair.workgroupSize;
SmallVector<int64_t, 4> batchTs;
batchTs.append({1, pair.tileSize[0], pair.tileSize[1], pair.tileSize[2]});
+ TileSizesListType tileSizes;
tileSizes.emplace_back(batchTs);
// No tiling at the subgroup level since this target doesn't use subgroup op
// or shared memory.
tileSizes.emplace_back();
SmallVector<int64_t, 4> invocationLevelTs = {
- batchTs[0], batchTs[1] / workgroupSize[1],
- batchTs[2] / workgroupSize[0], batchTs[3]};
+ batchTs[0], batchTs[1] / pair.workgroupSize[1],
+ batchTs[2] / pair.workgroupSize[0], batchTs[3]};
tileSizes.emplace_back(invocationLevelTs);
- return success();
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
+ pair.workgroupSize);
}
return failure();
}
/// Launch config for `linalg.batchmatmul`.
-template <>
-LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
- if (succeeded(getMaliSpecificConfig(op, targetEnv, options, tileSizes,
- config.workgroupSize,
- config.numSubgroups))) {
- config.vectorize = true;
+static LogicalResult setRootConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ linalg::BatchMatmulOp op) {
+ if (succeeded(setMaliSpecificConfig(entryPoint, targetEnv, op))) {
return success();
}
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
- std::tie(config.workgroupSize[0], config.workgroupSize[1]) =
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+ std::tie(workgroupSize[0], workgroupSize[1]) =
distributeProcs2D(maxWorkgroupSize);
- config.workgroupSize[2] = 1;
// This is just being hard-wired for now to be minimal viable, but this can be
// decided better when we have better estimates of device charecteristics.
const int64_t nRowsPerWorkitem = 1;
const int64_t nColsPerWorkitem = 1;
const int64_t nBatchesPerWorkitem = 1;
int64_t tileSizeK = 0;
- if (options.useWorkgroupMemory) {
- // This number should be decided based on the amount of shared memory
- // available (maybe). For now, just hard-wire it.
- tileSizeK = 32;
- }
SmallVector<int64_t, 4> workgroupLevel = {
- nBatchesPerWorkitem, nRowsPerWorkitem * config.workgroupSize[1],
- nColsPerWorkitem * config.workgroupSize[0], tileSizeK};
+ nBatchesPerWorkitem, nRowsPerWorkitem * workgroupSize[1],
+ nColsPerWorkitem * workgroupSize[0], tileSizeK};
SmallVector<int64_t, 4> invocationLevel = {
nBatchesPerWorkitem, nRowsPerWorkitem, nColsPerWorkitem, 0};
+ TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupLevel));
tileSizes.emplace_back(); // subgroup level
tileSizes.emplace_back(std::move(invocationLevel));
- return success();
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize);
}
/// Returns the size of the co-operative matrix multiply operations on the
@@ -297,26 +213,29 @@
/// Launch configuration for using spv.CooperativeMatrixMulAddNV
/// operations. Needs two levels of tiling.
-static LogicalResult getConfigForCooperativeMatmul(
- linalg::MatmulOp op, const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+static LogicalResult setConfigForCooperativeMatmul(
+ FuncOp entryPoint, const spirv::TargetEnv &targetEnv, linalg::MatmulOp op) {
if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
!targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
return failure();
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
-
- ShapedType lhsType = inputTypes[0], rhsType = inputTypes[1];
- ShapedType outputType = outputTypes[0];
+ ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
+ // If the shape size is unknonw fall back to none vectorized path.
+ if (llvm::any_of(lhsShape, ShapedType::isDynamic) ||
+ llvm::any_of(rhsShape, ShapedType::isDynamic)) {
+ return failure();
+ }
auto resourceLimits = targetEnv.getResourceLimits();
+ auto getElementType = [](Value v) {
+ return v.getType().cast<ShapedType>().getElementType();
+ };
+ auto outputElementType = getElementType(op.outputs()[0]);
Optional<SmallVector<int64_t, 4>> coopMatmulSize =
getCooperativeMatmulSubgroupSize(
- resourceLimits, lhsType.getElementType(), rhsType.getElementType(),
- outputType.getElementType(), outputType.getElementType());
+ resourceLimits, getElementType(op.inputs()[0]),
+ getElementType(op.inputs()[1]), outputElementType, outputElementType);
if (!coopMatmulSize) return failure();
// Check that the matmul sizes are a multiple of the tilesize.
@@ -324,241 +243,223 @@
return !ShapedType::isDynamic(s) && (s % ts) == 0;
};
- ArrayRef<int64_t> lhsShape = lhsType.getShape();
- ArrayRef<int64_t> rhsShape = rhsType.getShape();
if (!isMultipleOf(lhsShape[0], (*coopMatmulSize)[0]) ||
!isMultipleOf(rhsShape[1], (*coopMatmulSize)[1]) ||
!isMultipleOf(lhsShape[1], (*coopMatmulSize)[2]) ||
!isMultipleOf(rhsShape[0], (*coopMatmulSize)[2]))
return failure();
- if (options.useWorkgroupMemory) {
- numSubgroups[0] = 2;
- numSubgroups[1] = 2;
- } else {
- numSubgroups[0] = 1;
- numSubgroups[1] = 1;
- }
- numSubgroups[2] = 1;
-
// For now this is being hard-wired to be {4, 4, 2}. This can actually be set
// to whatever, but ultimately depends on register pressure.
const int64_t numVecMatmulPerSubgroupX = 4;
const int64_t numVecMatmulPerSubgroupY = 4;
const int64_t numVecMatmulPerSubgroupK = 2;
SmallVector<int64_t, 4> ts = {
- numVecMatmulPerSubgroupY * (*coopMatmulSize)[0] * numSubgroups[1],
- numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0],
+ numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
+ numVecMatmulPerSubgroupX * (*coopMatmulSize)[1],
numVecMatmulPerSubgroupK * (*coopMatmulSize)[2]};
+ TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(ts));
int64_t subgroupSize =
resourceLimits.subgroup_size().getValue().getSExtValue();
- workgroupSize[0] = numSubgroups[0] * numSubgroups[1] * subgroupSize;
- workgroupSize[1] = 1;
- workgroupSize[2] = 1;
- // Subgroup tile sizes
+ std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
SmallVector<int64_t, 4> subgroupTs = {
numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
numVecMatmulPerSubgroupX * (*coopMatmulSize)[1]};
tileSizes.emplace_back(std::move(subgroupTs));
- return success();
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, workgroupSize);
}
/// Launch config for element-wise linalg.generic.
-LogicalResult getGenericOpLaunchConfig(linalg::LinalgOp linalgOp,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
+LogicalResult setDefaultRootConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ Operation *op) {
+ auto partitionedLoops = getPartitionedLoops(op);
+ if (partitionedLoops.empty()) {
+ // Serialized computation.
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, /*tileSizes =*/TileSizesListType{{}},
+ /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, {1, 1, 1});
+ }
+
// Skip vectorization for non-minor identity inputs as it generates
// transfer_read ops with permutation maps that we currently cannot lower.
// TODO: Remove this restriction once the lowering of the permutation map is
// supported in core.
- bool vectorize = !linalgOp.hasIndexSemantics() &&
- llvm::all_of(linalgOp.getIndexingMaps(), [](AffineMap &map) {
- return map.isMinorIdentity();
- });
- // TODO(thomasraoux): Lowering of integers other than i32 may require
- // emulation. This is currently not supported for vector operation. Re-enable
- // this when the bug is fixed on SPIR-V lowering side.
- if (llvm::any_of(linalgOp->getOperands(), [](Value operand) {
- Type memrefType = operand.getType().cast<MemRefType>().getElementType();
- return !memrefType.isa<FloatType>() && !memrefType.isInteger(32);
- }))
- vectorize = false;
int64_t subgroupSize =
targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue();
- config.workgroupSize[0] = subgroupSize;
- config.workgroupSize[1] = 1;
- config.workgroupSize[2] = 1;
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(linalgOp);
- ShapedType outputShape = outputTypes[0];
- SmallVector<int64_t, 4> candidateTileSizes;
- // When Vectororization is not enabled we skil the second level of tiling and
- // fall back to convertToGPU which will map one element to one thread. To
- // avoid a mismatch in the number of workgroup dispatched, we pick a tile size
- // to have one element per thread.
- // TODO: Remove this once we switch to linalg on tensor path.
- if (vectorize) {
- candidateTileSizes.append({4 * subgroupSize, 2 * subgroupSize});
+ int64_t lowerWorkgroupTs = subgroupSize;
+ int64_t lowerThreadTs = 1;
+ IREE::HAL::DispatchLoweringPassPipeline pipeline =
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute;
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ bool vectorize = false;
+ // TODO(thomasraoux): Lowering of integers other than i32 may require
+ // emulation. This is currently not supported for vector operation.
+ // Re-enable this when the bug is fixed on SPIR-V lowering side.
+ auto outputShape = getUntiledResultShape(linalgOp, 0);
+ if (!linalgOp.hasIndexSemantics() &&
+ llvm::all_of(linalgOp.getIndexingMaps(),
+ [](AffineMap &map) { return map.isMinorIdentity(); }) &&
+ llvm::all_of(
+ linalgOp->getOperands(),
+ [](Value operand) {
+ auto shapedType = operand.getType().dyn_cast<ShapedType>();
+ Type elementType = (shapedType ? shapedType.getElementType()
+ : operand.getType());
+ return elementType.isa<FloatType>() || elementType.isInteger(32);
+ }) &&
+ llvm::all_of(outputShape,
+ [](int64_t dim) { return !ShapedType::isDynamic(dim); })) {
+ vectorize = true;
+ }
+ SmallVector<int64_t, 4> candidateTileSizes;
+ if (vectorize) candidateTileSizes.push_back(4 * subgroupSize);
+ candidateTileSizes.push_back(subgroupSize);
+ for (int64_t size : candidateTileSizes) {
+ if (outputShape.back() % size != 0) continue;
+ lowerWorkgroupTs = size;
+ break;
+ }
+ if (lowerWorkgroupTs <= subgroupSize ||
+ outputShape.back() % lowerWorkgroupTs != 0) {
+ vectorize = false;
+ }
+ if (vectorize) {
+ lowerThreadTs = lowerWorkgroupTs / subgroupSize;
+ pipeline = IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize;
+ }
}
- candidateTileSizes.push_back(subgroupSize);
- // Use the first tile size that can divide the shape. If the shape is not
- // aligned on any of the tile sizes pick the smallest tile of one element per
- // thread.
- int64_t lowerTs = config.workgroupSize[0];
- for (int64_t size : candidateTileSizes) {
- if (outputShape.getShape().back() % size != 0) continue;
- lowerTs = size;
- break;
- }
- unsigned numLoops = getNumOuterParallelLoops(linalgOp);
- SmallVector<int64_t, 4> ts(numLoops, 1);
- ts.back() = lowerTs;
- tileSizes.emplace_back(ts); // Workgroup level
- tileSizes.emplace_back(); // Subgroup level
+ std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
- if (!vectorize || outputShape.getShape().back() % lowerTs != 0) {
- ts.back() = 1;
- tileSizes.emplace_back(ts); // Thread level
- config.vectorize = false;
- } else {
- ts.back() = lowerTs / subgroupSize;
- tileSizes.emplace_back(ts); // Thread level
- // Vectorize only if we are processing more than one element per thread.
- config.vectorize = vectorize && (ts.back() > 1);
+ unsigned loopDepth = partitionedLoops.back() + 1;
+ SmallVector<int64_t, 4> workgroupTileSize(loopDepth, 0);
+ SmallVector<int64_t, 4> threadTileSize(loopDepth, 0);
+
+ // Tiling along partitioned loops with size 1.
+ for (int64_t loopIndex : partitionedLoops) {
+ workgroupTileSize[loopIndex] = threadTileSize[loopIndex] = 1;
}
- return success();
+ // Overwrite the configuration for the innermost dimension.
+ workgroupTileSize.back() = lowerWorkgroupTs;
+ threadTileSize.back() = lowerThreadTs;
+
+ TileSizesListType tileSizes;
+ tileSizes.emplace_back(workgroupTileSize); // Workgroup level
+ tileSizes.emplace_back(); // Subgroup level
+ tileSizes.emplace_back(threadTileSize); // Invocation level
+
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes,
+ /*nativeVectorSize =*/ArrayRef<int64_t>{}, pipeline, workgroupSize);
}
-#define GET_GENERIC_OP_LAUNCH_CONFIG(opType) \
- template <> \
- LogicalResult getOpLaunchConfig( \
- opType op, const spirv::TargetEnv &targetEnv, \
- const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \
- LaunchConfigInfo &config) { \
- return getGenericOpLaunchConfig(op, targetEnv, options, tileSizes, \
- config); \
- }
-
-GET_GENERIC_OP_LAUNCH_CONFIG(linalg::GenericOp)
-
-#undef GET_GENERIC_OP_LAUNCH_CONFIG
-
/// Launch configuration for different known GPU configuration.
-static LogicalResult getTargetSpecificConfig(
- linalg::MatmulOp op, const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+static LogicalResult setTargetSpecificConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ linalg::MatmulOp op) {
if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure();
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
-
- ShapedType lhsType = inputTypes[0], rhsType = inputTypes[1];
+ ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
// If the shape size is unknonw fall back to none vectorized path.
- if (!lhsType || !rhsType || !lhsType.hasStaticShape() ||
- !rhsType.hasStaticShape())
+ if (llvm::any_of(lhsShape, ShapedType::isDynamic) ||
+ llvm::any_of(rhsShape, ShapedType::isDynamic)) {
return failure();
+ }
// Pick ideal tile size based on the type.
SmallVector<TileWorkgroupSizePair, 4> workgroupLevelTs;
- int64_t dstSize = lhsType.getDimSize(0) * rhsType.getDimSize(1);
- getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs,
- dstSize);
+ int64_t dstSize = lhsShape[0] * rhsShape[1];
+ getMaliBestMatMulTileSizes(
+ op.inputs()[0].getType().cast<ShapedType>().getElementType(),
+ workgroupLevelTs, dstSize);
for (TileWorkgroupSizePair pair : workgroupLevelTs) {
- if (lhsType.getDimSize(0) % pair.tileSize[0] != 0 ||
- rhsType.getDimSize(1) % pair.tileSize[1] != 0 ||
- lhsType.getDimSize(1) % pair.tileSize[2] != 0) {
+ if (lhsShape[0] % pair.tileSize[0] != 0 ||
+ rhsShape[1] % pair.tileSize[1] != 0 ||
+ lhsShape[1] % pair.tileSize[2] != 0) {
continue;
}
- workgroupSize = pair.workgroupSize;
+ TileSizesListType tileSizes;
SmallVector<int64_t, 4> matmulTS(pair.tileSize.begin(),
pair.tileSize.end());
tileSizes.emplace_back(matmulTS);
// No tiling at the subgroup level since this target doesn't use subgroup op
// or shared memory.
tileSizes.emplace_back();
- SmallVector<int64_t, 4> invocationLevelTs = {matmulTS[0] / workgroupSize[1],
- matmulTS[1] / workgroupSize[0],
- matmulTS[2]};
+ SmallVector<int64_t, 4> invocationLevelTs = {
+ matmulTS[0] / pair.workgroupSize[1],
+ matmulTS[1] / pair.workgroupSize[0], matmulTS[2]};
tileSizes.emplace_back(invocationLevelTs);
- return success();
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes,
+ /*nativeVectorSize =*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
+ pair.workgroupSize);
}
return failure();
}
-template <>
-LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
- if (succeeded(getConfigForCooperativeMatmul(op, targetEnv, options, tileSizes,
- config.workgroupSize,
- config.numSubgroups))) {
- config.vectorize = true;
+LogicalResult setRootConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ linalg::MatmulOp op) {
+ if (succeeded(setConfigForCooperativeMatmul(entryPoint, targetEnv, op))) {
return success();
}
- if (succeeded(getTargetSpecificConfig(op, targetEnv, options, tileSizes,
- config.workgroupSize,
- config.numSubgroups))) {
- config.vectorize = true;
+ if (succeeded(setTargetSpecificConfig(entryPoint, targetEnv, op))) {
return success();
}
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
- std::tie(config.workgroupSize[0], config.workgroupSize[1]) =
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+ std::tie(workgroupSize[0], workgroupSize[1]) =
distributeProcs2D(maxWorkgroupSize);
- config.workgroupSize[2] = 1;
const int nRowsPerWorkitem = 1;
const int nColsPerWorkitem = 1;
int64_t tileSizeK = 0;
- if (options.useWorkgroupMemory) {
- // TODO(#3131): This number should be decided based on the amount of shared
- // memory available (maybe). For now, just hard-wire it.
- tileSizeK = 32;
- }
- SmallVector<ShapedType> inputTypes;
- std::tie(inputTypes, std::ignore) = getInputOutputTypes(op);
- int64_t M = inputTypes[0].getShape()[0];
- int64_t N = inputTypes[1].getShape()[1];
- int64_t K = inputTypes[0].getShape()[1];
+ ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
+
+ int64_t M = lhsShape[0];
+ int64_t N = rhsShape[1];
+ int64_t K = lhsShape[1];
SmallVector<int64_t, 4> workgroupLevel = {
- getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
- getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
+ getMinIfShapeStatic(M, nRowsPerWorkitem * workgroupSize[1]),
+ getMinIfShapeStatic(N, nColsPerWorkitem * workgroupSize[0]),
getMinIfShapeStatic(K, tileSizeK)};
SmallVector<int64_t, 4> invocationLevel = {1, 1, 0};
+ TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupLevel));
tileSizes.emplace_back(); // subgroup level
tileSizes.emplace_back(std::move(invocationLevel));
- return success();
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes, /*nativeVectorSize =*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize);
}
-static LogicalResult getMaliSpecificConfig(linalg::ConvInputNHWCFilterHWCFOp op,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
-
- ShapedType inputType = inputTypes[0], outputType = outputTypes[0];
- if (!inputType || !outputType || !inputType.hasStaticShape() ||
- !outputType.hasStaticShape())
+static LogicalResult setMaliSpecificConfig(
+ FuncOp entryFn, linalg::ConvInputNHWCFilterHWCFOp op) {
+ ArrayRef<int64_t> inputShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> outputShape =
+ getUntiledResultShape(cast<linalg::LinalgOp>(op.getOperation()), 0);
+ if (llvm::any_of(inputShape, ShapedType::isDynamic) ||
+ llvm::any_of(outputShape, ShapedType::isDynamic)) {
return failure();
+ }
- bool isInputTilable =
- inputType.getDimSize(3) % 4 == 0 || inputType.getDimSize(3) < 4;
+ bool isInputTilable = inputShape[3] % 4 == 0 || inputShape[3] < 4;
if (!isInputTilable) return failure();
// A list of preferred tile sizes and workgroup sizes. This is for Mali
@@ -575,13 +476,13 @@
const std::array<int64_t, 3> &tileSize = pair.tileSize;
const std::array<int64_t, 3> &workgroupSize = pair.workgroupSize;
- auto outputShape = outputType.getShape();
bool isOutputTilable = (outputShape[0] == 1) &&
(outputShape[1] % tileSize[0] == 0) &&
(outputShape[2] % tileSize[1] == 0) &&
(outputShape[3] % tileSize[2] == 0);
if (!isOutputTilable) continue;
+ TileSizesListType tileSizes;
SmallVector<int64_t, 4> workgroupLevel = {
/*batch=*/0, /*output_height=*/tileSize[0],
/*output_width=*/tileSize[1], /*output_channel=*/tileSize[2]};
@@ -603,39 +504,51 @@
SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 1, 1, 4};
tileSizes.emplace_back(fourthLevel);
- config.workgroupSize = workgroupSize;
- config.vectorize = true;
+ if (failed(setOpConfigAndEntryPointFnTranslation(
+ entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
+ workgroupSize)))
+ return failure();
- return success();
+ // Let the entry point region to return fully static number of workgroups.
+ // This is needed for folding `affine.min` ops to expose static-shaped tiled
+ // convolution for vectorization.
+ // TODO(#5034): Use a proper way to prove tilability and fold `affine.min`s.
+ auto numWorkgroupsFn = [&](OpBuilder &b, Location loc,
+ std::array<Value, 3>) {
+ std::array<Value, 3> xyz;
+ for (unsigned i = 0; i < 3; ++i) {
+ int64_t count = outputShape[i + 1] / tileSize[i];
+ xyz[2 - i] = b.create<ConstantIndexOp>(loc, count);
+ }
+ return xyz;
+ };
+
+ OpBuilder builder(op.getContext());
+ return defineWorkgroupCountRegion(builder, entryFn, numWorkgroupsFn);
}
-
return failure();
}
-template <>
-LogicalResult getOpLaunchConfig(linalg::ConvInputNHWCFilterHWCFOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
+LogicalResult setRootConfig(FuncOp entryPoint,
+ const spirv::TargetEnv &targetEnv,
+ linalg::ConvInputNHWCFilterHWCFOp op) {
if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
- succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
+ succeeded(setMaliSpecificConfig(entryPoint, op))) {
return success();
}
-
- return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
+ return setDefaultRootConfig(entryPoint, targetEnv, op);
}
-static LogicalResult getMaliSpecificConfig(
- linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
- SmallVector<ShapedType> inputTypes, outputTypes;
- std::tie(inputTypes, outputTypes) = getInputOutputTypes(op);
-
- ShapedType inputType = inputTypes[0], outputType = outputTypes[0];
- if (!inputType || !outputType || !inputType.hasStaticShape() ||
- !outputType.hasStaticShape())
+static LogicalResult setMaliSpecificConfig(
+ FuncOp entryFn, linalg::DepthwiseConvInputNHWCFilterHWCOp op) {
+ ArrayRef<int64_t> inputShape = getUntiledShape(op.inputs()[0]);
+ ArrayRef<int64_t> outputShape =
+ getUntiledResultShape(cast<linalg::LinalgOp>(op.getOperation()), 0);
+ if (llvm::any_of(inputShape, ShapedType::isDynamic) ||
+ llvm::any_of(outputShape, ShapedType::isDynamic)) {
return failure();
+ }
// A list of preferred tile sizes and workgroup sizes. This is for Mali
// G77 now and it's fairly ad-hoc. We need to have a better story for
@@ -644,13 +557,13 @@
{{2, 2, 32}, {8, 2, 2}},
{{1, 4, 16}, {4, 4, 1}},
{{1, 1, 64}, {16, 1, 1}},
+ {{4, 4, 8}, {2, 4, 2}},
};
for (const auto &pair : tileWorkgroupSizePairs) {
const std::array<int64_t, 3> &tileSize = pair.tileSize;
const std::array<int64_t, 3> &workgroupSize = pair.workgroupSize;
- auto outputShape = outputType.getShape();
bool isOutputTilable = outputShape[0] == 1 &&
(outputShape[1] % tileSize[0] == 0) &&
(outputShape[2] % tileSize[1] == 0) &&
@@ -661,6 +574,7 @@
/*output_height=*/tileSize[0],
/*output_width=*/tileSize[1],
/*output_channel=*/tileSize[2]};
+ TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupLevel));
// No tiling at the subgroup level given that we don't use subgroup
@@ -678,118 +592,166 @@
SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 1, 1};
tileSizes.emplace_back(fourthLevel);
- config.workgroupSize = workgroupSize;
- config.vectorize = true;
+ if (failed(setOpConfigAndEntryPointFnTranslation(
+ entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
+ workgroupSize)))
+ return failure();
- return success();
+ // Let the entry point region to return fully static number of workgroups.
+ // This is needed for folding `affine.min` ops to expose static-shaped tiled
+ // convolution for vectorization.
+ // TODO(#5034): Use a proper way to prove tilability and fold `affine.min`s.
+ auto numWorkgroupsFn = [&](OpBuilder &b, Location loc,
+ std::array<Value, 3>) {
+ std::array<Value, 3> xyz;
+ for (unsigned i = 0; i < 3; ++i) {
+ int64_t count = outputShape[i + 1] / tileSize[i];
+ xyz[2 - i] = b.create<ConstantIndexOp>(loc, count);
+ }
+ return xyz;
+ };
+
+ OpBuilder builder(op.getContext());
+ return defineWorkgroupCountRegion(builder, entryFn, numWorkgroupsFn);
}
return failure();
}
-template <>
-LogicalResult getOpLaunchConfig(linalg::DepthwiseConvInputNHWCFilterHWCOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
+static LogicalResult setRootConfig(
+ FuncOp entryPoint, const spirv::TargetEnv &targetEnv,
+ linalg::DepthwiseConvInputNHWCFilterHWCOp op) {
if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
- succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
+ succeeded(setMaliSpecificConfig(entryPoint, op))) {
return success();
}
-
- return setDefaultTilingScheme(targetEnv, op, tileSizes, config.workgroupSize);
+ return setDefaultRootConfig(entryPoint, targetEnv, op);
}
-Optional<LaunchConfig> initGPULaunchConfig(
- MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps) {
- LaunchConfig launchConfig;
- if (!options.workgroupSize.empty()) {
- SmallVector<int64_t, 3> workgroupTileSizes(
- options.workgroupTileSizes.begin(), options.workgroupTileSizes.end());
- SmallVector<int64_t, 3> invocationTileSizes(
- options.invocationTileSizes.begin(), options.invocationTileSizes.end());
- for (linalg::LinalgOp linalgOp : linalgOps) {
- launchConfig.setTileSizes(linalgOp.getOperation(), workgroupTileSizes, 0);
- // Subgroup level.
- launchConfig.setTileSizes(linalgOp.getOperation(), {}, 1);
- // Invocation level.
- launchConfig.setTileSizes(linalgOp.getOperation(), invocationTileSizes,
- 2);
- launchConfig.setVectorize(true);
+/// Helper function to generate the number of workgroups when the
+/// `SPIRVDistributeToGlobalID` is used.
+// TODO(ravishankarm): Remove this when that pipeline is deprecated.
+static LogicalResult setTranslationUsingDistributeToGlobalId(
+ FuncOp funcOp, ArrayRef<int64_t> workgroupSize) {
+ auto entryPointOp = getEntryPoint(funcOp);
+ MLIRContext *context = entryPointOp.getContext();
+ auto translationInfo = buildTranslationInfo(
+ IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistributeToGlobalID,
+ /*workloadPerWorkgroup =*/{}, context);
+ setTranslationInfo(entryPointOp, translationInfo, workgroupSize);
+ OpBuilder builder(context);
+ int64_t workgroupSizeX = workgroupSize[0];
+ auto numWorkgroupsFn =
+ [workgroupSizeX](OpBuilder &b, Location loc,
+ std::array<Value, 3> workload) -> std::array<Value, 3> {
+ AffineExpr e1, e2, e3;
+ bindSymbols(b.getContext(), e1, e2, e3);
+ AffineExpr expr = e1 * e2 * e3;
+ expr = expr.ceilDiv(workgroupSizeX);
+ Value numWorkgroupsX = linalg::applyMapToValues(
+ b, loc, AffineMap::get(0, 3, expr), workload)[0];
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ return {numWorkgroupsX, one, one};
+ };
+ return defineWorkgroupCountRegion(builder, funcOp, numWorkgroupsFn);
+}
+
+LogicalResult initSPIRVLaunchConfig(ModuleOp module) {
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPointOps =
+ getAllEntryPoints(module);
+
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ auto entryPointOp = entryPointOps.lookup(funcOp.getName());
+ if (!entryPointOp) continue;
+ if (getTranslationInfo(entryPointOp)) continue;
+ SmallVector<Operation *, 4> computeOps;
+ SmallVector<Operation *, 4> tiledLoops;
+ if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
+ return funcOp.emitOpError("failed to get compute ops");
}
- SmallVector<int64_t, 3> workgroupSize(options.workgroupSize.begin(),
- options.workgroupSize.end());
- launchConfig.setWorkgroupSize(workgroupSize);
- }
+ spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(funcOp));
+ int64_t subgroupSize =
+ targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue();
- if (linalgOps.empty()) return launchConfig;
-
- spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(*linalgOps.begin()));
-
- Optional<linalg::LinalgOp> rootOperation = {};
- LaunchConfigInfo config;
-#define DISPATCH(opName) \
- if (auto op = dyn_cast<opName>(linalgOp.getOperation())) { \
- rootOperation = linalgOp; \
- if (launchConfig.hasTileSizes(linalgOp.getOperation())) break; \
- TileSizesListType tileSizesInfo; \
- if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, \
- config))) { \
- return llvm::None; \
- } \
- launchConfig.setTileSizes(op, tileSizesInfo); \
- break; \
- }
-
- for (linalg::LinalgOp linalgOp : linalgOps) {
- DISPATCH(linalg::BatchMatmulOp)
- DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp)
- DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCFOp)
- DISPATCH(linalg::ConvInputNWCFilterWCFOp)
- DISPATCH(linalg::ConvInputNHWCFilterHWCFOp)
- DISPATCH(linalg::ConvInputNDHWCFilterDHWCFOp)
- DISPATCH(linalg::MatmulOp)
- DISPATCH(linalg::PoolingNhwcMaxOp)
- DISPATCH(linalg::PoolingNhwcMinOp)
- DISPATCH(linalg::PoolingNhwcSumOp)
- }
-
- // Any generic operations found are made the root if no other op is the root
- if (!rootOperation) {
- for (linalg::LinalgOp linalgOp : reverse(linalgOps)) {
- size_t numLoops = getNumOuterParallelLoops(linalgOp);
- if (numLoops == 0 ||
- llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) {
- return !map.isProjectedPermutation();
- })) {
- return llvm::None;
+ if (computeOps.empty() || llvm::none_of(computeOps, [](Operation *op) {
+ return hasMarker(op, getWorkgroupMarker());
+ })) {
+ // TODO(ravishankarm): `tensor.insert_slice` is not a compute op but still
+ // needs to be handled in dispatch region. For now it is handled in
+ // ConvertToGPU pass. Eventually this will be handled as a compute
+ // op. This is just to keep scope of change to dynamic pass pipelines
+ // limited. Remove this when dropping ConvertToGPU pass.
+ if (failed(getFilteredOps(
+ funcOp,
+ [](Operation *op) {
+ return isa<tensor::InsertSliceOp, tensor::ExtractSliceOp>(op);
+ },
+ computeOps, tiledLoops)) ||
+ computeOps.empty()) {
+ continue;
}
+ std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
+ if (failed(
+ setTranslationUsingDistributeToGlobalId(funcOp, workgroupSize))) {
+ return computeOps[0]->emitOpError(
+ "failed to set translation info for distributing to global IDs");
+ }
+ continue;
+ }
- DISPATCH(linalg::GenericOp)
+ Operation *rootOperation = nullptr;
+ for (Operation *computeOp : computeOps) {
+ auto setConfigFn = [&](Operation *rootOp) -> LogicalResult {
+ return TypeSwitch<Operation *, LogicalResult>(rootOp)
+ .Case<linalg::BatchMatmulOp,
+ linalg::DepthwiseConvInputNHWCFilterHWCOp,
+ linalg::ConvInputNHWCFilterHWCFOp, linalg::MatmulOp>(
+ [&](auto op) { return setRootConfig(funcOp, targetEnv, op); })
+ .Default([&](Operation *) { return success(); });
+ };
+ if (failed(setConfigFn(computeOp))) {
+ return failure();
+ }
+ // Check if the op configuration was set.
+ if (getLoweringConfig(computeOp)) {
+ if (rootOperation) {
+ return computeOp->emitOpError(
+ "unhandled multiple roots in dispatch region");
+ }
+ rootOperation = computeOp;
+ }
+ }
+
+ // If there are still no roots, check for any generic op.
+ if (!rootOperation) {
+ for (Operation *computeOp : computeOps) {
+ if (isa<linalg::FillOp, linalg::CopyOp>(computeOp)) continue;
+ if (failed(setDefaultRootConfig(funcOp, targetEnv, computeOp))) {
+ return failure();
+ }
+ if (getLoweringConfig(computeOp)) {
+ if (rootOperation) {
+ return computeOp->emitOpError(
+ "unhandled multiple roots in dispatch region");
+ }
+ rootOperation = computeOp;
+ }
+ }
+ }
+
+ // Propogate the configuration to the other ops.
+ // TODO(ravishankarm, antiagainst): This is a very specific use (and
+ // fragile). In general, this should not be needed. Things are already tiled
+ // and distributed. The rest of the compilation must be structured to either
+ // use `TileAndFuse` or they are independent configurations that are
+ // determined based on the op.
+ IREE::HAL::LoweringConfig config = getLoweringConfig(rootOperation);
+ for (auto op : computeOps) {
+ if (op == rootOperation) continue;
+ setLoweringConfig(op, config);
}
}
-
-#undef DISPATCH
-
- if (!rootOperation) {
- return llvm::None;
- }
-
- launchConfig.setRootOperation(*rootOperation);
- if (options.workgroupSize.empty()) {
- launchConfig.setWorkgroupSize(config.workgroupSize);
- launchConfig.setVectorize(config.vectorize);
- }
- launchConfig.setNumSubgroups(config.numSubgroups);
-
- if (failed(propogateRootOperationLaunchConfig(launchConfig, *rootOperation,
- dependenceGraph)))
- return llvm::None;
-
- // TODO(ravishankarm): Verify that the set configurations is within the device
- // limits.
- return launchConfig;
+ return success();
}
template <typename OpTy>
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h
index a93bb3b..3465e7d 100644
--- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h
@@ -18,7 +18,6 @@
#include <array>
#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/LaunchConfig.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/FormatVariadic.h"
@@ -35,9 +34,7 @@
namespace mlir {
namespace iree_compiler {
-Optional<LaunchConfig> initGPULaunchConfig(
- MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps);
+LogicalResult initSPIRVLaunchConfig(ModuleOp moduleOp);
/// Returns the size of instruction in `vector` dialect that maps directly to
/// the hardware.
diff --git a/iree/compiler/Codegen/SPIRV/LaunchConfig.cpp b/iree/compiler/Codegen/SPIRV/LaunchConfig.cpp
deleted file mode 100644
index ce31d1e..0000000
--- a/iree/compiler/Codegen/SPIRV/LaunchConfig.cpp
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-//===- LaunchConfig.cpp - Specifies configuration used to drive the nfo ---===//
-//
-// This file defines the data structure that is used by the codegeneration to
-// lower to target specific IR. The values of the parameters are archtecture
-// specific. Once set the same transformations can be used to generate the
-// desired code. This allows sharing codegen infra between different backends.
-//
-//===----------------------------------------------------------------------===//
-
-#include "iree/compiler/Codegen/SPIRV/LaunchConfig.h"
-
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Name of the StrAttr that can be used to get the key to access the tile size
-/// information.
-static const char kLaunchInfoKey[] = "launch_info_key";
-static const char kRootOpKey[] = "is_root_op";
-
-static Optional<StringRef> getKey(Operation *op) {
- StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
- if (!attr) return {};
- return attr.getValue();
-}
-
-static void setKey(Operation *op, StringRef key) {
- MLIRContext *context = op->getContext();
- op->setAttr(Identifier::get(kLaunchInfoKey, op->getContext()),
- StringAttr::get(context, key));
-}
-
-static std::string getOrSetNewKey(Operation *op, int64_t suffix) {
- Optional<StringRef> key = getKey(op);
- if (key) return key->str();
- std::string newKey = llvm::formatv("__op_num_{0}__", suffix).str();
- setKey(op, StringRef(newKey));
- return newKey;
-}
-
-void LaunchConfig::finalize(FuncOp funcOp) {
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- linalgOp->removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
- });
-}
-
-TileSizesListTypeRef LaunchConfig::getTileSizes(Operation *op) const {
- auto key = getKey(op);
- if (!key) return {};
- auto it = tileSizes.find(*key);
- if (it == tileSizes.end()) return {};
- return it->second;
-}
-
-ArrayRef<int64_t> LaunchConfig::getTileSizes(Operation *op,
- size_t level) const {
- auto t = getTileSizes(op);
- if (level >= t.size()) return {};
- return t[level];
-}
-
-Operation *LaunchConfig::getRootOperation(ArrayRef<Operation *> ops) {
- for (auto op : ops) {
- if (op->getAttrOfType<UnitAttr>(kRootOpKey)) return op;
- }
- return nullptr;
-}
-
-void LaunchConfig::setTileSizes(Operation *op, TileSizesListType vTileSizes) {
- tileSizes[getOrSetNewKey(op, tileSizes.size())] = vTileSizes;
-}
-
-void LaunchConfig::setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes,
- size_t level) {
- tileSizes[getOrSetNewKey(op, tileSizes.size())].emplace_back(
- vTileSizes.begin(), vTileSizes.end());
-}
-
-static void setArrayVals(std::array<int64_t, 3> &array,
- ArrayRef<int64_t> vals) {
- if (vals.size() > 3) vals = vals.take_front(3);
- for (auto size : enumerate(vals)) array[size.index()] = size.value();
- for (unsigned i : llvm::seq<unsigned>(vals.size(), 3)) array[i] = 1;
-}
-
-void LaunchConfig::setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize) {
- setArrayVals(workgroupSize, vWorkgroupSize);
-}
-
-void LaunchConfig::setNumSubgroups(ArrayRef<int64_t> vNumSubgroups) {
- setArrayVals(numSubgroups, vNumSubgroups);
-}
-
-void LaunchConfig::setRootOperation(Operation *op) {
- op->setAttr(kRootOpKey, UnitAttr::get(op->getContext()));
-}
-
-void LaunchConfig::setSameConfig(Operation *source, Operation *target) {
- assert(getKey(source) && "missing configuration of source operation");
- setKey(target, *getKey(source));
-}
-
-void LaunchConfig::setVectorize(bool enableVectorize) {
- vectorize = enableVectorize;
-}
-
-LogicalResult propogateRootOperationLaunchConfig(
- LaunchConfig &config, linalg::LinalgOp rootOperation,
- const linalg::LinalgDependenceGraph &dependenceGraph) {
- auto dependences = dependenceGraph.getDependentOperations(rootOperation);
- // The dependent operations get the same tile size information as the root
- // operation. To propogate that information, just use the same key as the root
- // operation.
- for (auto dependence : dependences) {
- config.setSameConfig(rootOperation, dependence.getDependentOp());
- }
- return success();
-}
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/LaunchConfig.h b/iree/compiler/Codegen/SPIRV/LaunchConfig.h
deleted file mode 100644
index 99f30fc..0000000
--- a/iree/compiler/Codegen/SPIRV/LaunchConfig.h
+++ /dev/null
@@ -1,144 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-//===- LaunchConfig.h - Configuration used to drive arch specific codegen -===//
-//
-// This file declares the data structure that is used by the codegeneration to
-// lower to target specific IR. The values of the parameters are archtecture
-// specific. Once set the same transformations can be used to generate the
-// desired code. This allows sharing codegen infra between different backends.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CODEGEN_COMMON_LAUNCHCONFIG_H_
-#define IREE_COMPILER_CODEGEN_COMMON_LAUNCHCONFIG_H_
-#include <array>
-
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Stores the tile sizes to use at different levels of tiling as a vector of
-/// vectors.
-/// - First level tiling maps to workgroups.
-/// - Second level tiling maps to subgroups.
-/// - Third level tiling maps to invocations.
-using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
-using TileSizesListTypeRef = ArrayRef<SmallVector<int64_t, 4>>;
-
-/// Configurations for mapping Linalg ops to CPU/GPU parallel hiearchies.
-///
-/// Based on the linalg operations in a dispatch region, the number of levels of
-/// tiling, the tile sizes needed, the workgroup size, etc. need to be
-/// decided. These parameters are called `LaunchConfig`. This class implements
-/// one heuristic to compute these for the different linalg operations on
-/// buffers. This can be adapted later to support multiple configurations that
-/// can be picked based on device information/problem size information. It
-/// exposes the information needed by the codegenerators, and hides the
-/// implementation from the rest of the pipeline.
-class LaunchConfig {
- public:
- LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
-
- /// Removes attributes added to operations for retrieving tile size
- /// information.
- void finalize(FuncOp funcOp);
-
- /// Gets the tile size computed for an operation at all levels.
- TileSizesListTypeRef getTileSizes(Operation *op) const;
-
- /// Gets the tile size computed for an operation for an level.
- ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const;
-
- /// Returns the workgroup size to use based on the tile sizes.
- ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
- /// Returns the number of subgroups to use.
- ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
-
- /// Of the given operations return the operation that has been marked as the
- /// root operation. Within a dispatch region a single root operation (like
- /// matmul, conv, etc.) decides the launch configuration to be used. The rest
- /// of the ops that are fused with it obey this configuration. Returns nullptr
- /// if unable to find an operation that is set as root in the list.
- Operation *getRootOperation(ArrayRef<Operation *> ops);
-
- /// Returns true if tile sizes have been computed for the operation. If tile
- /// sizes arent set, it implies operation is not to be tiled.
- bool hasTileSizes(Operation *op, size_t level = 0) const {
- return !getTileSizes(op, level).empty();
- }
-
- /// Use vectorize transformations.
- bool useVectorize() const { return vectorize; }
-
- /// Sets the tile sizes to use for all levels of tiling of `op`.
- void setTileSizes(Operation *op, TileSizesListType vTileSizes);
-
- /// Sets the tile sizes to use for a given `level` of tiling of `op`.
- void setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes, size_t level);
-
- /// Sets the workgroup size to use for the function.
- void setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize);
-
- /// Sets number of subgroups to use.
- void setNumSubgroups(ArrayRef<int64_t> vNumSubgroups);
-
- /// Sets the root operation. Within a dispatch region a single root operation
- /// (like matmul, conv, etc.) decides the launch configuration to be used. The
- /// rest of the ops that are fused with it obey this configuration.
- void setRootOperation(Operation *root);
-
- /// Sets the configuration of the `targetOp` to be same as the configuration
- /// of the `sourceOp`.
- void setSameConfig(Operation *sourceOp, Operation *targetOp);
-
- /// Sets flag to enable vectorization.
- void setVectorize(bool enableVectorize);
-
- protected:
- /// Current tile size configuration per operation. They key used here to
- /// retrieve the tile size information per operation is the value of a StrAttr
- /// added to operations during `init`. When tiled this attribute is copied
- /// over to the tiled operation, thereby the same key can be used to retrieve
- /// the tile sizes for the next level of tiling. The `finalize` method removes
- /// these attributes.
- llvm::StringMap<TileSizesListType> tileSizes;
-
- /// Workgroup size to use.
- std::array<int64_t, 3> workgroupSize = {1, 1, 1};
-
- /// Number of subgroups that are logically distributed along x, y & z.
- std::array<int64_t, 3> numSubgroups = {1, 1, 1};
-
- /// Use vectorization.
- bool vectorize = false;
-};
-
-/// Propogates tile sizes from `rootOperation` to other linalg operations in the
-/// dispatch region. This assumes that each dispatch region has a single root
-/// operation (like matmul, conv, etc.) that determines the tile sizes to use
-/// for tile+fuse+distribute. These are then propogated to the other operations.
-/// Note: This is a temporary solution and might be defunct when the codegen
-/// becomes more sophisticated.
-LogicalResult propogateRootOperationLaunchConfig(
- LaunchConfig &launchConfig, linalg::LinalgOp rootOperation,
- const linalg::LinalgDependenceGraph &dependenceGraph);
-
-} // namespace iree_compiler
-} // namespace mlir
-#endif // IREE_COMPILER_CODEGEN_COMMON_LAUNCHCONFIG_H_
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index 3c122c7..4fa7ee2 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Codegen/SPIRV/MemorySpace.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
@@ -42,35 +43,39 @@
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
+#define DEBUG_TYPE "iree-spirv-lowering-pass-pipeline"
+
namespace mlir {
namespace iree_compiler {
-void buildLinalgToSPIRVPassPipeline(OpPassManager &pm,
- const SPIRVCodegenOptions &options) {
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
- pm.nest<ModuleOp>().addPass(createCSEPass());
+static Value gpuAllocationFunction(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> staticShape,
+ Type elementType,
+ ArrayRef<Value> dynamicSizes) {
+ MemRefType allocType =
+ MemRefType::get(staticShape, elementType, {}, getWorkgroupMemorySpace());
+ return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
+}
+void addSPIRVVectorizationPassPipeline(OpPassManager &pm) {
//===--------------------------------------------------------------------===//
- // Tiling, distribution, vectorization
+ // Initial clean up.
//===--------------------------------------------------------------------===//
-
- // flow.dispatch.workgroups performed abstract tiling and distribution. Make
- // them concrete now since we know the target and settings now.
- pm.addPass(createSPIRVConcretizeWorkgroupTilesPass(options));
- // Tile and distribute to GPU subgroups/invocations and vectorize.
- pm.addPass(createSPIRVTileAndVectorizePass(options));
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ pm.addNestedPass<FuncOp>(createSPIRVRemoveOneTripTiledLoopPass());
+ // Tile and distribute to GPU subgroups/invocations and vectorize.
+ pm.addNestedPass<FuncOp>(createSPIRVTileAndVectorizePass());
+ pm.addPass(createCanonicalizerPass());
// Handle ops that cannot go through the previous tiling, distribution, and
// vectorization flow. Only perform one level of distribution to map them to
// GPU global invocation IDs for distribution.
// TODO(antiagainst): Handle all the cases uniformly and remove this pass.
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(
- createSPIRVCopyToWorkgroupMemoryPass());
- pm.addPass(createSPIRVConvertToGPUPass());
- pm.nest<ModuleOp>().addPass(createLowerAffinePass());
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
- pm.nest<ModuleOp>().addPass(createCSEPass());
+ pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
//===--------------------------------------------------------------------===//
// Optimizations and cleanups
@@ -78,65 +83,110 @@
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
+ pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
+}
+void addSPIRVDistributePassPipeline(OpPassManager &pm) {
+ //===--------------------------------------------------------------------===//
+ // Initial clean up.
+ //===--------------------------------------------------------------------===//
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ // Tile and distribute to GPU subgroups/invocations and vectorize.
+ pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass());
+ pm.addPass(createCanonicalizerPass());
+
+ // Handle ops that cannot go through the previous tiling, distribution, and
+ // vectorization flow. Only perform one level of distribution to map them to
+ // GPU global invocation IDs for distribution.
+ // TODO(antiagainst): Handle all the cases uniformly and remove this pass.
+ pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ //===--------------------------------------------------------------------===//
+ // Optimizations and cleanups
+ //===--------------------------------------------------------------------===//
+
+ // Perform various vector-level cross-op optimizations like load-store
+ // forwarding, shape casting and casting op cancelling.
+ pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
+}
+
+void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm) {
+ // Handle ops that cannot go through the previous tiling, distribution, and
+ // vectorization flow. Only perform one level of distribution to map them to
+ // GPU global invocation IDs for distribution.
+ // TODO(antiagainst): Handle all the cases uniformly and remove this pass.
+ pm.addNestedPass<FuncOp>(createSPIRVConvertToGPUPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ //===--------------------------------------------------------------------===//
+ // Optimizations and cleanups
+ //===--------------------------------------------------------------------===//
+
+ // Perform various vector-level cross-op optimizations like load-store
+ // forwarding, shape casting and casting op cancelling.
+ pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
+}
+
+static void addLowerToSPIRVPasses(OpPassManager &pm) {
// Fold load/store from/to subview ops into the original memref when possible.
// In SPIR-V we don't use memref descriptor so it's not possible to handle
// subview ops.
- pm.nest<ModuleOp>().addPass(memref::createFoldSubViewOpsPass());
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
- pm.nest<ModuleOp>().addPass(createCSEPass());
+ pm.addPass(memref::createFoldSubViewOpsPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
// Turn scalar load/store from memrefs into vectorized ones if possible. This
// gives better memory access patterns, which is very important for perf.
- pm.nest<ModuleOp>().addPass(createSPIRVVectorizeLoadStore());
+ pm.addPass(createSPIRVVectorizeLoadStore());
// Lower vector ops to SPIR-V cooperative matrix ops. This needs to be done
// before flattening memref because we still need the multi-dimension
// structure.
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(
- createSPIRVVectorToCooperativeMatrixPass());
+ pm.addNestedPass<FuncOp>(createSPIRVVectorToCooperativeMatrixPass());
// Perform optimizations that need to across the scf.for region boundary.
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(createForOpCanonicalizationPass());
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
- pm.nest<ModuleOp>().addPass(createCSEPass());
+ pm.addNestedPass<FuncOp>(createForOpCanonicalizationPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
// Turn multi-dimension memref into one-dimension. This is needed for SPIR-V
// because we don't use upstream memref descriptors.
- pm.nest<ModuleOp>().addPass(createFlattenMemRefSubspanPass());
- pm.nest<ModuleOp>().addPass(createLowerAffinePass());
- pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
- pm.nest<ModuleOp>().addPass(createCSEPass());
+ pm.addPass(createFlattenMemRefSubspanPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
//===--------------------------------------------------------------------===//
// SPIR-V conversions
//===--------------------------------------------------------------------===//
// Finally convert everything to SPIR-V.
- pm.nest<ModuleOp>().addPass(createConvertToSPIRVPass());
+ pm.addPass(createConvertToSPIRVPass());
- OpPassManager &spirvModulePM = pm.nest<ModuleOp>().nest<spirv::ModuleOp>();
+ OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>();
spirvModulePM.addPass(spirv::createLowerABIAttributesPass());
spirvModulePM.addPass(createCanonicalizerPass());
spirvModulePM.addPass(createCSEPass());
spirvModulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
}
-void buildSPIRVCodegenPassPipeline(OpPassManager &pm,
- const SPIRVCodegenOptions &options) {
- pm.nest<ModuleOp>().addPass(createInlinerPass());
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCleanupBufferAllocViewPass());
+void buildSPIRVCodegenPassPipeline(OpPassManager &pm) {
+ {
+ OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
+ addLinalgBufferizePasses(nestedModulePM, gpuAllocationFunction);
+ }
+ pm.addPass(createSPIRVLowerExecutableTargetPass());
+ OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
+ addLowerToSPIRVPasses(nestedModulePM);
- WorkgroupMemoryAllocationFn allocationFn =
- [](OpBuilder &builder, Location loc, ArrayRef<int64_t> staticShape,
- Type elementType, ArrayRef<Value> dynamicSizes) {
- MemRefType allocType = MemRefType::get(staticShape, elementType, {},
- getWorkgroupMemorySpace());
- return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
- };
- addLinalgBufferizePasses(pm.nest<ModuleOp>(), allocationFn);
-
- buildLinalgToSPIRVPassPipeline(pm, options);
+ LLVM_DEBUG({
+ llvm::dbgs() << "Using SPIRV Pass pipeline :\n";
+ pm.printAsTextualPipeline(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
}
} // namespace iree_compiler
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVConcretizeWorkgroupTiles.cpp b/iree/compiler/Codegen/SPIRV/SPIRVConcretizeWorkgroupTiles.cpp
deleted file mode 100644
index 550e489..0000000
--- a/iree/compiler/Codegen/SPIRV/SPIRVConcretizeWorkgroupTiles.cpp
+++ /dev/null
@@ -1,411 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-//===- SPIRVConcretizeWorkgroupTiles.cpp ----------------------------------===//
-//
-// This pass concretizes hal.interface.workgroup ops by replacing them with
-// constant values from the chosen tiling and distribution scheme.
-//
-// During dispatch region formation in IREE Flow transformations, ops are tiled
-// and distributed in an abstract way by using symbolic hal.interface.workgroup
-// ops. That is because the same source region is compiled towards different
-// target backends and each target backend could use different tiling and
-// distribution schemes. However, after HAL interface materialization, the
-// hal.executable.variant is just meant for one target backend. We need to
-// concretize the tiling and distribution in order to inject static information
-// for further compilation.
-//
-// This pass performs the conretization in two modes:
-//
-// 1) Partically static: where have a concrete tiling and distirbution sheme
-// *but not* a full static original problem size (e.g., due to dynamic
-// shapes). Under such circumstances, we can only replace ops like
-// hal.interface.workgroup.size ops and still need to compute the number
-// of workgroups using symbolic values.
-// 2) Fully static: where we have a concrete tiling and distribution scheme
-// *and* the full static original problem size. Under such circumstances,
-// we can fully deduce the number of workgroups to dispatch and replace
-// hal.interface.workgroup.count ops with constant values too.
-//
-//===----------------------------------------------------------------------===//
-
-#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
-#include "iree/compiler/Codegen/SPIRV/LaunchConfig.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "iree-spirv-concretize-workgroups-tile"
-
-namespace mlir {
-namespace iree_compiler {
-
-static constexpr unsigned kMaxWorkgroupDimCount = 3;
-
-static int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; }
-
-/// Returns the root Linalg op that dictates tiling and distribution policy.
-static linalg::LinalgOp getRootLinalgOp(FuncOp funcOp,
- const SPIRVCodegenOptions &options) {
- SmallVector<linalg::LinalgOp, 4> linalgOps;
- SmallVector<Operation *, 4> tiledLoops;
- if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) return {};
-
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- Optional<LaunchConfig> launchConfigOpt = initGPULaunchConfig(
- funcOp.getContext(), dependenceGraph, options, linalgOps);
- if (!launchConfigOpt) return {};
-
- LaunchConfig &launchConfig = *launchConfigOpt;
- Operation *rootOp =
- launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range(
- linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); })));
-
- // Clean up internal markers that are set during launch configuration
- // preparation.
- launchConfig.finalize(funcOp);
-
- return rootOp;
-}
-
-namespace {
-/// Replaces hal.interface.workgroup.count op with the constant value chosen
-/// from tiling scheme.
-class ConcretizeWorkgroupCountOp final
- : public OpRewritePattern<IREE::HAL::InterfaceWorkgroupCountOp> {
- public:
- ConcretizeWorkgroupCountOp(MLIRContext *context,
- ArrayRef<int64_t> numWorkgroups,
- PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit),
- numWorkgroups(numWorkgroups.begin(), numWorkgroups.end()) {}
-
- LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op,
- PatternRewriter &rewriter) const override {
- unsigned dimIndex = op.dimension().getZExtValue();
-
- if (dimIndex >= numWorkgroups.size()) return failure();
- rewriter.replaceOpWithNewOp<ConstantOp>(
- op, rewriter.getIndexAttr(numWorkgroups[dimIndex]));
-
- return success();
- }
-
- private:
- SmallVector<int64_t, 4> numWorkgroups;
-};
-
-// Canonicalizes away a trip-one scf.for loop by inlining its body and removing
-// the loop.
-//
-// This pattern is needed because in Flow abstract tiling and distribution we
-// will create scf.for loops that distribute workload cyclically. After
-// concretizing hal.interface.workgroup.* ops, these scf.for loops still remain,
-// and they will be of the form:
-//
-// %lb = mul %workgroup_id_{x|y|z}, %cst_tile_size_{x|y|z}
-// scf.for %iv = %lb to %cst_wokload_size_{x|y|z}
-// step %cst_workload_size_{x|y|z} { ... }
-//
-// Such scf.for loops can be inlined if %lb is smaller than upper bound.
-class RemoveTripOneLoop final : public OpRewritePattern<scf::ForOp> {
- public:
- RemoveTripOneLoop(MLIRContext *context, ArrayRef<int64_t> workloadSize,
- ArrayRef<int64_t> tileSize, PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit),
- workloadSize(workloadSize.begin(), workloadSize.end()),
- tileSize(tileSize.begin(), tileSize.end()) {}
-
- LogicalResult matchAndRewrite(scf::ForOp op,
- PatternRewriter &rewriter) const override {
- // Get constant upper bound and step values.
- IntegerAttr ub, step;
- if (!matchPattern(op.upperBound(), m_Constant(&ub)) ||
- !matchPattern(op.step(), m_Constant(&step))) {
- return failure();
- }
-
- // Require that they are the same.
- if (ub != step) return failure();
-
- // Now make sure the lower bound is smaller than upper bound. The lower
- // bound should be multiplying the workgroup ID with some constant.
-
- auto mulOp = op.lowerBound().getDefiningOp<AffineApplyOp>();
- if (!mulOp || mulOp.mapOperands().size() != 2) return failure();
-
- AffineExpr lhs, rhs;
- bindSymbols(op.getContext(), lhs, rhs);
- auto mulMap = AffineMap::get(0, 2, lhs * rhs);
- if (mulOp.getAffineMap() != mulMap) return failure();
-
- auto mulLhs = mulOp.mapOperands().front();
- auto mulRhs = mulOp.mapOperands().back();
-
- auto idOp = mulLhs.getDefiningOp<IREE::HAL::InterfaceWorkgroupIDOp>();
- IntegerAttr multipler;
- if (!idOp || !matchPattern(mulRhs, m_Constant(&multipler)))
- return failure();
-
- // We just need to make sure the max value of the workgroup ID multipled by
- // the multipler is smaller than the upper bound to guarantee one trip.
- unsigned dimIndex = idOp.dimension().getZExtValue();
- int64_t dimSize = workloadSize[dimIndex];
- int64_t dimTile = tileSize[dimIndex];
-
- if (dimSize == ShapedType::kDynamicSize) return failure();
-
- int64_t count = ceilDiv(dimSize, dimTile);
- assert(count > 0 && "expected at least one tile!");
-
- // ID should be in range [0, count).
- if ((count - 1) * multipler.getInt() >= ub.getInt()) {
- // Dead loop. It can actually be removed entirely. But we aren't expecting
- // it to happen here. Do not canonicalize for such case.
- return failure();
- }
-
- SmallVector<Value, 4> blockArgs;
- blockArgs.reserve(op.getNumIterOperands() + 1);
- blockArgs.push_back(op.lowerBound());
- llvm::append_range(blockArgs, op.getIterOperands());
-
- Block *block = &op.getLoopBody().front();
- Operation *terminator = block->getTerminator();
- ValueRange results = terminator->getOperands();
- rewriter.mergeBlockBefore(block, op, blockArgs);
- rewriter.replaceOp(op, results);
- rewriter.eraseOp(terminator);
-
- return success();
- }
-
- private:
- SmallVector<int64_t, 4> workloadSize;
- SmallVector<int64_t, 4> tileSize;
-};
-
-static void removeOneTripTiledLoops(MLIRContext *context, FuncOp funcOp,
- linalg::LinalgOp rootLinalgOp,
- ArrayRef<int64_t> halWorkgroupSize) {
- if (rootLinalgOp.getNumOutputs() != 1) return;
- unsigned numParallelDims = getNumOuterParallelLoops(rootLinalgOp);
- unsigned numTiledDims =
- std::min<size_t>(numParallelDims, kMaxWorkgroupDimCount);
-
- ArrayRef<int64_t> outputShape =
- getUntiledShape(rootLinalgOp.getOutputOperand(0)->get());
- if (outputShape.size() < numParallelDims) return;
-
- // TODO(ravishankarm, antiagainst): Its pure co-incidence that the
- // workload is derivable from the output shape. There is no requirement
- // for this but is the case for all operations we are interested in.
- auto workloadSize = llvm::to_vector<4>(llvm::reverse(
- outputShape.take_front(numParallelDims).take_back(numTiledDims)));
- if (llvm::any_of(workloadSize, [](int64_t dim) {
- return dim == ShapedType::kDynamicSize;
- })) {
- return;
- }
- LLVM_DEBUG({
- llvm::dbgs() << "Queried workload size: ";
- llvm::interleaveComma(workloadSize, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
- SmallVector<int64_t, 3> numWorkgroups;
- assert(halWorkgroupSize.size() == workloadSize.size());
- for (auto pair : llvm::zip(workloadSize, halWorkgroupSize)) {
- auto workload = std::get<0>(pair);
- auto size = std::get<1>(pair);
- numWorkgroups.push_back(ceilDiv(workload, size));
- }
- numWorkgroups.resize(kMaxWorkgroupDimCount, 1);
- WorkgroupCountRegionBuilder regionBuilder = [&](OpBuilder &b, Location loc,
- std::array<Value, 3>) {
- std::array<Value, 3> returnValues;
- for (unsigned i = 0; i < kMaxWorkgroupDimCount; ++i) {
- returnValues[i] = b.create<ConstantIndexOp>(loc, numWorkgroups[i]);
- }
- return returnValues;
- };
-
- OpBuilder builder(context);
- if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) {
- return;
- }
-
- {
- OwningRewritePatternList workgroupCountPatterns(context);
- workgroupCountPatterns.insert<ConcretizeWorkgroupCountOp>(context,
- numWorkgroups);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(workgroupCountPatterns));
- }
- {
- OwningRewritePatternList removeTripOneLoopPatterns(context);
- removeTripOneLoopPatterns.insert<RemoveTripOneLoop>(context, workloadSize,
- halWorkgroupSize);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(removeTripOneLoopPatterns));
- }
-}
-
-/// Concretizes hal.interface.workgroup.* ops with constants from the chosen
-/// tiling sheme when possible and perform loop canonicalization afterwards.
-class SPIRVConcretizeWorkgroupTilesPass
- : public SPIRVConcretizeWorkgroupTilesBase<
- SPIRVConcretizeWorkgroupTilesPass> {
- public:
- SPIRVConcretizeWorkgroupTilesPass(const SPIRVCodegenOptions &options)
- : options(options) {}
- SPIRVConcretizeWorkgroupTilesPass(
- const SPIRVConcretizeWorkgroupTilesPass &that)
- : options(that.options) {
- inlineTripOneLoops = that.inlineTripOneLoops;
- }
-
- void runOnOperation() override {
- IREE::HAL::ExecutableVariantOp variantOp = getOperation();
- ModuleOp module = variantOp.getInnerModule();
- for (FuncOp funcOp : module.getOps<FuncOp>()) {
- if (!funcOp.isPublic()) continue;
- (void)runOnFunction(funcOp);
- }
- }
-
- private:
- LogicalResult runOnFunction(FuncOp funcOp) {
- MLIRContext &context = getContext();
-
- // 1. Get the linalg operations within the function. The callee here
- // successed only for functions with single basic block.
- SmallVector<linalg::LinalgOp> linalgOps;
- SmallVector<Operation *> tiledLoops;
- if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) {
- return failure();
- }
- // If there are no Linalg ops. Nothing to do. Return.
- if (linalgOps.empty()) return success();
-
- // 2. Get the launch configuration to use for the function.
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- Optional<LaunchConfig> launchConfig = initGPULaunchConfig(
- funcOp.getContext(), dependenceGraph, options, linalgOps);
- if (!launchConfig) {
- // Having no config implies that there is nothing to do here. Return
- return success();
- }
-
- // 3. The root operation determines the tile size to use. This has already
- // been computed by the launch configuration.
- // TODO(ravishankarm): The configuration actually makes sure that all tile
- // sizes for the parallel loops are consistent, but get the root operation
- // for now.
- Operation *rootOp =
- launchConfig->getRootOperation(llvm::to_vector<4>(llvm::map_range(
- linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); })));
-
- unsigned numParallelDims = getNumOuterParallelLoops(rootOp);
- unsigned numTiledDims =
- std::min<size_t>(numParallelDims, kMaxWorkgroupDimCount);
- ArrayRef<int64_t> tileSizes = launchConfig->getTileSizes(rootOp, 0);
- if (tileSizes.size() < numParallelDims) {
- return rootOp->emitError(
- "invalid tile size configuration, expected at least as many "
- "as the number of tiled loops : ")
- << numParallelDims;
- }
-
- // TODO(ravishankarm): The flow tiling only tiles the inner parallel loops
- // by default. Using the same approach here. This spooky distant shake hand
- // needs to be resolved. Potentially can be made cleaner with use of
- // `linalg.tile` operation.
- tileSizes = tileSizes.take_front(numParallelDims).take_back(numTiledDims);
- if (llvm::any_of(tileSizes, [](int64_t ts) { return ts == 0; })) {
- return rootOp->emitError(
- "unhandled tile size setting of 0 for a loop that was tiled");
- }
-
- // 4. The hal.workgroup.size is a representation of the tile size. Note that
- // this is not the actual workgroup size used eventually. That is computed
- // by the launch configuration and is set below.
- auto halWorkgroupSize = llvm::to_vector<4>(llvm::reverse(tileSizes));
-
- LLVM_DEBUG({
- llvm::dbgs() << "Queried tile size: ";
- llvm::interleaveComma(tileSizes, llvm::dbgs());
- llvm::dbgs() << ", HAL workgroup size: ";
- llvm::interleaveComma(halWorkgroupSize, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
- // 4. Materialize the constant values for the hal.workgroup.size along
- // different dimensions.
- if (failed(materializeStaticLaunchInformation(funcOp, halWorkgroupSize))) {
- return funcOp.emitOpError(
- "failed to materialize static launch information");
- }
-
- // 5. Update the actual workgroup size to use based on launch configuraiton.
- if (failed(updateWorkGroupSize(funcOp, launchConfig->getWorkgroupSize()))) {
- return funcOp.emitOpError("failed to set workgroup size on function");
- }
- launchConfig->finalize(funcOp);
-
- if (inlineTripOneLoops) {
- removeOneTripTiledLoops(&context, funcOp, cast<linalg::LinalgOp>(rootOp),
- halWorkgroupSize);
- }
-
- return success();
- }
-
- private:
- SPIRVCodegenOptions options;
-
- // TODO(#5034): Investigate whether there is a better way to prove tileability
- // and canonicalize affine.min ops, without matching against the specific
- // pattern involving loops.
- Option<bool> inlineTripOneLoops{
- *this, "inline-trip-one-loops",
- llvm::cl::desc(
- "Inline a loop's body if it can be proven to just have one trip"),
- llvm::cl::init(true)};
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConcretizeWorkgroupTilesPass(const SPIRVCodegenOptions &options) {
- return std::make_unique<SPIRVConcretizeWorkgroupTilesPass>(options);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp b/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp
index ed4aad3..5613f85 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp
@@ -115,7 +115,6 @@
generateGuard);
}
-
//===----------------------------------------------------------------------===//
// Pass and patterns.
//===----------------------------------------------------------------------===//
@@ -165,7 +164,6 @@
linalg::linalgOpToParallelLoops(rewriter, linalgOp);
if (!loops) return failure();
- SmallVector<int64_t, 3> workgroupSize(3, 1);
if (!loops.getValue().empty()) {
scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
// If there are parallel loops partition them to threads using global
@@ -179,22 +177,8 @@
return rewriter.notifyMatchFailure(
linalgOp, "mapping to GlobalInvocationID failed");
}
- workgroupSize = {32, 1, 1};
}
}
- WorkgroupCountRegionBuilder regionBuilder =
- [&workgroupSize](OpBuilder &b, Location loc,
- std::array<Value, 3> workload) {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- return std::array<Value, 3>{
- getWorkgroupCountX(b, loc, workload, workgroupSize[0]), one, one};
- };
- if (failed(defineWorkgroupCountRegion(rewriter, funcOp, regionBuilder))) {
- return failure();
- }
- if (failed(updateWorkGroupSize(funcOp, workgroupSize))) {
- return failure();
- }
rewriter.eraseOp(linalgOp);
return success();
}
@@ -202,8 +186,9 @@
} // namespace
-
void SPIRVConvertToGPUPass::runOnOperation() {
+ FuncOp funcOp = getOperation();
+ if (!isEntryPoint(funcOp)) return;
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// After this pass Linalg and scf.parallel ops should be gone.
@@ -223,20 +208,16 @@
MapLinalgOpToGlobalInvocationId<linalg::GenericOp>>(context);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- for (FuncOp funcOp : getOperation().getInnerModule().getOps<FuncOp>()) {
- if (!isEntryPoint(funcOp)) continue;
- Region &body = funcOp.getBody();
- if (!llvm::hasSingleElement(body)) {
- funcOp.emitError("unhandled dispatch function with multiple blocks");
- return signalPassFailure();
- }
- if (failed(applyFullConversion(funcOp, target, frozenPatterns)))
- return signalPassFailure();
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body)) {
+ funcOp.emitError("unhandled dispatch function with multiple blocks");
+ return signalPassFailure();
}
+ if (failed(applyFullConversion(funcOp, target, frozenPatterns)))
+ return signalPassFailure();
}
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConvertToGPUPass() {
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass() {
return std::make_unique<SPIRVConvertToGPUPass>();
}
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVFoldGPUProcessorIDUses.cpp b/iree/compiler/Codegen/SPIRV/SPIRVFoldGPUProcessorIDUses.cpp
index 6b5faca..52d000c 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVFoldGPUProcessorIDUses.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVFoldGPUProcessorIDUses.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
@@ -84,16 +85,7 @@
IREE::HAL::ReturnOp getEntryPointReturnOp(Operation *op) {
auto funcOp = op->getParentOfType<FuncOp>();
- auto variantOp =
- funcOp.getOperation()->getParentOfType<IREE::HAL::ExecutableVariantOp>();
-
- IREE::HAL::ExecutableEntryPointOp entryPointOp;
- for (auto op : variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
- if (op.sym_name() == funcOp.getName()) {
- entryPointOp = op;
- break;
- }
- }
+ IREE::HAL::ExecutableEntryPointOp entryPointOp = getEntryPoint(funcOp);
if (!entryPointOp || !entryPointOp.getBody()) return {};
Operation *terminator = entryPointOp.getBlock()->getTerminator();
@@ -136,13 +128,17 @@
/// point ABI.
Optional<int64_t> getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) {
FuncOp funcOp = threadIDOp->getParentOfType<FuncOp>();
- auto abiAttr = funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
- spirv::getEntryPointABIAttrName());
- if (!abiAttr) return llvm::None;
+ IREE::HAL::ExecutableEntryPointOp entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return {};
+
+ Optional<ArrayAttr> sizes = entryPointOp.workgroup_size();
+ if (!sizes) return {};
int index = dimensionToIndex(threadIDOp.dimension());
- auto valueIt = abiAttr.local_size().getIntValues().begin() + index;
- return (*valueIt).getZExtValue();
+ if (index < sizes->size()) {
+ return sizes->getValue()[index].cast<IntegerAttr>().getInt();
+ }
+ return llvm::None;
}
/// Folds `affine.min` ops which has only one symbol operand, which is a
@@ -268,8 +264,7 @@
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(&getContext());
populateFoldGPUProcessorIDUsesPatterns(context, patterns);
- (void)applyPatternsAndFoldGreedily(getOperation().getInnerModule(),
- std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -281,8 +276,7 @@
AffineMinOp::getCanonicalizationPatterns(patterns, context);
}
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVFoldProcessorIDUsesPass() {
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVFoldProcessorIDUsesPass() {
return std::make_unique<SPIRVFoldProcessorIDUsesPass>();
}
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
new file mode 100644
index 0000000..91e499d
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -0,0 +1,132 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+#define DEBUG_TYPE "iree-spirv-lower-executable-target-pass"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+/// Lowers an hal.executable.variant operation to scalar/native-vector
+/// code. Invokes different compilation pipeline to
+/// - first lower to scalar/native-vector code
+/// - then convert to SPIRV dialect.
+class SPIRVLowerExecutableTargetPass
+ : public SPIRVLowerExecutableTargetBase<SPIRVLowerExecutableTargetPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, gpu::GPUDialect, IREE::HAL::HALDialect,
+ linalg::LinalgDialect, linalg_ext::LinalgExtDialect,
+ memref::MemRefDialect, scf::SCFDialect, ShapeDialect,
+ spirv::SPIRVDialect, vector::VectorDialect>();
+ }
+
+ SPIRVLowerExecutableTargetPass() = default;
+ SPIRVLowerExecutableTargetPass(const SPIRVLowerExecutableTargetPass &pass){};
+
+ void runOnOperation() override;
+
+ private:
+ Option<bool> testLoweringConfiguration{
+ *this, "test-lowering-configuration",
+ llvm::cl::desc(
+ "Flag used for lit-testing the default configuration set for root "
+ "ops in hal.executable.variants. Defaults to false and is set to "
+ "true "
+ "for lit tests. Not for general usage"),
+ llvm::cl::init(false)};
+};
+} // namespace
+
+void SPIRVLowerExecutableTargetPass::runOnOperation() {
+ IREE::HAL::ExecutableVariantOp variantOp = getOperation();
+ ModuleOp moduleOp = variantOp.getInnerModule();
+
+ OpPassManager executableLoweringPipeline(
+ IREE::HAL::ExecutableVariantOp::getOperationName());
+
+ if (failed(initSPIRVLaunchConfig(moduleOp))) {
+ return signalPassFailure();
+ }
+ // There might be multiple entry points in the module. Currently, all of
+ // them need to have the same pipeline.
+ // TODO(ravishankarm): This is strange that this is not enforced
+ // structurally, but something to address later on. For now this restriction
+ // is fine.
+ llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
+ getAllEntryPoints(moduleOp);
+ Optional<IREE::HAL::DispatchLoweringPassPipeline> passPipeline;
+ for (auto &it : entryPoints) {
+ auto entryPointOp = it.second;
+ if (IREE::HAL::TranslationInfo translationInfo =
+ getTranslationInfo(entryPointOp)) {
+ IREE::HAL::DispatchLoweringPassPipeline currPipeline =
+ translationInfo.passPipeline().getValue();
+ if (passPipeline) {
+ if (currPipeline != passPipeline.getValue()) {
+ moduleOp.emitError(
+ "unhandled compilation of entry point function with different "
+ "pass pipelines within a module");
+ return signalPassFailure();
+ }
+ continue;
+ }
+ passPipeline = currPipeline;
+ }
+ }
+
+ executableLoweringPipeline.addPass(createSetNumWorkgroupsPass());
+ executableLoweringPipeline.addPass(createCanonicalizerPass());
+ if (!testLoweringConfiguration && passPipeline.hasValue()) {
+ OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
+ switch (*passPipeline) {
+ case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute:
+ addSPIRVDistributePassPipeline(nestedModulePM);
+ break;
+ case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistributeToGlobalID:
+ addSPIRVDistributeToGlobalIDPipeline(nestedModulePM);
+ break;
+ case IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize:
+ addSPIRVVectorizationPassPipeline(nestedModulePM);
+ break;
+ default:
+ llvm_unreachable("Unsupported pipeline on GPU target.");
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Using SPIRV Lowering Pass pipeline :\n";
+ executableLoweringPipeline.printAsTextualPipeline(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ if (failed(runPipeline(executableLoweringPipeline, variantOp))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
+createSPIRVLowerExecutableTargetPass() {
+ return std::make_unique<SPIRVLowerExecutableTargetPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp b/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp
new file mode 100644
index 0000000..50d0160
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVRemoveOneTripTiledLoops.cpp
@@ -0,0 +1,234 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-remove-one-trip-tiled-loops"
+
+namespace mlir {
+namespace iree_compiler {
+
+static int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; }
+
+namespace {
+/// Replaces hal.interface.workgroup.count op with the constant value chosen
+/// from tiling scheme.
+class ConcretizeWorkgroupCountOp final
+ : public OpRewritePattern<IREE::HAL::InterfaceWorkgroupCountOp> {
+ public:
+ ConcretizeWorkgroupCountOp(MLIRContext *context,
+ ArrayRef<int64_t> numWorkgroups,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit),
+ numWorkgroups(numWorkgroups.begin(), numWorkgroups.end()) {}
+
+ LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned dimIndex = op.dimension().getZExtValue();
+
+ if (dimIndex >= numWorkgroups.size()) return failure();
+ rewriter.replaceOpWithNewOp<ConstantOp>(
+ op, rewriter.getIndexAttr(numWorkgroups[dimIndex]));
+
+ return success();
+ }
+
+ private:
+ SmallVector<int64_t, 4> numWorkgroups;
+};
+
+// Canonicalizes away a trip-one scf.for loop by inlining its body and removing
+// the loop.
+//
+// This pattern is needed because in Flow abstract tiling and distribution we
+// will create scf.for loops that distribute workload cyclically. After
+// concretizing hal.interface.workgroup.* ops, these scf.for loops still remain,
+// and they will be of the form:
+//
+// %lb = mul %workgroup_id_{x|y|z}, %cst_tile_size_{x|y|z}
+// scf.for %iv = %lb to %cst_wokload_size_{x|y|z}
+// step %cst_workload_size_{x|y|z} { ... }
+//
+// Such scf.for loops can be inlined if %lb is smaller than upper bound.
+class RemoveTripOneLoop final : public OpRewritePattern<scf::ForOp> {
+ public:
+ RemoveTripOneLoop(MLIRContext *context, ArrayRef<int64_t> workloadSize,
+ ArrayRef<int64_t> workloadPerWorkgroup,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit),
+ workloadSize(workloadSize.begin(), workloadSize.end()),
+ workloadPerWorkgroup(workloadPerWorkgroup.begin(),
+ workloadPerWorkgroup.end()) {}
+
+ LogicalResult matchAndRewrite(scf::ForOp op,
+ PatternRewriter &rewriter) const override {
+ // Get constant upper bound and step values.
+ IntegerAttr ub, step;
+ if (!matchPattern(op.upperBound(), m_Constant(&ub)) ||
+ !matchPattern(op.step(), m_Constant(&step))) {
+ return failure();
+ }
+
+ // Require that they are the same.
+ if (ub != step) return failure();
+
+ // Now make sure the lower bound is smaller than upper bound. The lower
+ // bound should be multiplying the workgroup ID with some constant.
+ auto mulOp = op.lowerBound().getDefiningOp<AffineApplyOp>();
+ if (!mulOp || mulOp.mapOperands().size() != 1) return failure();
+
+ auto operand = mulOp.mapOperands().front();
+ auto idOp = operand.getDefiningOp<IREE::HAL::InterfaceWorkgroupIDOp>();
+ if (!idOp) return failure();
+
+ // We just need to make sure the max value of the workgroup ID multipled by
+ // the multipler is smaller than the upper bound to guarantee one trip.
+ unsigned dimIndex = idOp.dimension().getZExtValue();
+ int64_t dimSize = workloadSize[dimIndex];
+ if (dimSize == ShapedType::kDynamicSize) return failure();
+
+ int64_t dimTile = workloadPerWorkgroup[dimIndex];
+ AffineExpr symbol;
+ bindSymbols(op.getContext(), symbol);
+ auto mulMap = AffineMap::get(0, 1, symbol * dimTile);
+ if (mulOp.getAffineMap() != mulMap) return failure();
+
+ int64_t count = ceilDiv(dimSize, dimTile);
+ assert(count > 0 && "expected at least one tile!");
+
+ // ID should be in range [0, count).
+ if ((count - 1) * dimTile >= ub.getInt()) {
+ // Dead loop. It can actually be removed entirely. But we aren't expecting
+ // it to happen here. Do not canonicalize for such case.
+ return failure();
+ }
+
+ SmallVector<Value, 4> blockArgs;
+ blockArgs.reserve(op.getNumIterOperands() + 1);
+ blockArgs.push_back(op.lowerBound());
+ llvm::append_range(blockArgs, op.getIterOperands());
+
+ Block *block = &op.getLoopBody().front();
+ Operation *terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.mergeBlockBefore(block, op, blockArgs);
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+
+ return success();
+ }
+
+ private:
+ SmallVector<int64_t, 4> workloadSize;
+ SmallVector<int64_t, 4> workloadPerWorkgroup;
+};
+} // namespace
+
+static void removeOneTripTiledLoops(MLIRContext *context, FuncOp funcOp,
+ linalg::LinalgOp rootLinalgOp,
+ ArrayRef<int64_t> workloadPerWorkgroup) {
+ unsigned numParallelDims = getNumOuterParallelLoops(rootLinalgOp);
+ unsigned numTiledDims =
+ std::min<size_t>(numParallelDims, kNumMaxParallelDims);
+ ArrayRef<int64_t> outputShape = getUntiledResultShape(rootLinalgOp, 0);
+ if (outputShape.size() < numParallelDims) return;
+
+ // TODO(ravishankarm, antiagainst): Its pure co-incidence that the
+ // workload is derivable from the output shape. There is no requirement
+ // for this but is the case for all operations we are interested in.
+ auto workloadSize = llvm::to_vector<4>(llvm::reverse(
+ outputShape.take_front(numParallelDims).take_back(numTiledDims)));
+ if (llvm::any_of(workloadSize, [](int64_t dim) {
+ return dim == ShapedType::kDynamicSize;
+ })) {
+ return;
+ }
+ LLVM_DEBUG({
+ llvm::dbgs() << "Queried workload size: ";
+ llvm::interleaveComma(workloadSize, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ SmallVector<int64_t, 3> numWorkgroups;
+ assert(workloadPerWorkgroup.size() == workloadSize.size());
+ for (auto pair : llvm::zip(workloadSize, workloadPerWorkgroup)) {
+ auto workload = std::get<0>(pair);
+ auto size = std::get<1>(pair);
+ numWorkgroups.push_back(ceilDiv(workload, size));
+ }
+ numWorkgroups.resize(kNumMaxParallelDims, 1);
+ {
+ OwningRewritePatternList workgroupCountPatterns(context);
+ workgroupCountPatterns.insert<ConcretizeWorkgroupCountOp>(context,
+ numWorkgroups);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(workgroupCountPatterns));
+ }
+ {
+ OwningRewritePatternList removeTripOneLoopPatterns(context);
+ removeTripOneLoopPatterns.insert<RemoveTripOneLoop>(context, workloadSize,
+ workloadPerWorkgroup);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(removeTripOneLoopPatterns));
+ }
+}
+
+namespace {
+class SPIRVRemoveOneTripTiledLoopPass
+ : public SPIRVRemoveOneTripTiledLoopBase<SPIRVRemoveOneTripTiledLoopPass> {
+ public:
+ SPIRVRemoveOneTripTiledLoopPass() = default;
+ SPIRVRemoveOneTripTiledLoopPass(const SPIRVRemoveOneTripTiledLoopPass &) {}
+
+ void runOnOperation() {
+ FuncOp funcOp = getOperation();
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
+
+ // This pass seems to be only needed for the convolution vectorization. So
+ // filter out the necessary conv ops.
+ SmallVector<Operation *> rootOp;
+ SmallVector<Operation *> tiledLoops;
+ if (failed(getFilteredOps(
+ funcOp,
+ [](Operation *op) {
+ return isa<linalg::DepthwiseConvInputNHWCFilterHWCOp,
+ linalg::ConvInputNHWCFilterHWCFOp>(op);
+ },
+ rootOp, tiledLoops))) {
+ return;
+ }
+
+ if (!llvm::hasSingleElement(rootOp)) return;
+ IREE::HAL::TranslationInfo translationInfo =
+ getTranslationInfo(entryPointOp);
+ if (!translationInfo) return;
+ ArrayAttr workloadPerWorkgroupAttr = translationInfo.workloadPerWorkgroup();
+ if (!workloadPerWorkgroupAttr) return;
+ auto workloadPerWorkgroup = llvm::to_vector<4>(llvm::map_range(
+ workloadPerWorkgroupAttr,
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+
+ MLIRContext *context = &getContext();
+ removeOneTripTiledLoops(context, funcOp, cast<linalg::LinalgOp>(rootOp[0]),
+ workloadPerWorkgroup);
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVRemoveOneTripTiledLoopPass() {
+ return std::make_unique<SPIRVRemoveOneTripTiledLoopPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
new file mode 100644
index 0000000..a4750ce
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -0,0 +1,404 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVTileAndDistribute.cpp
+//------------------------------------------===//
+//
+// This pass tiles and vectorizes Linalg ops on buffers within in a single
+// workgroup.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Codegen/SPIRV/MemorySpace.h"
+#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
+
+#define DEBUG_TYPE "iree-spirv-tile-and-vectorize"
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Returns a Linalg marker that matches any of the `matchMarkers` and replaces
+/// it with `replaceMarker`.
+static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
+ ArrayRef<StringRef> matchMarkers, StringRef replaceMarker,
+ MLIRContext *context) {
+ SmallVector<Identifier, 2> markers;
+ markers.reserve(matchMarkers.size());
+ for (StringRef marker : matchMarkers) {
+ markers.emplace_back(Identifier::get(marker, context));
+ }
+ return linalg::LinalgTransformationFilter(
+ markers, Identifier::get(replaceMarker, context));
+}
+
+/// Converts a symbolic GPU processor dimension to its numeric one.
+static unsigned dimToIndex(StringRef dim) {
+ return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2);
+}
+
+//===----------------------------------------------------------------------===//
+// Main pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Function pass that implements tiling and fusion in Linalg on buffers.
+class SPIRVTileAndDistributePass
+ : public SPIRVTileAndDistributeBase<SPIRVTileAndDistributePass> {
+ public:
+ SPIRVTileAndDistributePass() = default;
+ SPIRVTileAndDistributePass(const SPIRVTileAndDistributePass &pass) = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect,
+ linalg::LinalgDialect, memref::MemRefDialect,
+ scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Patterns to tile computation to map to subgroups
+//===----------------------------------------------------------------------===//
+
+/// Computes the Value for subgroupID along each dimension given number of
+/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
+static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+ OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
+ Type indexType = builder.getIndexType();
+ Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
+ SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
+
+ // subgroupID
+ // = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
+ for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
+ Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
+ AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
+ AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
+ Value procId =
+ makeComposedAffineApply(builder, loc, d0 % s0, {subgroupId, nprocs});
+ procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
+ subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
+ }
+ return procInfo;
+}
+
+namespace {
+/// Pattern to tile linalg.matmul for subgroups.
+struct TileMatmulSubgroupPattern
+ : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+ using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+ TileMatmulSubgroupPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter marker,
+ PatternBenefit benefit = 1)
+ : Base(context, options, marker, benefit) {}
+};
+} // namespace
+
+/// Patterns for second level tiling to target subgroups.
+static void populateTilingToSubgroupPatterns(MLIRContext *context,
+ RewritePatternSet &patterns) {
+ auto getInnerTileSizeFn = [&](OpBuilder &builder,
+ Operation *operation) -> SmallVector<Value, 4> {
+ SmallVector<int64_t> tileSizes = getTileSizes(operation, 1);
+ return llvm::to_vector<4>(
+ llvm::map_range(tileSizes, [&](int64_t v) -> Value {
+ return builder.create<ConstantIndexOp>(operation->getLoc(), v);
+ }));
+ };
+
+ auto getSubgroupProcInfoFn = [&](OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ // TODO(ravishankarm): For now assume that there is always a single subgroup
+ std::array<int64_t, 3> numSubgroups = {1, 1, 1};
+ return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
+ };
+
+ linalg::LinalgLoopDistributionOptions subgroupDistributionOptions;
+ subgroupDistributionOptions.procInfo = getSubgroupProcInfoFn;
+ subgroupDistributionOptions.distributionMethod = {
+ {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+
+ patterns.insert<TileMatmulSubgroupPattern>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizeComputationFunction(getInnerTileSizeFn)
+ .setDistributionOptions(subgroupDistributionOptions),
+ getLinalgMatchAndReplaceMarker(
+ {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+ getVectorizeMarker(), context));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns and methods for thread tiling.
+//===----------------------------------------------------------------------===//
+
+/// Patterns for third level tiling to target invocations.
+static void populateTilingToInvocationPatterns(MLIRContext *context,
+ RewritePatternSet &patterns) {
+ linalg::TileSizeComputationFunction getInnerTileSizeFn =
+ [&](OpBuilder &builder, Operation *operation) {
+ SmallVector<int64_t> tileSizes = getTileSizes(operation, 2);
+ return llvm::to_vector<4>(
+ llvm::map_range(tileSizes, [&](int64_t v) -> Value {
+ return builder.create<ConstantIndexOp>(operation->getLoc(), v);
+ }));
+ };
+
+ auto getThreadProcInfoFn = [&](OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
+ builder, loc, parallelLoopRanges.size());
+ };
+ linalg::LinalgLoopDistributionOptions invocationDistributionOptions;
+ invocationDistributionOptions.procInfo = getThreadProcInfoFn;
+ invocationDistributionOptions.distributionMethod = {
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic}};
+
+ auto tilingOptions =
+ linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(getInnerTileSizeFn)
+ .setDistributionOptions(invocationDistributionOptions);
+
+ patterns.insert<
+ linalg::LinalgTilingPattern<linalg::MatmulOp>,
+ linalg::LinalgTilingPattern<linalg::FillOp>,
+ linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::GenericOp>,
+ linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
+ linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
+ linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>(
+ context, tilingOptions,
+ getLinalgMatchAndReplaceMarker(
+ {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+ getVectorizeMarker(), context));
+
+ patterns.insert<
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
+ context, tilingOptions,
+ getLinalgMatchAndReplaceMarker(
+ {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+ getConvFilterTileMarker(), context));
+}
+
+/// Returns the corresponding range for the given `processorValue` is a GPU
+/// thread id or block dim.
+static Optional<std::pair<AffineExpr, AffineExpr>> getThreadRange(
+ Value processorValue, SmallVectorImpl<Value> & /*dims*/,
+ SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupSize) {
+ if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) {
+ OpBuilder builder(processorValue.getContext());
+ unsigned index = dimToIndex(idOp.dimension());
+ AffineExpr zero = builder.getAffineConstantExpr(0);
+ AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupSize[index]);
+ return std::make_pair(zero, ubExpr - 1);
+ }
+ if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) {
+ OpBuilder builder(processorValue.getContext());
+ unsigned index = dimToIndex(dimOp.dimension());
+ AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]);
+ return std::make_pair(bound, bound);
+ }
+ return llvm::None;
+}
+
+//====---------------------------------------------------------------------===//
+// Patterns to tile convolution window dimensions
+//====---------------------------------------------------------------------===//
+
+static void populateTilingConvFilterPatterns(
+ MLIRContext *context, RewritePatternSet &patterns,
+ linalg::LinalgTransformationFilter marker) {
+ auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) {
+ SmallVector<Value, 4> tileSizes;
+ SmallVector<int64_t, 4> fourthLevel = getTileSizes(op, 3);
+ tileSizes.reserve(fourthLevel.size());
+
+ Location loc = op->getLoc();
+ for (int64_t size : fourthLevel) {
+ tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size));
+ }
+ return tileSizes;
+ };
+
+ auto tilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(getTileSizeFn);
+
+ patterns.insert<
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
+ context, tilingOptions, marker);
+}
+
+//====---------------------------------------------------------------------===//
+// Patterns to lower linalg ops to loops
+//====---------------------------------------------------------------------===//
+
+template <typename OpTy>
+struct LowerToLoops final : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Only handle the cases where tiling to invocations was done, where tiling
+ // convolution filters or vectorization is expected.
+ if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()}))
+ return failure();
+
+ if (linalg::linalgOpToLoops(rewriter, op)) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+//====---------------------------------------------------------------------===//
+// Main pass implementation
+//====---------------------------------------------------------------------===//
+
+void SPIRVTileAndDistributePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ FuncOp funcOp = getOperation();
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
+
+ {
+ RewritePatternSet thirdLevelTilingPatterns(&getContext());
+ populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(thirdLevelTilingPatterns));
+
+ // Remove trip-one loops created during cyclic loop distribution if we can
+ // prove the tiling was perfect.
+ RewritePatternSet canoncalizationPatterns(context);
+ populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
+ auto workgroupSize = getWorkgroupSize(entryPointOp);
+ auto getThreadRangeFn = [workgroupSize](Value processorValue,
+ SmallVectorImpl<Value> &dims,
+ SmallVectorImpl<Value> &symbols) {
+ return getThreadRange(processorValue, dims, symbols, workgroupSize);
+ };
+ populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
+ getThreadRangeFn);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canoncalizationPatterns));
+
+ // Perform generic canonicalization.
+ RewritePatternSet threadLevelTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateAffineMinCanonicalizationPattern(
+ threadLevelTilingCanonicalizationPatterns);
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadLevelTilingCanonicalizationPatterns));
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling to invocations ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
+ RewritePatternSet tilingPatterns(&getContext());
+ auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(),
+ getVectorizeMarker(), context);
+ populateTilingConvFilterPatterns(context, tilingPatterns, marker);
+ populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
+ tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
+
+ RewritePatternSet convTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateAffineMinCanonicalizationPattern(
+ convTilingCanonicalizationPatterns);
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(convTilingCanonicalizationPatterns));
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling convolution filter ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ // Lower ops that were tiled to invocations but not vectorized to loops.
+ // TODO(antiagainst): This is here now to simplify the interaction with
+ // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
+ // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
+ // directly.
+ {
+ RewritePatternSet patterns(context);
+ patterns.add<LowerToLoops<linalg::BatchMatmulOp>,
+ LowerToLoops<linalg::ConvInputNWCFilterWCFOp>,
+ LowerToLoops<linalg::ConvInputNHWCFilterHWCFOp>,
+ LowerToLoops<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+ LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCOp>,
+ LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
+ LowerToLoops<linalg::MatmulOp>,
+ LowerToLoops<linalg::PoolingNhwcMaxOp>,
+ LowerToLoops<linalg::PoolingNhwcMinOp>,
+ LowerToLoops<linalg::PoolingNhwcSumOp>>(context);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Pass entry point and registration
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndDistributePass() {
+ return std::make_unique<SPIRVTileAndDistributePass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp
index 04cc42b..05cc9fd 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp
@@ -79,10 +79,8 @@
class SPIRVTileAndVectorizePass
: public SPIRVTileAndVectorizeBase<SPIRVTileAndVectorizePass> {
public:
- SPIRVTileAndVectorizePass(const SPIRVCodegenOptions &passOptions)
- : options(passOptions) {}
- SPIRVTileAndVectorizePass(const SPIRVTileAndVectorizePass &pass)
- : options(pass.options) {}
+ SPIRVTileAndVectorizePass() = default;
+ SPIRVTileAndVectorizePass(const SPIRVTileAndVectorizePass &pass) = default;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect,
@@ -91,9 +89,6 @@
}
void runOnOperation() override;
-
- private:
- SPIRVCodegenOptions options;
};
} // namespace
@@ -198,27 +193,20 @@
/// Patterns for second level tiling to target subgroups.
static void populateTilingToSubgroupPatterns(MLIRContext *context,
- const LaunchConfig &launchConfig,
RewritePatternSet &patterns) {
- auto getInnerTileSizeFn = [&launchConfig](
- OpBuilder &builder,
+ auto getInnerTileSizeFn = [&](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
- ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 1);
- if (tileSizes.empty()) return {};
- SmallVector<Value, 4> tileSizesVal;
- tileSizesVal.reserve(tileSizes.size());
- for (auto val : tileSizes) {
- tileSizesVal.push_back(
- builder.create<ConstantIndexOp>(operation->getLoc(), val));
- }
- return tileSizesVal;
+ SmallVector<int64_t> tileSizes = getTileSizes(operation, 1);
+ return llvm::to_vector<4>(
+ llvm::map_range(tileSizes, [&](int64_t v) -> Value {
+ return builder.create<ConstantIndexOp>(operation->getLoc(), v);
+ }));
};
- auto getSubgroupProcInfoFn = [&launchConfig](
- OpBuilder &builder, Location loc,
+ auto getSubgroupProcInfoFn = [&](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
- ArrayRef<int64_t> numSubgroups =
- launchConfig.getNumSubgroups().take_front(parallelLoopRanges.size());
+ // TODO(ravishankarm): For now assume that there is always a single subgroup
+ std::array<int64_t, 3> numSubgroups = {1, 1, 1};
return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
};
@@ -245,19 +233,14 @@
/// Patterns for third level tiling to target invocations.
static void populateTilingToInvocationPatterns(MLIRContext *context,
- const LaunchConfig &launchConfig,
RewritePatternSet &patterns) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
- [&launchConfig](OpBuilder &builder, Operation *operation) {
- ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 2);
- if (tileSizes.empty()) return SmallVector<Value, 4>();
- SmallVector<Value, 4> tileSizesVal;
- tileSizesVal.reserve(tileSizes.size());
- for (auto val : tileSizes) {
- tileSizesVal.push_back(
- builder.create<ConstantIndexOp>(operation->getLoc(), val));
- }
- return tileSizesVal;
+ [&](OpBuilder &builder, Operation *operation) {
+ SmallVector<int64_t> tileSizes = getTileSizes(operation, 2);
+ return llvm::to_vector<4>(
+ llvm::map_range(tileSizes, [&](int64_t v) -> Value {
+ return builder.create<ConstantIndexOp>(operation->getLoc(), v);
+ }));
};
auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
@@ -328,7 +311,6 @@
//====---------------------------------------------------------------------===//
static void populateVectorizationPatterns(MLIRContext *context,
- const LaunchConfig &launchConfig,
RewritePatternSet &patterns) {
linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp,
linalg::ContractionOpInterface>(
@@ -471,11 +453,10 @@
static void populateTilingConvFilterPatterns(
MLIRContext *context, RewritePatternSet &patterns,
- const LaunchConfig &launchConfig,
linalg::LinalgTransformationFilter marker) {
- auto getTileSizeFn = [&launchConfig](OpBuilder &builder, Operation *op) {
+ auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) {
SmallVector<Value, 4> tileSizes;
- ArrayRef<int64_t> fourthLevel = launchConfig.getTileSizes(op, 3);
+ SmallVector<int64_t, 4> fourthLevel = getTileSizes(op, 3);
tileSizes.reserve(fourthLevel.size());
Location loc = op->getLoc();
@@ -526,209 +507,137 @@
void SPIRVTileAndVectorizePass::runOnOperation() {
MLIRContext *context = &getContext();
- IREE::HAL::ExecutableVariantOp variantOp = getOperation();
- ModuleOp module = variantOp.getInnerModule();
+ FuncOp funcOp = getOperation();
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
- for (FuncOp funcOp : module.getOps<FuncOp>()) {
- if (!isEntryPoint(funcOp)) continue;
+ // TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be
+ // controlled by vectorization. This is needed due to historical reasons.
+ // Change the second level tiling to cyclic to loops and remove this.
+ RewritePatternSet secondLevelTilingPatterns(&getContext());
+ populateTilingToSubgroupPatterns(context, secondLevelTilingPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(secondLevelTilingPatterns));
- SmallVector<linalg::LinalgOp, 4> linalgOps;
- SmallVector<Operation *, 4> tiledLoops;
+ RewritePatternSet secondLevelTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateAffineMinCanonicalizationPattern(
+ secondLevelTilingCanonicalizationPatterns);
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(secondLevelTilingCanonicalizationPatterns));
+ promoteSingleIterationLoops(funcOp);
- if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) {
- // Nothing to do here.
- continue;
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling to subgroups ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ {
+ RewritePatternSet thirdLevelTilingPatterns(&getContext());
+ populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(thirdLevelTilingPatterns));
+
+ // Remove trip-one loops created during cyclic loop distribution if we can
+ // prove the tiling was perfect.
+ RewritePatternSet canoncalizationPatterns(context);
+ populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
+ SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
+ if (workgroupSize.empty()) {
+ entryPointOp.emitError("expected to have workgroup_size attribute");
+ return signalPassFailure();
}
+ auto getThreadRangeFn = [workgroupSize](Value processorValue,
+ SmallVectorImpl<Value> &dims,
+ SmallVectorImpl<Value> &symbols) {
+ return getThreadRange(processorValue, dims, symbols, workgroupSize);
+ };
+ populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
+ getThreadRangeFn);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(canoncalizationPatterns));
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- Optional<LaunchConfig> launchConfigOpt =
- initGPULaunchConfig(context, dependenceGraph, options, linalgOps);
- if (!launchConfigOpt) {
- // No configuration to tile and vectorize. Nothing to do here.
- continue;
- }
- LaunchConfig &launchConfig = *launchConfigOpt;
+ // Perform generic canonicalization.
+ RewritePatternSet threadLevelTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateAffineMinCanonicalizationPattern(
+ threadLevelTilingCanonicalizationPatterns);
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadLevelTilingCanonicalizationPatterns));
LLVM_DEBUG({
- llvm::dbgs() << "\n--- Linalg tile configuration ---\n";
- llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
- interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
- llvm::dbgs() << "]\n";
- for (auto op : linalgOps) {
- llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
- TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op);
- llvm::dbgs() << "{";
- std::string sep = "";
- for (auto &level : enumerate(tileSizes)) {
- llvm::dbgs() << sep << level.index() << " : [";
- sep = ", ";
- interleaveComma(level.value(), llvm::dbgs());
- llvm::dbgs() << "]";
- }
- llvm::dbgs() << "}\n";
- }
+ llvm::dbgs() << "--- After tiling to invocations ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
});
+ }
- if (options.useWorkgroupMemory) {
- // The promotion patterns are put separate from the tiling patterns to
- // make sure that the allocated scratchspace memory is constant sizes
- // which requires some folding to trigger.
- RewritePatternSet promotionPatterns(&getContext());
- populatePromotionPatterns(context, promotionPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
+ {
+ RewritePatternSet tilingPatterns(&getContext());
+ auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(),
+ getVectorizeMarker(), context);
+ populateTilingConvFilterPatterns(context, tilingPatterns, marker);
+ populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
+ tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
- RewritePatternSet promotionCanonicalizationPatterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- populateAffineMinCanonicalizationPattern(
- promotionCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(promotionCanonicalizationPatterns));
+ RewritePatternSet convTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateAffineMinCanonicalizationPattern(
+ convTilingCanonicalizationPatterns);
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(convTilingCanonicalizationPatterns));
- LLVM_DEBUG({
- llvm::dbgs() << "--- After workgroup memory promotion ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling convolution filter ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
- // TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be
- // controlled by vectorization. This is needed due to historical reasons.
- // Change the second level tiling to cyclic to loops and remove this.
- if (launchConfig.useVectorize()) {
- RewritePatternSet secondLevelTilingPatterns(&getContext());
- populateTilingToSubgroupPatterns(context, launchConfig,
- secondLevelTilingPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(secondLevelTilingPatterns));
+ {
+ RewritePatternSet vectorizationPatterns(&getContext());
+ populateVectorizationPatterns(context, vectorizationPatterns);
+ populateLinalgToVectorVectorizeConvPatterns(context, vectorizationPatterns);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorizationPatterns));
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After vectorization ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
- RewritePatternSet secondLevelTilingCanonicalizationPatterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- populateAffineMinCanonicalizationPattern(
- secondLevelTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(secondLevelTilingCanonicalizationPatterns));
- promoteSingleIterationLoops(funcOp);
+ // TODO: This should be a folding of Add into Contract in core but while
+ // they live in different dialects, it is not possible without unnatural
+ // dependencies.
+ funcOp.walk([&](Operation *op) {
+ if (auto contract = canonicalizeContractionAdd(op))
+ op->replaceAllUsesWith(contract);
+ });
- LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling to subgroups ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
+ applyVectorTransformation(funcOp);
- {
- RewritePatternSet thirdLevelTilingPatterns(&getContext());
- populateTilingToInvocationPatterns(context, launchConfig,
- thirdLevelTilingPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(thirdLevelTilingPatterns));
-
- // Remove trip-one loops created during cyclic loop distribution if we can
- // prove the tiling was perfect.
- RewritePatternSet canoncalizationPatterns(context);
- populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
- ArrayRef<int64_t> workgroupSize = launchConfig.getWorkgroupSize();
- auto getThreadRangeFn = [workgroupSize](Value processorValue,
- SmallVectorImpl<Value> &dims,
- SmallVectorImpl<Value> &symbols) {
- return getThreadRange(processorValue, dims, symbols, workgroupSize);
- };
- populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
- getThreadRangeFn);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canoncalizationPatterns));
-
- // Perform generic canonicalization.
- RewritePatternSet threadLevelTilingCanonicalizationPatterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- populateAffineMinCanonicalizationPattern(
- threadLevelTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(threadLevelTilingCanonicalizationPatterns));
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling to invocations ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
-
- {
- RewritePatternSet tilingPatterns(&getContext());
- auto marker = getLinalgMatchAndReplaceMarker(
- getConvFilterTileMarker(), getVectorizeMarker(), context);
- populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
- marker);
- populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
- tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
- context);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
-
- RewritePatternSet convTilingCanonicalizationPatterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- populateAffineMinCanonicalizationPattern(
- convTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(convTilingCanonicalizationPatterns));
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling convolution filter ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
-
- if (launchConfig.useVectorize()) {
- {
- RewritePatternSet vectorizationPatterns(&getContext());
- populateVectorizationPatterns(context, launchConfig,
- vectorizationPatterns);
- populateLinalgToVectorVectorizeConvPatterns(context,
- vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
- LLVM_DEBUG({
- llvm::dbgs() << "--- After vectorization ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
- }
-
- // TODO: This should be a folding of Add into Contract in core but while
- // they live in different dialects, it is not possible without unnatural
- // dependencies.
- funcOp.walk([&](Operation *op) {
- if (auto contract = canonicalizeContractionAdd(op))
- op->replaceAllUsesWith(contract);
- });
-
- applyVectorTransformation(funcOp);
- }
-
- // Lower ops that were tiled to invocations but not vectorized to loops.
- // TODO(antiagainst): This is here now to simplify the interaction with
- // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
- // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
- // directly.
- {
- RewritePatternSet patterns(context);
- patterns
- .add<LowerToLoops<linalg::BatchMatmulOp>,
- LowerToLoops<linalg::ConvInputNWCFilterWCFOp>,
- LowerToLoops<linalg::ConvInputNHWCFilterHWCFOp>,
- LowerToLoops<linalg::ConvInputNDHWCFilterDHWCFOp>,
- LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
- LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCOp>,
- LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
- LowerToLoops<linalg::MatmulOp>,
- LowerToLoops<linalg::PoolingNhwcMaxOp>,
- LowerToLoops<linalg::PoolingNhwcMinOp>,
- LowerToLoops<linalg::PoolingNhwcSumOp>>(context);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
- }
-
- launchConfig.finalize(funcOp);
+ // Lower ops that were tiled to invocations but not vectorized to loops.
+ // TODO(antiagainst): This is here now to simplify the interaction with
+ // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
+ // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
+ // directly.
+ {
+ RewritePatternSet patterns(context);
+ patterns.add<LowerToLoops<linalg::BatchMatmulOp>,
+ LowerToLoops<linalg::ConvInputNWCFilterWCFOp>,
+ LowerToLoops<linalg::ConvInputNHWCFilterHWCFOp>,
+ LowerToLoops<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
+ LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCOp>,
+ LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
+ LowerToLoops<linalg::MatmulOp>,
+ LowerToLoops<linalg::PoolingNhwcMaxOp>,
+ LowerToLoops<linalg::PoolingNhwcMinOp>,
+ LowerToLoops<linalg::PoolingNhwcSumOp>>(context);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
}
@@ -736,9 +645,8 @@
// Pass entry point and registration
//===----------------------------------------------------------------------===//
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVTileAndVectorizePass(const SPIRVCodegenOptions &options) {
- return std::make_unique<SPIRVTileAndVectorizePass>(options);
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass() {
+ return std::make_unique<SPIRVTileAndVectorizePass>();
}
} // namespace iree_compiler
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
new file mode 100644
index 0000000..108bf96
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVVectorize.cpp -------------------------------------------------===//
+//
+// This pass vectorizes Linalg ops on buffers within in a single workgroup.
+//
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "iree-spirv-vectorize"
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 3b9ea7f..0f3bc85 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -19,16 +19,13 @@
name = "lit",
srcs = enforce_glob(
[
- "concretize_workgroup_tiles.mlir",
- "concretize_workgroup_tiles_dynamic.mlir",
"convert_to_gpu.mlir",
"convert_to_spirv.mlir",
"fold_gpu_procid_uses.mlir",
- "materialize_launch_configuration.mlir",
- "materialize_launch_configuration2.mlir",
"pipeline_matmul_cooperative_matrix.mlir",
"pipeline_matmul_vectorization.mlir",
"promote_workgroup_memory.mlir",
+ "remove_one_trip_tiled_loop.mlir",
"tile_and_vectorize.mlir",
"tile_and_vectorize_batch_matmul.mlir",
"tile_and_vectorize_conv.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index fa0818e..cc4083a 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -14,16 +14,13 @@
NAME
lit
SRCS
- "concretize_workgroup_tiles.mlir"
- "concretize_workgroup_tiles_dynamic.mlir"
"convert_to_gpu.mlir"
"convert_to_spirv.mlir"
"fold_gpu_procid_uses.mlir"
- "materialize_launch_configuration.mlir"
- "materialize_launch_configuration2.mlir"
"pipeline_matmul_cooperative_matrix.mlir"
"pipeline_matmul_vectorization.mlir"
"promote_workgroup_memory.mlir"
+ "remove_one_trip_tiled_loop.mlir"
"tile_and_vectorize.mlir"
"tile_and_vectorize_batch_matmul.mlir"
"tile_and_vectorize_conv.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles.mlir b/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles.mlir
deleted file mode 100644
index 8d2acc4..0000000
--- a/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles.mlir
+++ /dev/null
@@ -1,123 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=0,4,4,16 -iree-spirv-workgroup-size=4,4,1 -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles))' -canonicalize -cse %s | IreeFileCheck %s
-
-hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @conv2d_static_shape attributes {
- interface = @io,
- ordinal = 0 : index
- }
- module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
- func @conv2d_static_shape() {
- %cst = constant 0.000000e+00 : f32
- %c32 = constant 32 : index
- %c112 = constant 112 : index
- %c0 = constant 0 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<1x225x225x16xf32>
- %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x3x16x32xf32>
- %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<1x112x112x32xf32>
- %workgroup_size_x = hal.interface.workgroup.size[0] : index
- %workgroup_size_y = hal.interface.workgroup.size[1] : index
- %workgroup_size_z = hal.interface.workgroup.size[2] : index
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_count_x = hal.interface.workgroup.count[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %workgroup_id_z = hal.interface.workgroup.id[2] : index
- %workgroup_count_z = hal.interface.workgroup.count[2] : index
- %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
- %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
- scf.for %arg0 = %3 to %c112 step %4 {
- %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
- %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
- scf.for %arg1 = %5 to %c112 step %6 {
- %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
- %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
- scf.for %arg2 = %7 to %c32 step %8 {
- %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
- %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z]
- %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
- %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y]
- %13 = memref.subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>
- %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
- %15 = memref.subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>
- %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
- %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
- %18 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
- linalg.fill(%cst, %18) : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
- linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>)
- }
- }
- }
- return
- }
- hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- }
- }
-}
-
-// Check that for a fully static shaped dispatch region, we can:
-// 1) Generate static constant workgroup counts,
-// 2) Replace hal.interface.workgroup.{size|count} ops with constants,
-// 3) Canonicalize loops and memref.subview ops.
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 8)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (9, s0 * -8 + 225)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<()[s0] -> (16, s0 * -16 + 32)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0] -> (4, s0 * -4 + 112)>
-
-// CHECK: hal.executable.entry_point @conv2d_static_shape
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C28:.+]] = constant 28 : index
-// CHECK: hal.return %[[C2]], %[[C28]], %[[C28]] : index, index, index
-
-// CHECK: func @conv2d_static_shape()
-// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>}
-
-// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @io::@arg0
-// CHECK-DAG: %[[FILTER:.+]] = hal.interface.binding.subspan @io::@arg1
-// CHECK-DAG: %[[OUTPUT:.+]] = hal.interface.binding.subspan @io::@ret0
-
-// CHECK-DAG: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index
-// CHECK-DAG: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index
-// CHECK-DAG: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index
-
-// CHECK-DAG: %[[OUTPUT_OFFSET_Z:.+]] = affine.apply #[[MAP0]]()[%[[ID_Z]]]
-// CHECK-DAG: %[[OUTPUT_OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[ID_Y]]]
-// CHECK-DAG: %[[OUTPUT_OFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[ID_X]]]
-// CHECK-DAG: %[[INPUT_OFFSET_Z:.+]] = affine.apply #[[MAP2]]()[%[[ID_Z]]]
-// CHECK-DAG: %[[INPUT_SIZE_Z:.+]] = affine.min #[[MAP3]]()[%[[ID_Z]]]
-// CHECK-DAG: %[[INPUT_OFFSET_Y:.+]] = affine.apply #[[MAP2]]()[%[[ID_Y]]]
-// CHECK-DAG: %[[INPUT_SIZE_Y:.+]] = affine.min #[[MAP3]]()[%[[ID_Y]]]
-
-// CHECK: %[[INPUT_VIEW:.+]] = memref.subview %[[INPUT]]
-// CHECK-SAME: [0, %[[INPUT_OFFSET_Z]], %[[INPUT_OFFSET_Y]], 0]
-// CHECK-SAME: [1, %[[INPUT_SIZE_Z]], %[[INPUT_SIZE_Y]], 16] [1, 1, 1, 1]
-// CHECK-SAME: memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}>
-
-// CHECK: %[[OUTPUT_SIZE_X:.+]] = affine.min #[[MAP5]]()[%[[ID_X]]]
-// CHECK: %[[FILTER_VIEW:.+]] = memref.subview %[[FILTER]]
-// CHECK-SAME: [0, 0, 0, %[[OUTPUT_OFFSET_X]]] [3, 3, 16, %[[OUTPUT_SIZE_X]]]
-// CHECK-SAME: memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}>
-
-// CHECK-DAG: %[[OUTPUT_SIZE_Z:.+]] = affine.min #[[MAP7]]()[%[[ID_Z]]]
-// CHECK-DAG: %[[OUTPUT_SIZE_Y:.+]] = affine.min #[[MAP7]]()[%[[ID_Y]]]
-// CHECK: %[[OUTPUT_VIEW:.+]] = memref.subview %[[OUTPUT]]
-// CHECK-SAME: [0, %[[OUTPUT_OFFSET_Z]], %[[OUTPUT_OFFSET_Y]], %[[OUTPUT_OFFSET_X]]]
-// CHECK-SAME: [1, %[[OUTPUT_SIZE_Z]], %[[OUTPUT_SIZE_Y]], %[[OUTPUT_SIZE_X]]]
-// CHECK-SAME: memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}>
-
-// CHECK: linalg.fill(%{{.+}}, %[[OUTPUT_VIEW]])
-// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
-// CHECK-SAME: ins(%[[INPUT_VIEW]], %[[FILTER_VIEW]] : memref<1x?x?x16xf32, #map{{[0-9]+}}>, memref<3x3x16x?xf32, #map{{[0-9]+}}>)
-// CHECK-SAME: outs(%[[OUTPUT_VIEW]] : memref<1x?x?x?xf32, #map{{[0-9]+}}>)
diff --git a/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles_dynamic.mlir b/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles_dynamic.mlir
deleted file mode 100644
index 315fb57..0000000
--- a/iree/compiler/Codegen/SPIRV/test/concretize_workgroup_tiles_dynamic.mlir
+++ /dev/null
@@ -1,118 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=4,16 -iree-spirv-workgroup-size=4,4,1 -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles))' -canonicalize -cse %s | IreeFileCheck %s
-
-hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @matmul_dynamic_shape attributes {
- interface = @io,
- ordinal = 0 : index
- }
- module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
- func @matmul_dynamic_shape() {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %0 = hal.interface.load.constant offset = 0 : index
- %1 = hal.interface.load.constant offset = 1 : index
- %2 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
- %3 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?xf32>
- %4 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
- %5 = hal.interface.load.constant offset = 2 : index
- %6 = hal.interface.load.constant offset = 3 : index
- %7 = hal.interface.load.constant offset = 4 : index
- %8 = hal.interface.load.constant offset = 5 : index
- %9 = hal.interface.load.constant offset = 6 : index
- %10 = hal.interface.load.constant offset = 7 : index
- %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]>
- %12 = shapex.tie_shape %2, %11 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]>
- %14 = shapex.tie_shape %3, %13 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]>
- %16 = shapex.tie_shape %4, %15 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- %workgroup_size_x = hal.interface.workgroup.size[0] : index
- %workgroup_size_y = hal.interface.workgroup.size[1] : index
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_count_x = hal.interface.workgroup.count[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
- %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
- scf.for %arg0 = %17 to %5 step %18 {
- %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
- %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
- scf.for %arg1 = %19 to %8 step %20 {
- %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y]
- %22 = memref.subview %12[%arg0, 0] [%21, %6] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x]
- %24 = memref.subview %14[0, %arg1] [%7, %23] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y]
- %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x]
- %27 = memref.subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- linalg.fill(%cst, %27) {__internal_linalg_transform__ = "workgroup"} : f32, memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) outs(%27 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>)
- }
- }
- return
- }
- hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- }
- }
-}
-
-// Check that for a fully dynamic shaped dispatch region, we can:
-// 1) Generate symbolic workgroup counts,
-// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants.
-
-// CHECK-DAG: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK-DAG: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-// CHECK-DAG: #[[MUL16MAP:.+]] = affine_map<()[s0] -> (s0 * 16)>
-// CHECK-DAG: #[[MUL4MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-// CHECK-DAG: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-
-// CHECK: hal.executable.entry_point @matmul_dynamic_shape
-// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index):
-// CHECK: %c1 = constant 1 : index
-// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]]
-// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]]
-// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1
-
-// CHECK: func @matmul_dynamic_shape()
-// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>}
-
-// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index
-// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index
-// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index
-// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index
-// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index
-// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index
-
-// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index
-// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index
-// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index
-// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index
-
-// CHECK: %[[Y_LB:.+]] = affine.apply #[[MUL4MAP]]()[%[[ID_Y]]]
-// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MUL4MAP]]()[%[[COUNT_Y]]]
-// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]]
-// CHECK: %[[X_LB:.+]] = affine.apply #[[MUL16MAP]]()[%[[ID_X]]]
-// CHECK: %[[X_STEP:.+]] = affine.apply #[[MUL16MAP]]()[%[[COUNT_X]]]
-// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]]
-// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]]]
-// CHECK: %[[A_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref<?x?xf32> to memref<?x?xf32, {{.+}}>
-// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]]]
-// CHECK: %[[B_TILE:.+]] = memref.subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref<?x?xf32> to memref<?x?xf32, {{.+}}>
-// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]]]
-// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]]]
-// CHECK: %[[C_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref<?x?xf32> to memref<?x?xf32, {{.+}}>
-// CHECK: linalg.fill(%cst, %[[C_TILE]])
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[A_TILE]], %[[B_TILE]]
-// CHECK-SAME: outs(%[[C_TILE]]
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir b/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir
index f20b249..729557e 100644
--- a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-convert-to-gpu))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-convert-to-gpu))))' -canonicalize -cse %s | IreeFileCheck %s
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
hal.executable @parallel_4D attributes {sym_visibility = "private"} {
@@ -46,7 +46,6 @@
}
}
// CHECK-LABEL: func @parallel_4D
-// CHECK-SAME: local_size = dense<[32, 1, 1]>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
@@ -119,14 +118,7 @@
}
}
}
-// CHECK: #[[COUNT_MAP:.+]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) ceildiv 32)>
-// CHECK: hal.executable.entry_point @parallel_4D_static
-// CHECK: ^{{.*}}(%[[WORKLOAD_X:.+]]: index, %[[WORKLOAD_Y:.+]]: index, %[[WORKLOAD_Z:.+]]: index):
-// CHECK-DAG: %[[C1:.+]] = constant 1
-// CHECK-DAG: %[[COUNT:.+]] = affine.apply #[[COUNT_MAP]]()[%[[WORKLOAD_X]], %[[WORKLOAD_Y]], %[[WORKLOAD_Z]]]
-// CHECK: hal.return %[[COUNT]], %[[C1]], %[[C1]]
-// CHECK-LABEL: func @parallel_4D_static()
-// CHECK-SAME: local_size = dense<[32, 1, 1]>
+// CHECK-LABEL: func @parallel_4D_static()
// CHECK-DAG: %[[C360:.+]] = constant 360 : index
// CHECK-DAG: %[[C120:.+]] = constant 120 : index
// CHECK-DAG: %[[C30:.+]] = constant 30 : index
@@ -196,12 +188,6 @@
}
}
}
-// CHECK: #[[COUNT_MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
-// CHECK: hal.executable.entry_point @scalar_add
-// CHECK: ^{{.*}}(%[[WORKLOAD_X:.+]]: index, %[[WORKLOAD_Y:.+]]: index, %[[WORKLOAD_Z:.+]]: index):
-// CHECK-DAG: %[[C1:.+]] = constant 1
-// CHECK-DAG: %[[COUNT:.+]] = affine.apply #[[COUNT_MAP]]()[%[[WORKLOAD_X]], %[[WORKLOAD_Y]], %[[WORKLOAD_Z]]]
-// CHECK: hal.return %[[COUNT]], %[[C1]], %[[C1]]
// CHECK-LABEL: func @scalar_add()
// CHECK: load
// CHECK-NEXT: load
@@ -257,14 +243,7 @@
}
}
}
-// CHECK: #[[COUNT_MAP:.+]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) ceildiv 32)>
-// CHECK: hal.executable.entry_point @reduce_sum
-// CHECK: ^{{.*}}(%[[WORKLOAD_X:.+]]: index, %[[WORKLOAD_Y:.+]]: index, %[[WORKLOAD_Z:.+]]: index):
-// CHECK-DAG: %[[C1:.+]] = constant 1
-// CHECK-DAG: %[[COUNT:.+]] = affine.apply #[[COUNT_MAP]]()[%[[WORKLOAD_X]], %[[WORKLOAD_Y]], %[[WORKLOAD_Z]]]
-// CHECK: hal.return %[[COUNT]], %[[C1]], %[[C1]]
//CHECK-LABEL: func @reduce_sum
-// CHECK-SAME: local_size = dense<[32, 1, 1]> : vector<3xi32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C40:.+]] = constant 40 : index
// CHECK-DAG: %[[C50:.+]] = constant 50 : index
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index a7d78af..5c73cec 100644
--- a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -1,133 +1,187 @@
-// RUN: iree-opt -split-input-file -iree-convert-to-spirv %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(iree-convert-to-spirv)))' %s | IreeFileCheck %s
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
- // CHECK-LABEL: spv.module
- // CHECK: spv.GlobalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
- // CHECK: spv.func @push_constant()
- func @push_constant() {
- // CHECK: %[[INDEX_0:.+]] = spv.Constant 0 : i32
- // CHECK: %[[INDEX_1:.+]] = spv.Constant 2 : i32
- // CHECK: %[[ADDR:.+]] = spv.mlir.addressof @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
- // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
- // CHECK: spv.Load "PushConstant" %[[AC]] : i32
- %0 = hal.interface.load.constant offset = 2 : index
- return
- }
-
+hal.executable @push_constant attributes {sym_visibility = "private"} {
hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @push_constant attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ // CHECK-LABEL: spv.module
+ // CHECK: spv.GlobalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+ // CHECK: spv.func @push_constant()
+ func @push_constant() {
+ // CHECK: %[[INDEX_0:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[INDEX_1:.+]] = spv.Constant 2 : i32
+ // CHECK: %[[ADDR:.+]] = spv.mlir.addressof @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+ // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+ // CHECK: spv.Load "PushConstant" %[[AC]] : i32
+ %0 = hal.interface.load.constant offset = 2 : index
+ return
+ }
+
+ hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+ }
+ }
}
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
- // CHECK-LABEL: spv.module
- // CHECK: spv.GlobalVariable @[[ARG0:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[ARG1_0:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[ARG1_1:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[RET0:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.func @resource_bindings_in_same_entry_func()
- func @resource_bindings_in_same_entry_func() {
- %c0 = constant 0 : index
-
- // Same type
- // CHECK: spv.mlir.addressof @[[ARG0]]
- // CHECK: spv.mlir.addressof @[[ARG0]]
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
- %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
-
- // Different type
- // CHECK: spv.mlir.addressof @[[ARG1_0]]
- // CHECK: spv.mlir.addressof @[[ARG1_1]]
- %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4x4xf32>
- %3 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4xvector<4xf32>>
-
- // CHECK: spv.mlir.addressof @[[RET0]]
- %4 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32>
-
- %5 = memref.load %0[%c0, %c0] : memref<4x4xf32>
- %6 = memref.load %1[%c0, %c0] : memref<4x4xf32>
-
- %7 = memref.load %2[%c0, %c0] : memref<4x4xf32>
- %8 = memref.load %3[%c0] : memref<4xvector<4xf32>>
-
- %9 = memref.load %4[%c0, %c0] : memref<4x4xf32>
-
- return
- }
-
+hal.executable @resource_bindings_in_same_func attributes {sym_visibility = "private"} {
hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=1, binding=3, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @resource_bindings_in_same_func attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ // CHECK-LABEL: spv.module
+ // CHECK: spv.GlobalVariable @[[ARG0:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[ARG1_0:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[ARG1_1:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[RET0:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.func @resource_bindings_in_same_entry_func()
+ func @resource_bindings_in_same_entry_func() {
+ %c0 = constant 0 : index
+
+ // Same type
+ // CHECK: spv.mlir.addressof @[[ARG0]]
+ // CHECK: spv.mlir.addressof @[[ARG0]]
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
+ %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
+
+ // Different type
+ // CHECK: spv.mlir.addressof @[[ARG1_0]]
+ // CHECK: spv.mlir.addressof @[[ARG1_1]]
+ %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4x4xf32>
+ %3 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4xvector<4xf32>>
+
+ // CHECK: spv.mlir.addressof @[[RET0]]
+ %4 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32>
+
+ %5 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+ %6 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+
+ %7 = memref.load %2[%c0, %c0] : memref<4x4xf32>
+ %8 = memref.load %3[%c0] : memref<4xvector<4xf32>>
+
+ %9 = memref.load %4[%c0, %c0] : memref<4x4xf32>
+
+ return
+ }
+
+ hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=1, binding=3, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+ }
+ }
+ }
}
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
- // CHECK-LABEL: spv.module
- // CHECK: spv.GlobalVariable @[[FUNC1_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-
- // CHECK: spv.func @resource_bindings_in_entry_func1()
- func @resource_bindings_in_entry_func1() {
- // CHECK: spv.mlir.addressof @[[FUNC1_ARG]]
- // CHECK: spv.mlir.addressof @[[FUNC1_RET]]
- %c0 = constant 0 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
- %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4xvector<4xf32>>
-
- %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
- %3 = memref.load %1[%c0] : memref<4xvector<4xf32>>
-
- return
- }
-
- // CHECK: spv.func @resource_bindings_in_entry_func2()
- func @resource_bindings_in_entry_func2() {
- // CHECK: spv.mlir.addressof @[[FUNC2_ARG]]
- // CHECK: spv.mlir.addressof @[[FUNC2_RET]]
- %c0 = constant 0 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> // Same type as previous function
- %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> // Different type as previous function
-
- %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
- %3 = memref.load %1[%c0, %c0] : memref<4x4xf32>
-
- return
- }
-
+hal.executable @resource_bindings_in_multi_entry_func attributes {sym_visibility = "private"} {
hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @resource_bindings_in_entry_func1 attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ hal.executable.entry_point @resource_bindings_in_entry_func2 attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ // CHECK-LABEL: spv.module
+ // CHECK: spv.GlobalVariable @[[FUNC1_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+
+ // CHECK: spv.func @resource_bindings_in_entry_func1()
+ func @resource_bindings_in_entry_func1() {
+ // CHECK: spv.mlir.addressof @[[FUNC1_ARG]]
+ // CHECK: spv.mlir.addressof @[[FUNC1_RET]]
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4xvector<4xf32>>
+
+ %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+ %3 = memref.load %1[%c0] : memref<4xvector<4xf32>>
+
+ return
+ }
+
+ // CHECK: spv.func @resource_bindings_in_entry_func2()
+ func @resource_bindings_in_entry_func2() {
+ // CHECK: spv.mlir.addressof @[[FUNC2_ARG]]
+ // CHECK: spv.mlir.addressof @[[FUNC2_RET]]
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> // Same type as previous function
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> // Different type as previous function
+
+ %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+ %3 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+
+ return
+ }
+
+ hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+ }
+ }
+ }
}
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
- func @interface_binding() {
- %c0 = constant 0 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<8x5xf32>
- %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<5xf32>
- %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<8x5xf32>
-
- %3 = memref.load %0[%c0, %c0] : memref<8x5xf32>
- %4 = memref.load %1[%c0] : memref<5xf32>
- %5 = memref.load %2[%c0, %c0] : memref<8x5xf32>
-
- return
- }
+hal.executable @interface_binding attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @interface_binding attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
+ func @interface_binding() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<8x5xf32>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<5xf32>
+ %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<8x5xf32>
+
+ %3 = memref.load %0[%c0, %c0] : memref<8x5xf32>
+ %4 = memref.load %1[%c0] : memref<5xf32>
+ %5 = memref.load %2[%c0, %c0] : memref<8x5xf32>
+
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
}
// Explicitly check the variable symbols
@@ -143,18 +197,32 @@
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
- func @interface_wg_id() {
- %0 = hal.interface.workgroup.id[0] : index
- %1 = hal.interface.workgroup.id[1] : index
- return
- }
+hal.executable @interface_wg_id attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @interface_wg_id attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
+ func @interface_wg_id() {
+ %0 = hal.interface.workgroup.id[0] : index
+ %1 = hal.interface.workgroup.id[1] : index
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
}
+
// CHECK-LABEL: spv.module
// CHECK-DAG: spv.GlobalVariable @[[WGID:.+]] built_in("WorkgroupId")
// CHECK: spv.func
@@ -167,17 +235,30 @@
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
- func @interface_wg_count() {
- %0 = hal.interface.workgroup.count[0] : index
- %1 = hal.interface.workgroup.count[1] : index
- return
- }
+hal.executable @interface_wg_count attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
+ hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @interface_wg_count attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
+ }
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, SwiftShader:CPU, {}>} {
+ func @interface_wg_count() {
+ %0 = hal.interface.workgroup.count[0] : index
+ %1 = hal.interface.workgroup.count[1] : index
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
}
// CHECK-LABEL: spv.module
// CHECK-DAG: spv.GlobalVariable @[[WGCOUNT:.+]] built_in("NumWorkgroups")
diff --git a/iree/compiler/Codegen/SPIRV/test/fold_gpu_procid_uses.mlir b/iree/compiler/Codegen/SPIRV/test/fold_gpu_procid_uses.mlir
index 5424b81..9756ed9 100644
--- a/iree/compiler/Codegen/SPIRV/test/fold_gpu_procid_uses.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/fold_gpu_procid_uses.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-fold-gpu-procid-uses))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-fold-gpu-procid-uses))))' %s | IreeFileCheck %s
hal.executable @fold_block_id attributes {sym_visibility = "private"} {
hal.interface @io {
@@ -76,11 +76,11 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @fold_thread_id attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [8: index, 2: index, 1: index]
}
module {
- func @fold_thread_id() -> (index, index, index)
- attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ func @fold_thread_id() -> (index, index, index) {
%0 = "gpu.thread_id"() {dimension = "x"} : () -> index
%1 = "gpu.thread_id"() {dimension = "y"} : () -> index
%2 = "gpu.thread_id"() {dimension = "z"} : () -> index
@@ -106,10 +106,11 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @does_not_fold_mod attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [8: index, 2: index, 1: index]
}
module {
- func @does_not_fold_mod() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ func @does_not_fold_mod() -> index {
%0 = "gpu.thread_id"() {dimension = "z"} : () -> index
%1 = affine.min affine_map<()[s0] -> (21, s0 mod 5)>()[%0]
return %1: index
@@ -128,10 +129,11 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @does_not_fold_div attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [8: index, 2: index, 1: index]
}
module {
- func @does_not_fold_div() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ func @does_not_fold_div() -> index {
%0 = "gpu.thread_id"() {dimension = "z"} : () -> index
%1 = affine.min affine_map<()[s0] -> (21, s0 ceildiv 5)>()[%0]
return %1: index
@@ -150,10 +152,11 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @does_not_fold_symbol_mul_symbol attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [8: index, 2: index, 1: index]
}
module {
- func @does_not_fold_symbol_mul_symbol() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ func @does_not_fold_symbol_mul_symbol() -> index {
// 5 is in %0's range of [0,7] so we cannot fold the following into 5 or 0.
%0 = "gpu.thread_id"() {dimension = "z"} : () -> index
%1 = affine.min affine_map<()[s0] -> (21, s0 * s0)>()[%0]
diff --git a/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration.mlir
deleted file mode 100644
index 6e55d13..0000000
--- a/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration.mlir
+++ /dev/null
@@ -1,87 +0,0 @@
-// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles))' -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-
-hal.executable @matmul_tensors attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64"> {
- hal.executable.entry_point @matmul_tensors attributes {
- interface = @io,
- ordinal = 0 : index
- }
- module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} {
- func @matmul_tensors() {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
- %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?x?xf32>
- %4 = hal.interface.binding.subspan @io::@arg2[%c0] : memref<?x?xf32>
- %6 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
- %M = memref.dim %0, %c0 : memref<?x?xf32>
- %N = memref.dim %2, %c1 : memref<?x?xf32>
- %K = memref.dim %0, %c1 : memref<?x?xf32>
- %workgroup_size_x = hal.interface.workgroup.size[0] : index
- %workgroup_size_y = hal.interface.workgroup.size[1] : index
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_count_x = hal.interface.workgroup.count[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %8 = muli %workgroup_size_y, %workgroup_id_y : index
- %9 = muli %workgroup_size_y, %workgroup_count_y : index
- scf.for %arg0 = %8 to %M step %9 {
- %10 = muli %workgroup_size_x, %workgroup_id_x : index
- %11 = muli %workgroup_size_x, %workgroup_count_x : index
- scf.for %arg1 = %10 to %N step %11 {
- %12 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %N]
- %13 = memref.subview %0[%arg0, 0] [%12, %K] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- %14 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %M]
- %15 = memref.subview %2[0, %arg1] [%K, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- %16 = memref.subview %4[%arg0, %arg1] [%12, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- %17 = memref.alloc(%12, %14) : memref<?x?xf32>
- linalg.copy(%16, %17) : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xf32>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%13, %15 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) outs(%17 : memref<?x?xf32>)
- %18 = memref.subview %6[%arg0, %arg1] [%12, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- linalg.copy(%17, %18) : memref<?x?xf32>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
- }
- }
- return
- }
- }
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-// CHECK: hal.executable @matmul_tensors
-// CHECK: hal.executable.entry_point @matmul_tensors
-// CHECK-NEXT: ^{{[a-zA-Z0-9_]+}}(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[WGX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-// CHECK-DAG: %[[WGY:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
-// CHECK: hal.return %[[WGX]], %[[WGY]], %[[C1]]
-// CHECK-NOT: hal.interface.workgroup.size
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C16:.+]] = constant 16 : index
-// CHECK-DAG: %[[C8:.+]] = constant 8 : index
-// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
-// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
-// CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @io::@arg2
-// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[LHS]], %[[C0]]
-// CHECK-DAG: %[[N:.+]] = memref.dim %[[RHS]], %[[C1]]
-// CHECK-DAG: %[[K:.+]] = memref.dim %[[LHS]], %[[C1]]
-// CHECK-DAG: %[[WGID_X:.+]] = hal.interface.workgroup.id[0]
-// CHECK-DAG: %[[WGID_Y:.+]] = hal.interface.workgroup.id[1]
-// CHECK-DAG: %[[WGCOUNT_X:.+]] = hal.interface.workgroup.count[0]
-// CHECK-DAG: %[[WGCOUNT_Y:.+]] = hal.interface.workgroup.count[1]
-// CHECK: %[[OFFSET_Y:.+]] = muli %[[WGID_Y]], %[[C8]]
-// CHECK: %[[STEP_Y:.+]] = muli %[[WGCOUNT_Y]], %[[C8]]
-// CHECK: scf.for %{{.+}} = %[[OFFSET_Y]] to %[[M]] step %[[STEP_Y]]
-// CHECK: %[[OFFSET_X:.+]] = muli %[[WGID_X]], %[[C16]]
-// CHECK: %[[STEP_X:.+]] = muli %[[WGCOUNT_X]], %[[C16]]
-// CHECK: scf.for %{{.+}} = %[[OFFSET_X]] to %[[N]] step %[[STEP_X]]
diff --git a/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration2.mlir b/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration2.mlir
deleted file mode 100644
index 72c7684..0000000
--- a/iree/compiler/Codegen/SPIRV/test/materialize_launch_configuration2.mlir
+++ /dev/null
@@ -1,72 +0,0 @@
-// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-convert-to-gpu))' -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-
-hal.executable @add attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @add attributes {
- interface = @io,
- ordinal = 0 : index
- }
- module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} {
- func @add() {
- %c0 = constant 0 : index
- %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<?x?xf32>
- %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<?xf32>
- %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<?x?xf32>
- linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %1 : memref<?x?xf32>, memref<?xf32>) outs(%2 : memref<?x?xf32>) {
- ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
- %3 = addf %arg0, %arg1 : f32
- linalg.yield %3 : f32
- }
- return
- }
- hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- }
- }
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) ceildiv 32)>
-// CHECK: hal.executable @add
-// CHECK: hal.executable.entry_point @add
-// CHECK-NEXT: ^{{[a-zA-Z0-9_]+}}(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[WGCOUNTX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]], %[[ARG2]]]
-// CHECK: hal.return %[[WGCOUNTX]], %[[C1]], %[[C1]]
-// CHECK: func @add()
-// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}
-// CHECK-DAG: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[C1:.+]] = constant 1
-// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
-// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
-// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@ret0
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[LHS]], %[[C0]]
-// CHECK-DAG: %[[N:.+]] = memref.dim %[[LHS]], %[[C1]]
-// CHECK: %[[UB:.+]] = muli %[[N]], %[[M]]
-// CHECK-DAG: %[[BID:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[BDIM:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TID:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK: %[[BOFFSET:.+]] = muli %[[BID]], %[[BDIM]]
-// CHECK: %[[IV:.+]] = addi %[[BOFFSET]], %[[TID]]
-// CHECK: %[[COND:.+]] = cmpi slt, %[[IV]], %[[UB]]
-// CHECK: scf.if %[[COND]] {
-// CHECK: %[[IV0:.+]] = divi_signed %[[IV]], %[[N]]
-// CHECK: %[[IV1:.+]] = remi_signed %[[IV]], %[[N]]
-// CHECK-DAG: %[[V1:.+]] = memref.load %[[LHS]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[V2:.+]] = memref.load %[[RHS]][%[[IV1]]]
-// CHECK-DAG: %[[STORE:.+]] = addf %[[V1]], %[[V2]]
-// CHECK: store %[[STORE]], %[[RESULT]][%[[IV0]], %[[IV1]]]
diff --git a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
index 936962e..1b4bfe3 100644
--- a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-hlo-to-spirv-pipeline))' -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-spirv-pipeline))' %s | IreeFileCheck %s
+
+#config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
hal.executable @fuse_and_vectorize_fill_matmul attributes {sym_visibility = "private"} {
hal.interface @io {
@@ -8,8 +10,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @fuse_and_vectorize_fill_matmul attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @fuse_and_vectorize_fill_matmul() {
@@ -40,8 +43,8 @@
%13 = affine.min affine_map<(d0)[s0] -> (-d0 + 4096, s0)>(%arg0)[%workgroup_size_y]
%14 = affine.min affine_map<(d0)[s0] -> (-d0 + 4096, s0)>(%arg1)[%workgroup_size_x]
%15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
- %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
- %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x4096xf32>, tensor<4096x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %16 = linalg.fill(%cst, %15) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%8, %10 : tensor<?x4096xf32>, tensor<4096x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:4096x4096xf32>
}
}
@@ -66,6 +69,8 @@
// -----
+#config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
+
hal.executable @fuse_and_vectorize_matmul_add attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -74,8 +79,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @fuse_and_vectorize_matmul_add attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @fuse_and_vectorize_matmul_add() {
@@ -112,9 +118,9 @@
%18 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg0)[%workgroup_size_y]
%19 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg1)[%workgroup_size_x]
%20 = linalg.init_tensor [%18, %19] : tensor<?x?xf32>
- %21 = linalg.fill(%cst, %20) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
- %22 = linalg.matmul ins(%15, %17 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%21 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%22, %10 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%13 : tensor<?x?xf32>) attrs = {__internal_linalg_transform__ = "workgroup"} {
+ %21 = linalg.fill(%cst, %20) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %22 = linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%15, %17 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%21 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%22, %10 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%13 : tensor<?x?xf32>) attrs = {__internal_linalg_transform__ = "workgroup", lowering.config = #config} {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%24 = addf %arg2, %arg3 : f32
linalg.yield %24 : f32
diff --git a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
index ba46d79..46d2c03 100644
--- a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-tile-and-vectorize,canonicalize,cse))' -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))'
+// TODO(antiagainst): Fix promotion to workgroup and enable the test.
+// | IreeFileCheck %s
hal.executable @matmul_promote_workgroup_memory attributes {sym_visibility = "private"} {
hal.interface @io {
@@ -8,8 +10,8 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @matmul_promote_workgroup_memory attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 8: index, 1: index]
}
module attributes {
spv.target_env =
@@ -39,7 +41,9 @@
%15 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%3]
%16 = affine.min affine_map<()[s0] -> (16, s0 * -16 + 75)>()[%3]
%17 = memref.subview %2[%13, %15] [%14, %16] [1, 1] : memref<25x75xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %12 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>) outs(%17 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[8, 16, 32], [], [1, 1, 0]]}}
+ ins(%8, %12 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
+ outs(%17 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
}
return
}
@@ -84,8 +88,8 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @conv_promote_workgroup_memory attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 4: index, 1: index]
}
module attributes {
spv.target_env =
@@ -110,7 +114,9 @@
%13 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
%14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 11)>()[%3]
%15 = memref.subview %2[%5, %11, %13, 0] [1, %12, %14, 14] [1, 1, 1, 1] : memref<2x13x11x14xf32> to memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>
- linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%10, %0 : memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>, memref<3x4x6x14xf32>) outs(%15 : memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>)
+ linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[0, 1, 4, 32], [], [0, 1, 1, 1]]}, dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ ins(%10, %0 : memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>, memref<3x4x6x14xf32>)
+ outs(%15 : memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>)
return
}
hal.interface @io attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Codegen/SPIRV/test/remove_one_trip_tiled_loop.mlir b/iree/compiler/Codegen/SPIRV/test/remove_one_trip_tiled_loop.mlir
new file mode 100644
index 0000000..18341f9
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/remove_one_trip_tiled_loop.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-remove-one-trip-tiled-loop))))' %s | IreeFileCheck %s
+
+hal.executable @static_shaped_conv attributes {sym_visibility = "private"} {
+ hal.interface @io {
+ hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @static_shaped_conv attributes {
+ interface = @io, ordinal = 0 : index,
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [16, 4, 4]},
+ workgroup_size = [4 : index, 4 : index, 1 : index]
+ }
+ builtin.module {
+ builtin.func @static_shaped_conv() {
+ %cst = constant 0.000000e+00 : f32
+ %c112 = constant 112 : index
+ %c32 = constant 32 : index
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_z]
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_z]
+ scf.for %arg0 = %3 to %c112 step %4 {
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y]
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_y]
+ scf.for %arg1 = %5 to %c112 step %6 {
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
+ %8 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
+ scf.for %arg2 = %7 to %c32 step %8 {
+ %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
+ %10 = affine.min affine_map<(d0) -> (9, d0 * -2 + 225)>(%arg0)
+ %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
+ %12 = affine.min affine_map<(d0) -> (9, d0 * -2 + 225)>(%arg1)
+ %13 = memref.subview %0[0, %9, %11, 0] [1, %10, %12, 3] [1, 1, 1, 1] : memref<1x225x225x3xf32> to memref<1x?x?x3xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>>
+ %14 = affine.min affine_map<(d0) -> (16, -d0 + 32)>(%arg2)
+ %15 = memref.subview %1[0, 0, 0, %arg2] [3, 3, 3, %14] [1, 1, 1, 1] : memref<3x3x3x32xf32> to memref<3x3x3x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2 * 32 + d3)>>
+ %16 = affine.min affine_map<(d0) -> (4, -d0 + 112)>(%arg0)
+ %17 = affine.min affine_map<(d0) -> (4, -d0 + 112)>(%arg1)
+ %18 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
+ linalg.fill(%cst, %18) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[0, 4, 4, 16], [], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]}} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
+ linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, lowering.config = {tileSizes = [[0, 4, 4, 16], [], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]}, strides = dense<2> : tensor<2xi64>}
+ ins(%13, %15 : memref<1x?x?x3xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>>, memref<3x3x3x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2 * 32 + d3)>>)
+ outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>)
+ }
+ }
+ }
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (16, -d0 + 32)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (4, -d0 + 112)>
+
+// CHECK: func @static_shaped_conv()
+// CHECK: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0] : index
+// CHECK: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] : index
+// CHECK: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2] : index
+// CHECK: %[[OFFSET_Z:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Z]]]
+// CHECK: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]]]
+// CHECK: %[[OFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+// CHECK-NOT: scf.for
+// CHECK-DAG: %[[SIZE_Z:.+]] = affine.min #[[MAP3]](%[[OFFSET_Z]])
+// CHECK-DAG: %[[SIZE_Y:.+]] = affine.min #[[MAP3]](%[[OFFSET_Y]])
+// CHECK-DAG: %[[SIZE_X:.+]] = affine.min #[[MAP2]](%[[OFFSET_X]])
+// CHECK: %[[OUTPUT:.+]] = memref.subview %{{.+}}[0, %[[OFFSET_Z]], %[[OFFSET_Y]], %[[OFFSET_X]]] [1, %[[SIZE_Z]], %[[SIZE_Y]], %[[SIZE_X]]]
+// CHECK: linalg.fill(%{{.+}}, %[[OUTPUT]])
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: outs(%[[OUTPUT]]
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
index 52a6cac..1ce42bd 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-tile-and-vectorize,canonicalize,cse))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
#map0 = affine_map<()[s0] -> (s0 * 8)>
#map1 = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
@@ -8,6 +8,8 @@
#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#config = {tileSizes = [[8, 16, 0], [], [1, 1, 1]]}
+
hal.executable @matmul attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -15,7 +17,11 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @matmul attributes {interface = @io, ordinal = 0 : index}
+ hal.executable.entry_point @matmul attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 8: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [8, 16]}
+ }
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
@@ -48,7 +54,7 @@
%16 = memref.dim %arg2, %c1 : memref<?x?xf32>
%17 = affine.min #map1()[%1, %16]
%18 = memref.subview %arg2[%3, %10] [%15, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config}
ins(%7, %13 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
outs(%18 : memref<?x?xf32, #map3>)
}
@@ -77,6 +83,8 @@
// -----
+#config = {tileSizes = [[1, 4, 32], [], [1, 1, 1]]}
+
hal.executable @conv_1d attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -84,7 +92,11 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @conv_1d attributes {interface = @io, ordinal = 0 : index}
+ hal.executable.entry_point @conv_1d attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 4: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
+ }
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} {
func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
@@ -107,7 +119,7 @@
%15 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 1)>()[%3]
%16 = memref.subview %0[%5, %12, %14] [1, %13, %15] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
%17 = memref.subview %0[%5, %12, %9] [1, %13, %10] [1, 1, 1] : memref<3x6x1xf32> to memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>
- linalg.conv_1d_input_nwc_filter_wcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%8, %11 : memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>, memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>) outs(%16 : memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>)
+ linalg.conv_1d_input_nwc_filter_wcf { __internal_linalg_transform__ = "workgroup", lowering.config = #config, dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%8, %11 : memref<1x?x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 + d2)>>, memref<3x1x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>) outs(%16 : memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 6 + s0 + d1 + d2)>>)
return
}
hal.interface @io attributes {sym_visibility = "private"} {
@@ -157,6 +169,8 @@
#map6 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
#map7 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+#config = {tileSizes = [[0, 1, 4, 32], [], [0, 1, 1, 1]]}
+
hal.executable @conv_no_padding attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -164,7 +178,11 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @conv_no_padding attributes {interface = @io, ordinal = 0 : index}
+ hal.executable.entry_point @conv_no_padding attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 4: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
+ }
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
@@ -213,6 +231,7 @@
: memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
linalg.conv_2d_input_nhwc_filter_hwcf {
__internal_linalg_transform__ = "workgroup",
+ lowering.config = #config,
dilations = dense<1> : tensor<2xi64>,
strides = dense<2> : tensor<2xi64>}
ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
@@ -270,6 +289,8 @@
// -----
+#config = {tileSizes = [[0, 0, 1, 4, 32], [], [0, 0, 1, 1, 1]]}
+
hal.executable @conv_3d attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -277,7 +298,11 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @conv_3d attributes {interface = @io, ordinal = 0 : index}
+ hal.executable.entry_point @conv_3d attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 4: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
+ }
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} {
func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
@@ -299,7 +324,7 @@
%14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 7)>()[%3]
%15 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
%16 = memref.subview %0[%5, %11, %13, 0, 0] [1, %12, %14, 7, 2] [1, 1, 1, 1, 1] : memref<2x7x7x7x2xf32> to memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>
- linalg.conv_3d_input_ndhwc_filter_dhwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%10, %2 : memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>, memref<2x2x2x3x2xf32>) outs(%15 : memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>)
+ linalg.conv_3d_input_ndhwc_filter_dhwcf {__internal_linalg_transform__ = "workgroup", lowering.config = #config, dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%10, %2 : memref<1x?x?x8x3xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 1536 + s0 + d1 * 192 + d2 * 24 + d3 * 3 + d4)>>, memref<2x2x2x3x2xf32>) outs(%15 : memref<1x?x?x7x2xf32, affine_map<(d0, d1, d2, d3, d4)[s0] -> (d0 * 686 + s0 + d1 * 98 + d2 * 14 + d3 * 2 + d4)>>)
return
}
hal.interface @io attributes {sym_visibility = "private"} {
@@ -334,6 +359,9 @@
#map5 = affine_map<()[s0] -> (4, s0 * -4 + 14)>
#map6 = affine_map<()[s0] -> (32, s0 * -32 + 13)>
#map7 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1092 + s0 + d1 * 78 + d2 * 6 + d3)>
+
+#config = {tileSizes = [[1, 4, 32], [], [1, 1, 1]]}
+
module {
hal.executable @pooling_nhwc_max attributes {sym_visibility = "private"} {
hal.interface @io {
@@ -342,11 +370,10 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @pooling_nhwc_max attributes {interface = @io, ordinal = 0 : index} {
- ^bb0(%arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c4 = constant 4 : index
- %c1 = constant 1 : index
- hal.return %c1, %c4, %c1 : index, index, index
+ hal.executable.entry_point @pooling_nhwc_max attributes {
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 4: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @pooling_nhwc_max() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
@@ -364,7 +391,7 @@
%10 = affine.min #map5()[%4]
%11 = affine.min #map6()[%3]
%12 = memref.subview %2[0, %5, %7, 0] [2, %10, %11, 6] [1, 1, 1, 1] : memref<2x14x13x6xf32> to memref<2x?x?x6xf32, #map7>
- linalg.pooling_nhwc_max {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%9, %1 : memref<2x?x?x6xf32, #map4>, memref<3x4xf32>) outs(%12 : memref<2x?x?x6xf32, #map7>)
+ linalg.pooling_nhwc_max {__internal_linalg_transform__ = "workgroup", lowering.config = #config, dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%9, %1 : memref<2x?x?x6xf32, #map4>, memref<3x4xf32>) outs(%12 : memref<2x?x?x6xf32, #map7>)
return
}
hal.interface @io attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index 7378c56..1cbb881 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles,iree-spirv-tile-and-vectorize))' -canonicalize -cse -iree-spirv-workgroup-tile-size=1,8,64,4 -iree-spirv-invocation-tile-size=1,8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+
+#config = {tileSizes = [[1, 8, 64, 4], [], [1, 8, 4, 4]]}
hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
@@ -8,8 +10,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @batch_matmul_static_shape attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8, 1]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @batch_matmul_static_shape() {
@@ -43,7 +46,9 @@
%12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
%13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
%14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
- linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+ linalg.batch_matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config}
+ ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+ outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
}
}
}
@@ -74,21 +79,29 @@
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
// CHECK-DAG: %[[C7:.+]] = constant 7 : index
// CHECK: %[[BIDX:.+]] = hal.interface.workgroup.id[0]
+// CHECK: %[[BCNTX:.+]] = hal.interface.workgroup.count[0]
// CHECK: %[[BIDY:.+]] = hal.interface.workgroup.id[1]
+// CHECK: %[[BCNTY:.+]] = hal.interface.workgroup.count[1]
// CHECK: %[[BIDZ:.+]] = hal.interface.workgroup.id[2]
-// CHECK-DAG: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK-DAG: %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[BCNTZ:.+]] = hal.interface.workgroup.count[2]
+// CHECK: scf.for %[[IVZ:.+]] = %[[BIDZ]] to %{{.+}} step %[[BCNTZ]]
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[UBY:.+]] = affine.apply #[[MAP0]]()[%[[BCNTY]]]
+// CHECK: scf.for %[[IVY:.+]] = %[[BOFFSET_Y]] to %{{.+}} step %[[UBY]]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[UBX:.+]] = affine.apply #[[MAP1]]()[%[[BCNTX]]]
// CHECK: %[[SUBVIEW_ARG0:.+]] = memref.subview %[[ARG0]]
-// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], 0] [1, 8, 1024]
-// CHECK: %[[SUBVIEW_ARG1:.+]] = memref.subview %[[ARG1]]
-// CHECK-SAME: [%[[BIDZ]], 0, %[[BOFFSET_X]]] [1, 1024, 64]
-// CHECK: %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
-// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], %[[BOFFSET_X]]] [1, 8, 64]
+// CHECK-SAME: [%[[IVZ]], %[[IVY]], 0] [1, 8, 1024]
// CHECK: %[[IIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
// CHECK: %[[IIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
// CHECK: %[[IIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
// CHECK-DAG: %[[IOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[IIDY]]]
// CHECK-DAG: %[[IOFFSET_X:.+]] = affine.apply #[[MAP2]]()[%[[IIDX]]]
+// CHECK: scf.for %[[IVX:.+]] = %[[BOFFSET_X]] to %{{.+}} step %[[UBX]]
+// CHECK: %[[SUBVIEW_ARG1:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME: [%[[IVZ]], 0, %[[IVX]]] [1, 1024, 64]
+// CHECK: %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
+// CHECK-SAME: [%[[IVZ]], %[[IVY]], %[[IVX]]] [1, 8, 64]
// CHECK: %[[SUBVIEW_RESULT_2:.+]] = memref.subview %[[SUBVIEW_RESULT]]
// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], %[[IOFFSET_X]]] [1, 8, 4]
// CHECK-DAG: %[[READ_INIT_0:.+]] = vector.transfer_read
@@ -108,7 +121,7 @@
// CHECK-DAG: %[[READ_INIT_7:.+]] = vector.transfer_read
// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C7]], %[[C0]]]
-// CHECK: %[[FOR_RES:.+]]:8 = scf.for %[[IV0:.+]] = {{.*}} to
+// CHECK: %[[FOR_RES:.+]]:8 = scf.for %[[IV3:.+]] = {{.*}} to
// CHECK-SAME: iter_args(%[[ACC_0:.+]] = %[[READ_INIT_0]],
// CHECK-SAME: %[[ACC_1:.+]] = %[[READ_INIT_1]],
// CHECK-SAME: %[[ACC_2:.+]] = %[[READ_INIT_2]],
@@ -118,9 +131,9 @@
// CHECK-SAME: %[[ACC_6:.+]] = %[[READ_INIT_6]],
// CHECK-SAME: %[[ACC_7:.+]] = %[[READ_INIT_7]])
// CHECK-DAG: %[[SUBVIEW_LHS:.+]] = memref.subview %[[SUBVIEW_ARG0]]
-// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], %[[IV0]]] [1, 8, 4]
+// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], %[[IV3]]] [1, 8, 4]
// CHECK-DAG: %[[SUBVIEW_RHS:.+]] = memref.subview %[[SUBVIEW_ARG1]]
-// CHECK-SAME: [%[[IIDZ]], %[[IV0]], %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
+// CHECK-SAME: [%[[IIDZ]], %[[IV3]], %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
// CHECK-DAG: %[[READ_LHS_0:.+]] = vector.transfer_read %[[SUBVIEW_LHS]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK-DAG: %[[READ_LHS_1:.+]] = vector.transfer_read %[[SUBVIEW_LHS]][%[[C0]], %[[C1]], %[[C0]]]
@@ -355,6 +368,8 @@
// -----
+#config = {tileSizes = [[1, 8, 64, 4], [], [1, 8, 4, 4]]}
+
hal.executable @fused_fill_batch_matmul attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -363,8 +378,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @fused_fill_batch_matmul attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8, 1]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @fused_fill_batch_matmul() {
@@ -399,8 +415,8 @@
%12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
%13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
%14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
- linalg.fill(%zero, %14) : f32, memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
- linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+ linalg.fill(%zero, %14) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f32, memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ linalg.batch_matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
}
}
}
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index 77d9b5e..9318352 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles,iree-spirv-tile-and-vectorize))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(canonicalize,iree-spirv-remove-one-trip-tiled-loop,iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+
+#config = {tileSizes = [[0, 4, 4, 16], [], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]}
hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} {
hal.interface @io {
@@ -9,7 +11,15 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @conv_static_shape_f32 attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [4: index, 4: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [16, 4, 4]}
+ } {
+ ^bb0(%arg0 : index, %arg1 : index, %arg2 : index):
+ %x = constant 2: index
+ %y = constant 28: index
+ %z = constant 28: index
+ hal.return %x, %y, %z: index, index, index
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @conv_static_shape_f32() {
@@ -48,8 +58,8 @@
%16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
%17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
%18 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
- linalg.fill(%cst, %18) {__internal_linalg_transform__ = "workgroup"} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
- linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>)
+ linalg.fill(%cst, %18) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>
+ linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", lowering.config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>)
}
}
}
@@ -89,6 +99,8 @@
// -----
+#config = {tileSizes = [[0, 2, 2, 32], [], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1]]}
+
hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -98,7 +110,15 @@
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes {
interface = @io,
- ordinal = 0 : index
+ ordinal = 0 : index,
+ workgroup_size = [8: index, 2: index, 2: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [16, 4, 4]}
+ } {
+ ^bb0(%arg0 : index, %arg1 : index, %arg2 : index):
+ %x = constant 6: index
+ %y = constant 14: index
+ %z = constant 14: index
+ hal.return %x, %y, %z: index, index, index
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @depthwise_conv_static_shape_f32() {
@@ -139,8 +159,8 @@
%18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg0)[%workgroup_size_z]
%19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg1)[%workgroup_size_y]
%20 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, %18, %19, %15] [1, 1, 1, 1] : memref<1x56x56x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>
- linalg.fill(%cst, %20) {__internal_linalg_transform__ = "workgroup"} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>)
+ linalg.fill(%cst, %20) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "workgroup", lowering.config = #config, dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>)
}
}
}
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index 01f71a3..e3c3076 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,4 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-concretize-workgroup-tiles,iree-spirv-tile-and-vectorize))' -canonicalize -cse -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+
+#config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
@@ -8,8 +10,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @matmul_static_shape_f16 attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @matmul_static_shape_f16() {
@@ -36,8 +39,8 @@
%9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
%10 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf16> to memref<?x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
%11 = memref.subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf16> to memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.fill(%cst, %10) {__internal_linalg_transform__ = "workgroup"} : f16, memref<?x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %11 : memref<?x4096xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%10 : memref<?x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ linalg.fill(%cst, %10) {__internal_linalg_transform__ = "workgroup", lowering.config = #config} : f16, memref<?x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%8, %11 : memref<?x4096xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%10 : memref<?x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
}
}
return
@@ -63,6 +66,8 @@
// -----
+#config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
+
hal.executable @matmul_static_shape_f32 attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -71,8 +76,9 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @matmul_static_shape_f32 attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [16: index, 1: index, 1: index],
+ translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [64, 8]}
}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @matmul_static_shape_f32() {
@@ -99,8 +105,8 @@
%9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
%10 = memref.subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf32> to memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
%11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.fill(%cst, %11) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ linalg.fill(%cst, %11) {__internal_linalg_transform__ = "workgroup", lowering.config = #config}: f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config} ins(%8, %10 : memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
}
}
return
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
index 9967d3c..d84bb33 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-tile-and-vectorize,canonicalize,cse))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' %s | IreeFileCheck %s
// CHECK-LABEL: func @elementwise_static_shape
// CHECK: vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
@@ -13,15 +13,13 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @elementwise_static_shape attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
}
module attributes {
spv.target_env =
- #spv.target_env<#spv.vce<v1.5,
- [Shader],
- []>, NVIDIA:DiscreteGPU,
- {subgroup_size = 32 : i32}>} {
+ #spv.target_env<#spv.vce<v1.5, [Shader], []>,
+ NVIDIA:DiscreteGPU, {subgroup_size = 32 : i32}>} {
func @elementwise_static_shape() {
%c0 = constant 0 : index
%arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128xf32>
@@ -29,6 +27,7 @@
%ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128xf32>
linalg.generic {
__internal_linalg_transform__ = "workgroup",
+ lowering.config = {tileSizes = [[128], [], [4]]},
indexing_maps = [affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>],
@@ -66,15 +65,13 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"llvm", "embedded-elf-x86_64"> {
hal.executable.entry_point @elementwise_transpose attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
}
module attributes {
spv.target_env =
- #spv.target_env<#spv.vce<v1.5,
- [Shader],
- []>, NVIDIA:DiscreteGPU,
- {subgroup_size = 32 : i32}>} {
+ #spv.target_env<#spv.vce<v1.5, [Shader], []>,
+ NVIDIA:DiscreteGPU, {subgroup_size = 32 : i32}>} {
func @elementwise_transpose() {
%c0 = constant 0 : index
%arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128x8xf32>
@@ -82,6 +79,7 @@
%ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128x8xf32>
linalg.generic {
__internal_linalg_transform__ = "workgroup",
+ lowering.config = {tileSizes = [[1, 32], [], [1, 1]]},
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index 94b650c..f343f6b 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -1,5 +1,6 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-tile-and-vectorize,canonicalize,cse))' %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-tile-and-vectorize,canonicalize,cse))' -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
+// TODO(antiagainst): Fix promotion to workgroup and enable the test.
+// | IreeFileCheck %s -check-prefix=PROMOTE
hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
@@ -9,8 +10,8 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @matmul_static_shape attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
}
module attributes {
spv.target_env =
@@ -54,7 +55,9 @@
%9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%4]
%10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%3]
%11 = memref.subview %2[%9, %10] [64, 64] [1, 1] : memref<4096x4096xf16> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%6, %8 : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[64, 64, 32], [64, 64]]}}
+ ins(%6, %8 : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ outs(%11 : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
}
return
}
@@ -269,8 +272,8 @@
}
hal.executable.variant @vulkan, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> {
hal.executable.entry_point @matmul_static_shape attributes {
- interface = @io,
- ordinal = 0 : index
+ interface = @io, ordinal = 0 : index,
+ workgroup_size = [32: index, 1: index, 1: index]
}
module attributes {
spv.target_env =
@@ -314,7 +317,9 @@
%9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
%10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
%11 = memref.subview %2[%9, %10] [128, 128] [1, 1] : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op, launch_info_key = "__op_num_0__"} ins(%6, %8 : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[64, 64, 32], [64, 64]]}}
+ ins(%6, %8 : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
}
return
}
diff --git a/iree/compiler/Codegen/Transforms/Transforms.cpp b/iree/compiler/Codegen/Transforms/Transforms.cpp
index a12baf7..5a4778f 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -27,38 +27,6 @@
namespace {
-/// Sets the hal.interace.workgroup.size operation to the constant value passed
-/// in as `workloadPerWorkgroup`. The number of entries in
-/// `workloadPerWorkgroup` is at least as much as the dimensionality of the
-/// workgroup. It is assumed that the inner-most loop is mapped to the fastest
-/// varying dimension in flow.dispatch.workgroup_size.
-class SetWorkgroupSizePattern
- : public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
- public:
- SetWorkgroupSizePattern(MLIRContext *context,
- ArrayRef<int64_t> workloadPerWorkgroupRef,
- PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit),
- workloadPerWorkgroup(llvm::to_vector<4>(
- workloadPerWorkgroupRef.size() > kNumMaxParallelDims
- ? workloadPerWorkgroupRef.take_front(kNumMaxParallelDims)
- : workloadPerWorkgroupRef)) {}
-
- LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceWorkgroupSizeOp workgroupSizeOp,
- PatternRewriter &rewriter) const override {
- int64_t dim = workgroupSizeOp.dimension().getSExtValue();
- if (dim >= workloadPerWorkgroup.size()) {
- return failure();
- }
- rewriter.replaceOpWithNewOp<ConstantIndexOp>(workgroupSizeOp,
- workloadPerWorkgroup[dim]);
- return success();
- }
-
- private:
- SmallVector<int64_t, 4> workloadPerWorkgroup;
-};
} // namespace
LogicalResult defineWorkgroupCountRegion(
@@ -101,34 +69,6 @@
return success();
}
-LogicalResult materializeStaticLaunchInformation(
- FuncOp funcOp, ArrayRef<int64_t> workloadPerWorkgroup) {
- OwningRewritePatternList patterns(funcOp.getContext());
- patterns.insert<SetWorkgroupSizePattern>(funcOp.getContext(),
- workloadPerWorkgroup);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return failure();
- }
- assert(workloadPerWorkgroup.size() <= kNumMaxParallelDims &&
- "workloadPerWorkgroup size greater than max num parallel dims");
- WorkgroupCountRegionBuilder regionBuilder =
- [&workloadPerWorkgroup](
- OpBuilder &b, Location loc,
- std::array<Value, 3> workload) -> std::array<Value, 3> {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- std::array<Value, 3> returnValues = {one, one, one};
- for (auto ts : llvm::enumerate(workloadPerWorkgroup)) {
- returnValues[ts.index()] = linalg::applyMapToValues(
- b, loc,
- AffineMap::get(0, 1, b.getAffineSymbolExpr(0).ceilDiv(ts.value())),
- workload[ts.index()])[0];
- }
- return returnValues;
- };
- OpBuilder builder(funcOp.getContext());
- return defineWorkgroupCountRegion(builder, funcOp, regionBuilder);
-}
-
/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
diff --git a/iree/compiler/Codegen/Transforms/Transforms.h b/iree/compiler/Codegen/Transforms/Transforms.h
index 025b19b..1b30ff2 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/iree/compiler/Codegen/Transforms/Transforms.h
@@ -32,18 +32,6 @@
OpBuilder &builder, FuncOp funcOp,
WorkgroupCountRegionBuilder regionBuilder);
-/// Using linalg on tensors for dispatch region creation does first-level of
-/// tile (fuse and distribute) during dispatch region formation. At that point
-/// the workload per workgroup is set to the dynamic value represented by
-/// `flow.dispatch.workgroup.size` and is later lowered to
-/// `hal.dispatch.workgroup.size`. This method is to materialize the static
-/// information of the workload per workgroup determined based on target
-/// architecture. Note that the value of hal.dispatch.workgroup.size is now
-/// different after this function is called and represents the actual value used
-/// at runtime.
-LogicalResult materializeStaticLaunchInformation(
- FuncOp funcOp, ArrayRef<int64_t> workloadPerWorkgroup);
-
/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 210352c..7fec29c 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/SymbolTable.h"
@@ -27,8 +28,7 @@
}
IREE::HAL::ExecutableEntryPointOp getEntryPoint(FuncOp funcOp) {
- auto variantOp =
- funcOp.getOperation()->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ auto variantOp = funcOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
for (auto op : variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
if (op.sym_name() == funcOp.getName()) {
return op;
@@ -39,8 +39,7 @@
llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> getAllEntryPoints(
ModuleOp module) {
- auto variantOp =
- module.getOperation()->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ auto variantOp = module->getParentOfType<IREE::HAL::ExecutableVariantOp>();
llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPointOps;
for (auto op : variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
entryPointOps[op.sym_name()] = op;
@@ -50,11 +49,11 @@
void setTranslationInfo(FuncOp entryPointFn,
IREE::HAL::DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workgroupSize) {
+ ArrayRef<int64_t> workgroupSize,
+ ArrayRef<int64_t> workloadPerWorkgroup) {
auto entryPointOp = getEntryPoint(entryPointFn);
auto translationInfo = buildTranslationInfo(
- passPipeline, /*workloadPerWorkgroup=*/ArrayRef<int64_t>{},
- entryPointFn.getContext());
+ passPipeline, workloadPerWorkgroup, entryPointFn.getContext());
setTranslationInfo(entryPointOp, translationInfo, workgroupSize);
}
@@ -161,6 +160,30 @@
.Default([&](Type type) { return ArrayRef<int64_t>{}; });
}
+/// Returns the untiled shape of the output of a `LinalgOp`.
+// TODO(ravishankarm): Using the result shape for vectorization should be
+// avoided. Ideally the tile size is enough. But there is a phase ordering issue
+// which prevents the tile size from being known at this point.
+ArrayRef<int64_t> getUntiledResultShape(linalg::LinalgOp linalgOp,
+ unsigned resultNum) {
+ // Check the shape of the `outs` operand.
+ ArrayRef<int64_t> outputShape =
+ getUntiledShape(linalgOp.outputs()[resultNum]);
+ if (!llvm::any_of(outputShape, ShapedType::isDynamic)) return outputShape;
+ // Try to use the result value and check if the untiled shape can be obtained
+ // based on the uses.
+ Value result = linalgOp->getResult(resultNum);
+ for (Operation *user : result.getUsers()) {
+ if (auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(user)) {
+ return storeOp.target()
+ .getType()
+ .cast<IREE::Flow::DispatchTensorType>()
+ .getShape();
+ }
+ }
+ return result.getType().cast<ShapedType>().getShape();
+}
+
LogicalResult getFilteredOps(FuncOp funcOp, RootOpFilteringFn filteringFn,
SmallVectorImpl<Operation *> &filteredOps,
SmallVectorImpl<Operation *> &tiledLoops) {
diff --git a/iree/compiler/Codegen/Utils/Utils.h b/iree/compiler/Codegen/Utils/Utils.h
index b8e9beb..6d195dd 100644
--- a/iree/compiler/Codegen/Utils/Utils.h
+++ b/iree/compiler/Codegen/Utils/Utils.h
@@ -34,7 +34,8 @@
/// set.
void setTranslationInfo(FuncOp entryPointFn,
IREE::HAL::DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workgroupSize = {});
+ ArrayRef<int64_t> workgroupSize,
+ ArrayRef<int64_t> workloadPerWorkgroup);
/// Returns the loops that are partitioned during dispatch region formations, in
/// order, i.e. starting from the outer-most to innermost.
@@ -72,6 +73,12 @@
/// `subtensor` op chain (for tensors).
ArrayRef<int64_t> getUntiledShape(Value tiledView);
+/// Returns the shape of the result of the untiled operation for
+/// `LinalgOp`s. First looks at definitions of the corresponding `outs`
+/// operands. If that fails, then looks at uses of the `result`.
+ArrayRef<int64_t> getUntiledResultShape(linalg::LinalgOp linalgOp,
+ unsigned resultNum);
+
/// Assuming that `funcOp` contains a single nested scf.for that represented the
/// tiled+fused+distributed loops with the distribution being across workgroups,
/// i.e.
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp b/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
index c6fc59d..1c82197 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.cpp
@@ -44,6 +44,19 @@
kTranslationInfoAttrName);
}
+SmallVector<int64_t> getWorkgroupSize(
+ IREE::HAL::ExecutableEntryPointOp entryPointOp) {
+ SmallVector<int64_t> workgroupSize;
+ if (Optional<ArrayAttr> workgroupSizeAttrList =
+ entryPointOp.workgroup_size()) {
+ workgroupSize.resize(workgroupSizeAttrList->size());
+ for (auto attr : llvm::enumerate(workgroupSizeAttrList.getValue())) {
+ workgroupSize[attr.index()] = attr.value().cast<IntegerAttr>().getInt();
+ }
+ }
+ return workgroupSize;
+}
+
void setTranslationInfo(IREE::HAL::ExecutableEntryPointOp entryPointOp,
IREE::HAL::TranslationInfo translationInfo,
ArrayRef<int64_t> workgroupSize) {
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.h b/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
index a96ebcd..c2bf095 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.h
@@ -59,6 +59,10 @@
IREE::HAL::TranslationInfo getTranslationInfo(
IREE::HAL::ExecutableEntryPointOp entryPointOp);
+/// Returns the workgroup size specified on the `entryPointOp`.
+SmallVector<int64_t> getWorkgroupSize(
+ IREE::HAL::ExecutableEntryPointOp entryPointOp);
+
/// Set the translate executable info with the entry point op. Overwrites the
/// existing attributes.
// TODO(ravishankarm, benvanik): Eventually all the information needed for the
diff --git a/iree/compiler/Dialect/HAL/IR/LoweringConfig.td b/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
index d0f4111..519c754 100644
--- a/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
+++ b/iree/compiler/Dialect/HAL/IR/LoweringConfig.td
@@ -20,6 +20,12 @@
: I32EnumAttrCase<"LLVMGPUVectorize", 3>;
def LLVMGPU_MatmulSimt
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 4>;
+def SPIRV_SimpleDistribute
+ : I32EnumAttrCase<"SPIRVDistribute", 5>;
+def SPIRV_Vectorize
+ : I32EnumAttrCase<"SPIRVVectorize", 6>;
+def SPIRV_DistributeToGlobalID
+ : I32EnumAttrCase<"SPIRVDistributeToGlobalID", 7>;
// EnumAttrCase for all known lowerings for ops within dispatch region
// to scalar/native-vector code.
@@ -27,7 +33,8 @@
"DispatchLoweringPassPipeline",
"identifier for pass pipeline use to lower dispatch region",
[CPU_Default, CPU_Vectorization, LLVMGPU_SimpleDistribute,
- LLVMGPU_Vectorize, LLVMGPU_MatmulSimt]> {
+ LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, SPIRV_SimpleDistribute, SPIRV_Vectorize,
+ SPIRV_DistributeToGlobalID]> {
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 0cc87be..33ded2e 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -61,7 +61,7 @@
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
- buildSPIRVCodegenPassPipeline(passManager, SPIRVCodegenOptions());
+ buildSPIRVCodegenPassPipeline(passManager);
}
LogicalResult serializeExecutable(IREE::HAL::ExecutableVariantOp variantOp,
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 189610a..3cab856 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -82,7 +82,6 @@
llvm::cl::init(""));
VulkanSPIRVTargetOptions targetOptions;
- targetOptions.codegenOptions = SPIRVCodegenOptions::getFromCLOptions();
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
targetOptions.vulkanTargetTriple = clVulkanTargetTriple;
@@ -152,7 +151,7 @@
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
- buildSPIRVCodegenPassPipeline(passManager, options_.codegenOptions);
+ buildSPIRVCodegenPassPipeline(passManager);
}
// TODO(antiagainst): Re-enable SPIR-V linking once the tensorflow integration
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
index 677a549..768b56e 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h
@@ -19,8 +19,6 @@
// Options controlling the SPIR-V translation.
struct VulkanSPIRVTargetOptions {
- // SPIR-V codegeneration options
- SPIRVCodegenOptions codegenOptions;
// Vulkan target environment as #vk.target_env attribute assembly.
std::string vulkanTargetEnv;
// Vulkan target triple.
diff --git a/iree/test/e2e/regression/linalg_ops.mlir b/iree/test/e2e/regression/linalg_ops.mlir
index 4a0660a..8a742f9 100644
--- a/iree/test/e2e/regression/linalg_ops.mlir
+++ b/iree/test/e2e/regression/linalg_ops.mlir
@@ -31,3 +31,29 @@
[189, 220, 253, 288]]> : tensor<3x4xi32>) : tensor<3x4xi32>
return
}
+
+func @operand_fusion() {
+ %input = util.unfoldable_constant dense<1.0> : tensor<1x225x225x3xf32>
+ %filter = util.unfoldable_constant dense<1.0> : tensor<3x3x3x16xf32>
+ %bias = util.unfoldable_constant dense<1.0> : tensor<16xf32>
+ %init = linalg.init_tensor [1, 112, 112, 16] : tensor<1x112x112x16xf32>
+ %cst = constant 0.0 : f32
+ %fill = linalg.fill(%cst, %init) : f32, tensor<1x112x112x16xf32> -> tensor<1x112x112x16xf32>
+ %conv = linalg.conv_2d_input_nhwc_filter_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x16xf32>)
+ outs(%fill : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%conv, %bias : tensor<1x112x112x16xf32>, tensor<16xf32>)
+ outs(%init : tensor<1x112x112x16xf32>) {
+ ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32):
+ %0 = addf %arg0, %arg1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<1x112x112x16xf32>
+ check.expect_eq_const(%result, dense<28.0> : tensor<1x112x112x16xf32>) : tensor<1x112x112x16xf32>
+ return
+}
\ No newline at end of file
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 69c50fc..3dee341 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -39,20 +39,6 @@
)
iree_check_single_backend_test_suite(
- name = "check_vulkan-spirv_vulkan_wgmem",
- srcs = [
- "conv.mlir",
- "gemm.mlir",
- ],
- compiler_flags = [
- "-iree-input-type=mhlo",
- "-iree-spirv-use-workgroup-memory",
- ],
- driver = "vulkan",
- target_backend = "vulkan-spirv",
-)
-
-iree_check_single_backend_test_suite(
name = "check_vulkan-spirv_vulkan_vectorized_conv",
srcs = [
"vectorized_conv.mlir",
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 9e911e6..b32120c 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -31,21 +31,6 @@
iree_check_single_backend_test_suite(
NAME
- check_vulkan-spirv_vulkan_wgmem
- SRCS
- "conv.mlir"
- "gemm.mlir"
- TARGET_BACKEND
- "vulkan-spirv"
- DRIVER
- "vulkan"
- COMPILER_FLAGS
- "-iree-input-type=mhlo"
- "-iree-spirv-use-workgroup-memory"
-)
-
-iree_check_single_backend_test_suite(
- NAME
check_vulkan-spirv_vulkan_vectorized_conv
SRCS
"vectorized_conv.mlir"