[spirv] Push GPU target conversion to before SPIR-V conversion (#17816)
This commit moves the `SPIRVConvertGPUTargetPass` to right before the
`ConvertToSPIRVPass` in the pipeline. This makes sure we use the same
`#iree_gpu.target` in the majority of the configuration and lowering
passes in the CodeGen flow, and scopes the SPIR-V target environment to
only the final SPIR-V conversion. With this, we are able to unify and
simplify lots of SPIR-V tests.
Progress towards https://github.com/iree-org/iree/issues/16341
ci-extra:
test_nvidia_gpu,test_nvidia_a100,test_amd_mi250,test_amd_w7900,build_test_all_macos_arm64,build_and_test_android
---------
Signed-off-by: Lei Zhang <antiagainst@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index af69895..e421f4e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -362,15 +362,47 @@
let assemblyFormat = "`<` struct(params) `>`";
let extraClassDeclaration = [{
- int getPreferredSubgroupSize() const {
- return getWgp().getSubgroupSizeChoices().asArrayRef().front();
+ // Subgroup size related APIs
+
+ int getMinSubgroupSize() const {
+ return *llvm::min_element(getWgp().getSubgroupSizeChoices().asArrayRef());
}
+ int getMaxSubgroupSize() const {
+ return *llvm::max_element(getWgp().getSubgroupSizeChoices().asArrayRef());
+ }
+ // Returns the preferred subgroup size. If the target supports multiple
+ // subgroup sizes, pickLargest controls whether to return the largest one.
+ //
+ // AMD RDNA GPUs supports multiple subgroup sizes and the preferred one
+ // differ given the API--HIP prefers 32 while Vulkan prefers 64.
+ // TODO: We should be able to force Vulkan side to use 32 consistently
+ // too with subgroup size control; it might have perf implications though.
+ int getPreferredSubgroupSize(bool pickLargest=false) const {
+ if (pickLargest) {
+ return getMaxSubgroupSize();
+ }
+ return getMinSubgroupSize();
+ }
+
+ // Hardware feature related APIs
bool supportsSubgroupShuffle() const {
return bitEnumContainsAll(getWgp().getSubgroup().getValue(),
SubgroupOps::Shuffle);
}
+ // Vendor querying APIs
+
+ bool isAMD() const {
+ return getArch().starts_with("gfx") || getArch().starts_with("rdna");
+ }
+ bool isApple() const { return getArch().starts_with("apple"); }
+ bool isARM() const { return getArch().starts_with("valhall"); }
+ bool isNVIDIA() const { return getArch().starts_with("sm_"); }
+ bool isQualcomm() const { return getArch().starts_with("adreno"); }
+
+ // CUDA specific querying APIs
+
std::optional<int> getCUDAComputeCapability() const;
// Returns true if this target supports TensoreCore MMA ops with TF32
// input types.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index 993963c..f4584fe 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -563,7 +563,7 @@
//===----------------------------------------------------------------------===//
TargetAttr getMetalTargetDetails(MLIRContext *context) {
- return createTargetAttr(*getAppleTargetDetails(), /*arch=*/"",
+ return createTargetAttr(*getAppleTargetDetails(), /*arch=*/"apple",
/*features=*/"spirv:v1.3,cap:Shader", context);
}
@@ -603,6 +603,8 @@
// SPIR-V 1.4. For non-mobile GPUs we target Vulkan 1.3, which accepts
// SPIR-V 1.6 as the maximum.
+ // TODO: Add feature bits for physical storage buffer.
+
if (std::optional<TargetDetails> details = getAMDGPUTargetDetails(target)) {
return createTargetAttr(*details, normalizeAMDGPUTarget(target),
/*features=*/"spirv:v1.6,cap:Shader", context);
@@ -654,7 +656,8 @@
StringRef features, MLIRContext *context) {
return llvm::StringSwitch<TargetAttr>(targetAPI)
.Case("cuda", getCUDATargetDetails(aliasTarget, features, context))
- .Case("rocm", getHIPTargetDetails(aliasTarget, features, context))
+ .Case("hip", getHIPTargetDetails(aliasTarget, features, context))
+ .Case("vulkan", getVulkanTargetDetails(aliasTarget, context))
.Default(nullptr);
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
index f7d10c1..ee74582 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -10,15 +10,10 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinOps.h"
#define DEBUG_TYPE "iree-spirv-amd-config"
@@ -35,15 +30,14 @@
constexpr unsigned AMDNumMNTilesPerSubgroup = 8;
static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op,
- const spirv::TargetEnv &targetEnv) {
+ IREE::GPU::TargetAttr target) {
if (succeeded(setCooperativeMatrixConfig(
- targetEnv, op, AMDNumSubgroupsPerWorkgroup, AMDNumMNTilesPerSubgroup,
+ target, op, AMDNumSubgroupsPerWorkgroup, AMDNumMNTilesPerSubgroup,
AMDCoopMatrixSoftwarePipelineDepth,
AMDCoopMatrixSoftwarePipelineStoreStage)))
return success();
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- const int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
const std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 8};
std::array<int64_t, 3> threadMNK;
auto inputType =
@@ -53,7 +47,7 @@
} else {
threadMNK = {8, 4, 16};
}
- return setMatmulOpConfig(limits, op, workgroupXY, threadMNK,
+ return setMatmulOpConfig(target, op, workgroupXY, threadMNK,
/*enablePromotion=*/true,
AMDSimtSoftwarePipelineDepth,
AMDSimtSoftwarePipelineStoreStage);
@@ -71,14 +65,13 @@
// * Max 20 waves per SIMD32
// * Max 64KB LDS per workgroup
-LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
- return setAMDMatmulConfig(linalgOp, targetEnv);
+ return setAMDMatmulConfig(linalgOp, target);
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index dd76a97..6d8b815 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -14,15 +14,13 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/IR/BuiltinOps.h"
namespace mlir::iree_compiler::detail {
static LogicalResult setAdrenoMatmulConfig(linalg::LinalgOp op,
- spirv::ResourceLimitsAttr limits) {
- const int subgroupSize = limits.getSubgroupSize();
+ IREE::GPU::TargetAttr target) {
+ const int subgroupSize = target.getPreferredSubgroupSize();
const std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
std::array<int64_t, 3> threadMNK;
auto inputType =
@@ -32,24 +30,23 @@
} else {
threadMNK = {16, 4, 4};
}
- return setMatmulOpConfig(limits, op, workgroupXY, threadMNK);
+ return setMatmulOpConfig(target, op, workgroupXY, threadMNK);
}
//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//
-LogicalResult setAdrenoCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAdrenoCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize();
if (!isa<linalg::LinalgOp>(rootOp))
return failure();
auto linalgOp = cast<linalg::LinalgOp>(rootOp);
if (isMatmulOrBatchMatmul(linalgOp))
- return setAdrenoMatmulConfig(linalgOp, limits);
+ return setAdrenoMatmulConfig(linalgOp, target);
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
// Use the result type in case of larger bitwidth for accumulators.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
index 7f94716..8ec023b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
@@ -14,15 +14,12 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
namespace mlir::iree_compiler::detail {
static LogicalResult setAppleMatmulConfig(linalg::LinalgOp op,
- spirv::ResourceLimitsAttr limits) {
+ IREE::GPU::TargetAttr target) {
const std::array<int64_t, 2> workgroupXY = {256, 1};
std::array<int64_t, 3> threadMNK;
auto inputType =
@@ -32,21 +29,20 @@
} else {
threadMNK = {4, 4, 4};
}
- return setMatmulOpConfig(limits, op, workgroupXY, threadMNK);
+ return setMatmulOpConfig(target, op, workgroupXY, threadMNK);
}
//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//
-LogicalResult setAppleCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAppleCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize();
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
- return setAppleMatmulConfig(linalgOp, limits);
+ return setAppleMatmulConfig(linalgOp, target);
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index e72fdc5..b6e7a70 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -91,6 +91,7 @@
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
"//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 1378bbc..a2ced1e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -140,6 +140,7 @@
iree::compiler::Codegen::Common::GPU::GPUHeuristics
iree::compiler::Codegen::Common::TransformDialectInterpreterPass
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
iree::compiler::Codegen::TransformStrategies::GPU
iree::compiler::Codegen::Transforms
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 08b358d..1d6bca4 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -17,16 +17,12 @@
#include <cstdint>
#include <tuple>
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -41,15 +37,11 @@
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -596,17 +588,21 @@
}
}
- spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp);
- moduleOp->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
-
if (indexBits != 32 && indexBits != 64) {
moduleOp.emitOpError(
- "Only 32-bit or 64-bit indices are supported for SPIR-V");
+ "only 32-bit or 64-bit indices are supported for SPIR-V");
return signalPassFailure();
}
-
bool use64bitIndex = indexBits == 64;
+
+ auto targetAttr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
+ spirv::getTargetEnvAttrName());
+ if (!targetAttr) {
+ moduleOp.emitOpError("should contain a spirv.target_env attribute");
+ return signalPassFailure();
+ }
spirv::TargetEnv targetEnv(targetAttr);
+
if (use64bitIndex && !targetEnv.allows(spirv::Capability::Int64)) {
moduleOp.emitOpError(
"64-bit indices are not supported for the specified target "
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index e491182..5a13996 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -8,8 +8,8 @@
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
@@ -26,8 +26,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -605,7 +603,7 @@
namespace detail {
-LogicalResult setMatmulOpConfig(spirv::ResourceLimitsAttr limits,
+LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target,
linalg::LinalgOp op,
std::array<int64_t, 2> bestWorkgroupSizeXY,
std::array<int64_t, 3> bestThreadTileSizeMNK,
@@ -704,8 +702,8 @@
llvm::dbgs() << ")\n";
});
- const int subgroupSize = limits.getSubgroupSize();
- const int maxBytes = limits.getMaxComputeSharedMemorySize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ const int maxBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
// We want a 2-stage pipeline without multi-buffering if the depth is 0 to
// keep the default for compilation configs that don't specify a pipeline
@@ -844,17 +842,16 @@
namespace detail {
-LogicalResult setCooperativeMatrixConfig(
- const spirv::TargetEnv &targetEnv, linalg::LinalgOp op,
- const unsigned numSubgroupsPerWorkgroup,
- const unsigned numMNTilesPerSubgroup, unsigned softwarePipelineDepth,
- unsigned softwarePipelineStoreStage) {
- LLVM_DEBUG(llvm::dbgs() << "trying to matmul tensorcore config...\n");
+LogicalResult
+setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op,
+ const unsigned numSubgroupsPerWorkgroup,
+ const unsigned numMNTilesPerSubgroup,
+ unsigned softwarePipelineDepth,
+ unsigned softwarePipelineStoreStage) {
+ LLVM_DEBUG(llvm::dbgs() << "trying to matmul cooperative matrix config...\n");
// This configuration is only for cooperative matrix.
- if (!targetEnv.allows(spirv::Capability::CooperativeMatrixKHR) ||
- !targetEnv.allows(spirv::Extension::SPV_KHR_cooperative_matrix)) {
+ if (target.getWgp().getMma().empty())
return failure();
- }
if (op.hasDynamicShape())
return failure();
@@ -895,31 +892,23 @@
Type initElem = getElementType(init);
GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem);
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- auto properties =
- limits.getCooperativeMatrixPropertiesKhr()
- .getAsRange<spirv::CooperativeMatrixPropertiesKHRAttr>();
-
SmallVector<GPUMatmulShapeType> intrinsics;
- intrinsics.reserve(limits.getCooperativeMatrixPropertiesKhr().size());
- for (auto p : properties) {
- intrinsics.emplace_back(p.getMSize(), p.getNSize(), p.getKSize(),
- p.getAType(), p.getBType(), p.getCType());
+ intrinsics.reserve(target.getWgp().getMma().size());
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ auto [mSize, nSize, kSize] = mma.getMNKShape();
+ auto [aType, bType, cType] = mma.getABCElementTypes();
+ intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
}
GPUMMAHeuristicSeeds seeds{numSubgroupsPerWorkgroup, numMNTilesPerSubgroup,
numKTilesPerSubgroup};
int64_t sharedMemoryLimitInBytes =
- targetEnv.getResourceLimits().getMaxComputeSharedMemorySize();
+ target.getWgp().getMaxWorkgroupMemoryBytes();
- std::optional<int64_t> subgroupSize = limits.getSubgroupSize();
// AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use
// wave32 mode for better performance.
- if (targetEnv.getVendorID() == spirv::Vendor::AMD) {
- if (std::optional<int> minSize = limits.getMinSubgroupSize())
- subgroupSize = *minSize;
- }
+ int64_t subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/false);
// Infer if lhs or rhs is transposed to help generate better schedule.
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
@@ -932,13 +921,13 @@
FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes,
- *subgroupSize, transposedLhs, transposedRhs);
+ subgroupSize, transposedLhs, transposedRhs);
if (failed(schedule))
return failure();
auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
- std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * *subgroupSize,
+ std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * subgroupSize,
schedule->mWarpCount, 1};
SmallVector<int64_t> vectorSizes(kIndex + 1, 0);
@@ -982,7 +971,7 @@
bool promoteC = needToPrmoteCForCooperativeMatrix(op);
// Decrease pipeline depth until it fits in shared memory.
- const int maxBytes = limits.getMaxComputeSharedMemorySize();
+ const int maxBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
auto usedBytes =
getTileBytes(workgroupTileSizes[mIndex], workgroupTileSizes[nIndex],
reductionTileSizes[kIndex],
@@ -1007,10 +996,10 @@
// FFT Default Configuration
//===----------------------------------------------------------------------===//
-static LogicalResult setFftOpConfig(spirv::ResourceLimitsAttr limits,
+static LogicalResult setFftOpConfig(IREE::GPU::TargetAttr target,
IREE::LinalgExt::FftOp op) {
LLVM_DEBUG(llvm::dbgs() << "trying to deduce config as fft...\n");
- const int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
auto pipeline = CodeGenPipeline::SPIRVBaseDistribute;
std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
@@ -1046,7 +1035,7 @@
// Winograd Default Configuration
//===----------------------------------------------------------------------===//
-static LogicalResult setWinogradOpConfig(spirv::ResourceLimitsAttr limits,
+static LogicalResult setWinogradOpConfig(IREE::GPU::TargetAttr target,
IREE::LinalgExt::LinalgExtOp op) {
// Tiling is already done by tile and decompose, so we only set pipeline and
// workgroup size. The tile sizes below are placeholders and were obtained
@@ -1065,13 +1054,13 @@
//===----------------------------------------------------------------------===//
/// Set the configuration for reductions that can be mapped to warp reductions.
-static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
+static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target,
linalg::LinalgOp op) {
LLVM_DEBUG(llvm::dbgs() << "trying to deduce config as reduction...\n");
// This pipeline eventually generates non-uniform group shuffle ops, which
// requires special capability.
- if (!targetEnv.allows(spirv::Capability::GroupNonUniformShuffle))
+ if (!target.supportsSubgroupShuffle())
return failure();
SmallVector<unsigned> parallelDims;
@@ -1132,7 +1121,7 @@
if (!foundSingleReductionOutput)
return failure();
- const int subgroupSize = targetEnv.getResourceLimits().getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
@@ -1192,7 +1181,7 @@
// the workgroup size we use can divide the total reduction size, and it's
// also within hardware limitations.
const int64_t maxWorkgroupSize =
- targetEnv.getResourceLimits().getMaxComputeWorkgroupInvocations();
+ target.getWgp().getMaxThreadCountPerWorkgroup();
int64_t groupSize = reductionSize / vectorSize;
if (groupSize > maxWorkgroupSize) {
groupSize = GreatestCommonDivisor(APInt(64, uint64_t(groupSize)),
@@ -1308,7 +1297,7 @@
return bitwidth;
};
-static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
+static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target,
Operation *op,
bool allowVectorization = true) {
LLVM_DEBUG(llvm::dbgs() << "trying to deduce as default op...\n");
@@ -1327,7 +1316,7 @@
workgroupSize);
}
- const int subgroupSize = limits.getSubgroupSize();
+ int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
const unsigned loopDepth = partitionedLoops.back() + 1;
// Configurations we need to decide.
@@ -1542,7 +1531,7 @@
static LogicalResult
setTransformDialectConfig(mlir::FunctionOpInterface entryPoint, Operation *op,
- const spirv::TargetEnv &targetEnv) {
+ IREE::GPU::TargetAttr target) {
if (!clSPIRVEnableTransformDialectJit) {
return failure();
}
@@ -1551,30 +1540,24 @@
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
context, CodeGenPipeline::TransformDialectCodegen);
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
-
// TODO: unify the target information into one structure.
iree_compiler::gpu::GPUModel gpuModel;
- gpuModel.hasWarpShuffle =
- targetEnv.allows(spirv::Capability::GroupNonUniformShuffle);
+ gpuModel.hasWarpShuffle = target.supportsSubgroupShuffle();
gpuModel.hasTF32TensorCore = false;
gpuModel.hasMmaSync = false;
gpuModel.hasTF32TensorCore = false;
- gpuModel.minSubgroupSize = limits.getMinSubgroupSize();
- gpuModel.maxSubgroupSize = limits.getMaxSubgroupSize();
- gpuModel.maxWorkGroupInvocations = limits.getMaxComputeWorkgroupInvocations();
+ gpuModel.minSubgroupSize = target.getMinSubgroupSize();
+ gpuModel.maxSubgroupSize = target.getMaxSubgroupSize();
+ gpuModel.maxWorkGroupInvocations =
+ target.getWgp().getMaxThreadCountPerWorkgroup();
// Populates the supported WMMA fragment combinations from the target
// environment. Infer tf32 support from the list of supported fragment types.
- auto properties =
- limits.getCooperativeMatrixPropertiesKhr()
- .getAsRange<spirv::CooperativeMatrixPropertiesKHRAttr>();
- for (auto property : properties) {
- if (property.getScope().getValue() != spirv::Scope::Subgroup)
- continue;
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ auto [mSize, nSize, kSize] = mma.getMNKShape();
+ auto [aType, bType, cType] = mma.getABCElementTypes();
gpuModel.supportedWMMAConfigs.emplace_back(iree_compiler::gpu::MMAConfig{
- property.getMSize(), property.getNSize(), property.getKSize(),
- property.getAType(), property.getBType(), property.getCType()});
+ mSize, nSize, kSize, aType, bType, cType});
}
if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op,
@@ -1589,44 +1572,32 @@
/// Sets the CodeGen configuration as attributes to the given `rootOp` if it's a
/// known Linalg matmul/convolution op with good configurations.
-static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv,
+static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPointFn,
Operation *rootOp) {
// First try to see if there is a matching transform dialect configuration.
- if (succeeded(setTransformDialectConfig(entryPointFn, rootOp, targetEnv))) {
+ if (succeeded(setTransformDialectConfig(entryPointFn, rootOp, target))) {
return success();
}
// First try to find a proper CodeGen configuration to tile and vectorize for
// the current target architecture.
- switch (targetEnv.getVendorID()) {
- case spirv::Vendor::AMD:
- if (succeeded(detail::setAMDCodeGenConfig(targetEnv, rootOp)))
- return success();
- break;
- case spirv::Vendor::Apple:
- if (succeeded(detail::setAppleCodeGenConfig(targetEnv, rootOp)))
- return success();
- break;
- case spirv::Vendor::ARM:
- if (succeeded(detail::setMaliCodeGenConfig(targetEnv, rootOp)))
- return success();
- break;
- case spirv::Vendor::NVIDIA:
- if (succeeded(detail::setNVIDIACodeGenConfig(targetEnv, rootOp)))
- return success();
- break;
- case spirv::Vendor::Qualcomm:
- if (succeeded(detail::setAdrenoCodeGenConfig(targetEnv, rootOp)))
- return success();
- break;
- default:
- break;
- }
+ if (target.isAMD() && succeeded(detail::setAMDCodeGenConfig(target, rootOp)))
+ return success();
+ if (target.isApple() &&
+ succeeded(detail::setAppleCodeGenConfig(target, rootOp)))
+ return success();
+ if (target.isARM() && succeeded(detail::setMaliCodeGenConfig(target, rootOp)))
+ return success();
+ if (target.isNVIDIA() &&
+ succeeded(detail::setNVIDIACodeGenConfig(target, rootOp)))
+ return success();
+ if (target.isQualcomm() &&
+ succeeded(detail::setAdrenoCodeGenConfig(target, rootOp)))
+ return success();
// Otherwise fallback to use a default configuration that tiles and
// distributes/vectorizes.
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
return TypeSwitch<Operation *, LogicalResult>(rootOp)
.Case<linalg::BatchMatmulOp, linalg::MatmulOp>([&](auto op) {
// Try to tile and vectorize first. It's common to see 32 threads
@@ -1640,19 +1611,19 @@
threadMNK = {8, 8, 4};
}
auto result =
- detail::setMatmulOpConfig(limits, op, workgroupXY, threadMNK);
+ detail::setMatmulOpConfig(target, op, workgroupXY, threadMNK);
if (succeeded(result))
return success();
LLVM_DEBUG(llvm::dbgs()
<< "failed to set matmul op config, trying reduction\n");
- if (succeeded(setReductionConfig(targetEnv, op)))
+ if (succeeded(setReductionConfig(target, op)))
return success();
// If unsuccessful, try to tile and distribute.
- return setDefaultOpConfig(limits, op);
+ return setDefaultOpConfig(target, op);
})
- .Case<linalg::ConvolutionOpInterface>([limits](auto op) {
+ .Case<linalg::ConvolutionOpInterface>([target](auto op) {
// Use the result type in case of larger bitwidth for accumulators.
auto type = cast<ShapedType>(op->getResult(0).getType());
const int bitwidth = type.getElementTypeBitWidth();
@@ -1667,27 +1638,27 @@
}
// If unsuccessful, try to tile and distribute/vectorize.
- return setDefaultOpConfig(limits, op);
+ return setDefaultOpConfig(target, op);
})
.Case<linalg::GenericOp>([&](linalg::GenericOp op) {
LLVM_DEBUG(llvm::dbgs() << "figuring configuration for generic op\n");
- if (succeeded(setReductionConfig(targetEnv, op)))
+ if (succeeded(setReductionConfig(target, op)))
return success();
// If a generic op has reduction iterator types, it can be treated as a
// root op for configuration as well. Use the default configuration,
// which will mark it as a root.
if (op.getNumLoops() != op.getNumParallelLoops()) {
- return setDefaultOpConfig(limits, op);
+ return setDefaultOpConfig(target, op);
}
return failure();
})
- .Case<IREE::LinalgExt::FftOp>([limits](IREE::LinalgExt::FftOp op) {
- return setFftOpConfig(limits, op);
+ .Case<IREE::LinalgExt::FftOp>([target](IREE::LinalgExt::FftOp op) {
+ return setFftOpConfig(target, op);
})
.Case<IREE::LinalgExt::WinogradInputTransformOp,
IREE::LinalgExt::WinogradOutputTransformOp>(
- [&](auto op) { return setWinogradOpConfig(limits, op); })
+ [&](auto op) { return setWinogradOpConfig(target, op); })
.Default([](Operation *) { return failure(); });
};
@@ -1695,7 +1666,7 @@
// Entry Point
//===----------------------------------------------------------------------===//
-static LogicalResult setConfigForKernel(const spirv::TargetEnv &targetEnv,
+static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface funcOp) {
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (computeOps.empty()) {
@@ -1717,14 +1688,13 @@
}
for (Operation *computeOp : roots) {
- if (succeeded(setSPIRVOpConfig(targetEnv, funcOp, computeOp)))
+ if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp)))
return success();
}
Operation *computeOp = roots.back();
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
// If there are still no root op, check for any linalg.generic op.
- if (succeeded(setDefaultOpConfig(limits, computeOp)))
+ if (succeeded(setDefaultOpConfig(target, computeOp)))
return success();
// Check if the op configuration was set.
@@ -1734,15 +1704,12 @@
}
LogicalResult initSPIRVLaunchConfig(FunctionOpInterface funcOp) {
- spirv::TargetEnvAttr targetEnvAttr = getSPIRVTargetEnvAttr(funcOp);
- if (!targetEnvAttr) {
- return funcOp.emitOpError(
- "expected parent hal.executable.variant to have spirv.target_env "
- "attribute");
- }
- if (getTranslationInfo(funcOp)) {
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+ if (!target)
+ return funcOp.emitError("missing GPU target in #hal.executable.target");
+
+ if (getTranslationInfo(funcOp))
return success();
- }
if (auto exportOp = getEntryPoint(funcOp)) {
// If no translation info set, first check whether we already have workgroup
@@ -1762,8 +1729,7 @@
}
}
- spirv::TargetEnv targetEnv(targetEnvAttr);
- if (failed(setConfigForKernel(targetEnv, funcOp))) {
+ if (failed(setConfigForKernel(target, funcOp))) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.h b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.h
index e831200..6d59589 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.h
@@ -17,9 +17,9 @@
#include <array>
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir::iree_compiler {
@@ -64,7 +64,7 @@
/// Sets CodeGen configurations via attributes to the given matmul `linalgOp`
/// with the given best workgroup size and tile size hints.
LogicalResult setMatmulOpConfig(
- spirv::ResourceLimitsAttr limits, linalg::LinalgOp linalgOp,
+ IREE::GPU::TargetAttr target, linalg::LinalgOp linalgOp,
std::array<int64_t, 2> bestWorkgroupSizeXY,
std::array<int64_t, 3> bestThreadTileSizeMNK, bool enablePromotion = false,
unsigned softwarePipelineDepth = defaultSimtSoftwarePipelineDepth,
@@ -75,7 +75,7 @@
/// with tile sizes for cooperative matrix, if possible for the given matmul
/// size.
LogicalResult setCooperativeMatrixConfig(
- const spirv::TargetEnv &targetEnv, linalg::LinalgOp op,
+ IREE::GPU::TargetAttr target, linalg::LinalgOp op,
const unsigned numSubgroupsPerWorkgroup,
const unsigned numMNTilesPerSubgroup,
unsigned softwarePipelineDepth = defaultCoopMatrixSoftwarePipelineDepth,
@@ -91,15 +91,15 @@
/// Returns success when a configuration is successfullly attached as attribute.
/// Returns failure otherwise.
-LogicalResult setAdrenoCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAdrenoCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp);
-LogicalResult setAppleCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAppleCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp);
-LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp);
-LogicalResult setMaliCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setMaliCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp);
-LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setNVIDIACodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp);
} // namespace detail
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index a68cf33..7caab11 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -14,15 +14,14 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir::iree_compiler::detail {
static LogicalResult setMaliMatmulConfig(linalg::LinalgOp op,
- spirv::ResourceLimitsAttr limits) {
- const int subgroupSize = limits.getSubgroupSize();
+ IREE::GPU::TargetAttr target) {
+ const int subgroupSize = target.getPreferredSubgroupSize();
const std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
std::array<int64_t, 3> threadMNK;
Type inputType = op.getDpsInputOperand(0)->get().getType();
@@ -34,21 +33,20 @@
} else {
threadMNK = {6, 4, 4};
}
- return setMatmulOpConfig(limits, op, workgroupXY, threadMNK);
+ return setMatmulOpConfig(target, op, workgroupXY, threadMNK);
}
//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//
-LogicalResult setMaliCodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setMaliCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- int subgroupSize = limits.getSubgroupSize();
+ const int subgroupSize = target.getPreferredSubgroupSize();
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
- return setMaliMatmulConfig(linalgOp, limits);
+ return setMaliMatmulConfig(linalgOp, target);
}
if (auto convOp = dyn_cast<linalg::ConvolutionOpInterface>(rootOp)) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
index f5f7b9c..357a626 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
@@ -10,16 +10,10 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/APInt.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
#define DEBUG_TYPE "iree-spirv-nvidia-config"
@@ -32,15 +26,14 @@
namespace mlir::iree_compiler::detail {
static LogicalResult setNVIDIAMatmulConfig(linalg::LinalgOp op,
- const spirv::TargetEnv &targetEnv) {
+ IREE::GPU::TargetAttr target) {
// First try to see if we can use tensor cores.
- spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
- if (succeeded(setCooperativeMatrixConfig(targetEnv, op,
+ if (succeeded(setCooperativeMatrixConfig(target, op,
NVIDIANumSubgroupsPerWorkgroup,
NVIDIANumMNTilesPerSubgroup)))
return success();
- const int subgroupSize = limits.getSubgroupSize();
+ const int subgroupSize = target.getPreferredSubgroupSize();
const std::array<int64_t, 2> workgroupXY = {subgroupSize, 8};
std::array<int64_t, 3> threadMNK;
auto inputType =
@@ -50,7 +43,7 @@
} else {
threadMNK = {4, 4, 32};
}
- return setMatmulOpConfig(limits, op, workgroupXY, threadMNK,
+ return setMatmulOpConfig(target, op, workgroupXY, threadMNK,
/*enablePromotion=*/true);
}
@@ -84,11 +77,11 @@
// Note that the above numbers are from CUDA docs; for Vulkan the drivder can
// expose slightly different numbers, e.g., max shared memory size is smaller.
-LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
+LogicalResult setNVIDIACodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
- return setNVIDIAMatmulConfig(linalgOp, targetEnv);
+ return setNVIDIAMatmulConfig(linalgOp, target);
}
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 2963476..fd2b021 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -18,7 +18,6 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Utils/PassUtils.h"
@@ -29,12 +28,13 @@
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
@@ -183,10 +183,9 @@
.addPass(createPadDynamicAlloc);
// Check to make sure we are not exceeding shared memory usage limit.
- auto getSharedMemoryLimit = [](mlir::FunctionOpInterface func) {
- auto moduleOp = func->getParentOfType<ModuleOp>();
- spirv::TargetEnvAttr target = getSPIRVTargetEnvAttr(moduleOp);
- return target.getResourceLimits().getMaxComputeSharedMemorySize();
+ auto getSharedMemoryLimit = [](mlir::FunctionOpInterface fn) {
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(fn);
+ return target.getWgp().getMaxWorkgroupMemoryBytes();
};
// TODO: query this from the target.
auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 32; };
@@ -252,10 +251,12 @@
.addPass(createCanonicalizerPass)
.addPass(createCSEPass);
+ modulePassManager.addPass(createSPIRVConvertGPUTargetPass());
modulePassManager.addPass(createConvertToSPIRVPass(clSPIRVIndexingBits));
auto getTargetEnv = [](spirv::ModuleOp moduleOp) {
- return getSPIRVTargetEnvAttr(moduleOp);
+ return moduleOp->getParentOfType<mlir::ModuleOp>()
+ ->getAttrOfType<spirv::TargetEnvAttr>(spirv::getTargetEnvAttrName());
};
OpPassManager &spirvModulePassManager =
@@ -599,8 +600,7 @@
auto getWarpSize = [](mlir::FunctionOpInterface func) -> int {
// TODO: This kind of call back function is a really really bad idea
// This should be easier to resolve than doing this.
- std::optional<int64_t> subgroupSize = getSPIRVSubgroupSize(func);
- return subgroupSize.value_or(32);
+ return *getGPUSubgroupSize(func, /*pickLargest=*/true);
};
// Handle vector reduction operations specifically.
@@ -631,8 +631,6 @@
void buildSPIRVCodegenConfigurationPassPipeline(
OpPassManager &variantPassManager) {
- // TODO: move the following pass to be immediately before ConvertToSPIRVPass.
- variantPassManager.addPass(createSPIRVConvertGPUTargetPass());
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
buildSPIRVCodegenConfigurationPassPipelineImpl(modulePassManager);
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index a0b0d16..31f9202 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -87,8 +87,7 @@
createSPIRVBreakDownLargeVectorPass();
// Converts #iree_gpu.target into #spirv.target_env.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConvertGPUTargetPass();
+std::unique_ptr<OperationPass<ModuleOp>> createSPIRVConvertGPUTargetPass();
/// Emulates bfloat 16 ops with 32-bit float ops.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
index dc94eb2..6436102 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
@@ -33,9 +33,7 @@
let constructor = "mlir::iree_compiler::createSPIRVBreakDownLargeVectorPass()";
}
-def SPIRVConvertGPUTarget :
- Pass<"iree-spirv-convert-gpu-target",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+def SPIRVConvertGPUTarget : Pass<"iree-spirv-convert-gpu-target", "mlir::ModuleOp"> {
let summary = "Convert #iree_gpu.target into #spirv.target_env";
let constructor = "mlir::iree_compiler::createSPIRVConvertGPUTargetPass()";
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
index fc9cfe3..d82ab9d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
@@ -96,14 +95,16 @@
.Default(ClientAPI::Unknown);
}
-Vendor deduceVendor(StringRef arch) {
- if (arch.starts_with("gfx") || arch.starts_with("rdna"))
+Vendor deduceVendor(IREE::GPU::TargetAttr target) {
+ if (target.isAMD())
return Vendor::AMD;
- if (arch.starts_with("valhall"))
+ if (target.isApple())
+ return Vendor::Apple;
+ if (target.isARM())
return Vendor::ARM;
- if (arch.starts_with("sm_"))
+ if (target.isNVIDIA())
return Vendor::NVIDIA;
- if (arch.starts_with("adreno"))
+ if (target.isQualcomm())
return Vendor::Qualcomm;
return Vendor::Unknown;
}
@@ -181,9 +182,9 @@
}
}
-spirv::ResourceLimitsAttr convertLimits(StringRef arch,
- IREE::GPU::TargetWgpAttr wgp) {
- MLIRContext *context = wgp.getContext();
+spirv::ResourceLimitsAttr convertLimits(IREE::GPU::TargetAttr target) {
+ MLIRContext *context = target.getContext();
+ IREE::GPU::TargetWgpAttr wgp = target.getWgp();
Builder b(context);
SmallVector<Attribute, 4> coopMatAttrs;
@@ -196,19 +197,18 @@
spirv::ScopeAttr::get(context, spirv::Scope::Subgroup)));
}
- ArrayRef<int> subgroupSizes = wgp.getSubgroupSizeChoices().asArrayRef();
- const int minSubgroupSize = *llvm::min_element(subgroupSizes);
- const int maxSubgroupSize = *llvm::max_element(subgroupSizes);
// This is mostly to match RDNA behavior on Vulkan--RDNA supports either 32 or
// 64 as subgroup sizes; the default subgroup size is 64.
- const int preferredSubgroupSize = maxSubgroupSize;
+ const int preferredSubgroupSize =
+ target.getPreferredSubgroupSize(/*pickLargest=*/true);
return spirv::ResourceLimitsAttr::get(
context, wgp.getMaxWorkgroupMemoryBytes(),
wgp.getMaxThreadCountPerWorkgroup(),
b.getI32ArrayAttr(wgp.getMaxWorkgroupSizes().asArrayRef()),
- preferredSubgroupSize, minSubgroupSize, maxSubgroupSize,
- ArrayAttr::get(context, coopMatAttrs), ArrayAttr{});
+ preferredSubgroupSize, target.getMinSubgroupSize(),
+ target.getMaxSubgroupSize(), ArrayAttr::get(context, coopMatAttrs),
+ ArrayAttr{});
}
//===----------------------------------------------------------------------===//
@@ -246,9 +246,9 @@
auto triple = spirv::VerCapExtAttr::get(
*version, caps.getArrayRef(), exts.getArrayRef(), variant.getContext());
return spirv::TargetEnvAttr::get(
- triple, convertLimits(gpuTarget.getArch(), wgp),
- deduceClientAPI(target.getBackend()), deduceVendor(gpuTarget.getArch()),
- spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
+ triple, convertLimits(gpuTarget), deduceClientAPI(target.getBackend()),
+ deduceVendor(gpuTarget), spirv::DeviceType::Unknown,
+ spirv::TargetEnvAttr::kUnknownDeviceID);
}
struct SPIRVConvertGPUTargetPass final
@@ -258,29 +258,20 @@
}
void runOnOperation() override {
- IREE::HAL::ExecutableVariantOp variant = getOperation();
- IREE::HAL::ExecutableTargetAttr target = variant.getTarget();
+ mlir::ModuleOp moduleOp = getOperation();
+ auto variant = moduleOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
FailureOr<spirv::TargetEnvAttr> spirvTarget = convertGPUTarget(variant);
if (failed(spirvTarget))
return signalPassFailure();
- Builder b(&getContext());
- auto attrs = llvm::to_vector(target.getConfiguration().getValue());
- attrs.emplace_back(b.getStringAttr(spirv::getTargetEnvAttrName()),
- *spirvTarget);
- auto configAttr = b.getDictionaryAttr(attrs);
-
- auto halTarget = IREE::HAL::ExecutableTargetAttr::get(
- target.getContext(), target.getBackend(), target.getFormat(),
- configAttr);
- variant.setTargetAttr(halTarget);
+ moduleOp->setAttr(spirv::getTargetEnvAttrName(), *spirvTarget);
}
};
} // namespace
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
createSPIRVConvertGPUTargetPass() {
return std::make_unique<SPIRVConvertGPUTargetPass>();
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
index 9649e21..0e0843f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
@@ -11,9 +11,10 @@
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/SmallVector.h"
@@ -153,10 +154,10 @@
}
static bool supportsI64(FunctionOpInterface op) {
- spirv::TargetEnvAttr attr = getSPIRVTargetEnvAttr(op);
- assert(attr && "Not a valid spirv module");
- spirv::TargetEnv env(attr);
- return env.allows(spirv::Capability::Int64);
+ IREE::GPU::TargetAttr attr = getGPUTargetAttr(op);
+ assert(attr && "Missing GPU target");
+ return IREE::GPU::bitEnumContainsAll(attr.getWgp().getCompute().getValue(),
+ IREE::GPU::ComputeBitwidths::Int64);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
index 12fd88c..c39d4f8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
@@ -13,16 +13,16 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -260,27 +260,17 @@
/// Returns true when the target environment support integer dot product ops.
bool supportsIntegerDotProductOps(mlir::FunctionOpInterface fn) {
- spirv::TargetEnvAttr targetEnvAttr = getSPIRVTargetEnvAttr(fn);
- if (!targetEnvAttr) {
- // Alternatively, check if the function op itself has a target env
- // attribute. This may be preferred in tests.
- targetEnvAttr =
- fn->getAttrOfType<spirv::TargetEnvAttr>(spirv::getTargetEnvAttrName());
- if (!targetEnvAttr)
- return false;
- }
-
- spirv::TargetEnv targetEnv(targetEnvAttr);
- if (!targetEnv.allows(spirv::Extension::SPV_KHR_integer_dot_product))
+ // First check if the function op itself has a target env attribute. This may
+ // be preferred in tests.
+ auto targetEnvAttr =
+ fn->getAttrOfType<IREE::GPU::TargetAttr>("iree.gpu.target");
+ if (!targetEnvAttr)
+ targetEnvAttr = getGPUTargetAttr(fn);
+ if (!targetEnvAttr)
return false;
- // Query all the dot prod capabilities except for the packed one -- none of
- // the vectorization patterns need it.
- if (!targetEnv.allows(spirv::Capability::DotProduct))
- return false;
- if (!targetEnv.allows(spirv::Capability::DotProductInput4x8Bit))
- return false;
- if (!targetEnv.allows(spirv::Capability::DotProductInputAll))
+ if (!IREE::GPU::bitEnumContainsAll(targetEnvAttr.getWgp().getDot().getValue(),
+ IREE::GPU::DotProductOps::DP4xI8ToI32))
return false;
return true;
@@ -292,7 +282,7 @@
void getDependentDialects(DialectRegistry ®istry) const override {
// vector.gather lowering patterns target scf ops.
registry.insert<linalg::LinalgDialect, vector::VectorDialect,
- scf::SCFDialect>();
+ scf::SCFDialect, spirv::SPIRVDialect>();
}
void runOnOperation() override {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp
index 2c7fdbc..80f0b8a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp
@@ -4,15 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "llvm/ADT/StringExtras.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -71,6 +73,22 @@
return spirv::mapMemorySpaceToOpenCLStorageClass(attr);
}
+bool allowsShaderCapability(ArrayRef<StringRef> features) {
+ for (StringRef feature : features) {
+ if (feature.consume_front("cap:") && feature == "Shader")
+ return true;
+ }
+ return false;
+}
+
+bool allowsKernelCapability(ArrayRef<StringRef> features) {
+ for (StringRef feature : features) {
+ if (feature.consume_front("cap:") && feature == "Kernel")
+ return true;
+ }
+ return false;
+}
+
struct SPIRVMapMemRefStorageClassPass final
: public SPIRVMapMemRefStorageClassBase<SPIRVMapMemRefStorageClassPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -88,13 +106,14 @@
spirv::MemorySpaceToStorageClassMap memorySpaceMap;
- if (spirv::TargetEnvAttr attr = getSPIRVTargetEnvAttr(op)) {
- spirv::TargetEnv targetEnv(attr);
- if (targetEnv.allows(spirv::Capability::Shader)) {
+ if (IREE::GPU::TargetAttr target = getGPUTargetAttr(op)) {
+ SmallVector<StringRef> features;
+ llvm::SplitString(target.getFeatures(), features, ",");
+ if (allowsShaderCapability(features)) {
memorySpaceMap = useIndirectBindings
? &mapHALDescriptorTypeForVulkan<true>
: &mapHALDescriptorTypeForVulkan<false>;
- } else if (targetEnv.allows(spirv::Capability::Kernel)) {
+ } else if (allowsKernelCapability(features)) {
memorySpaceMap = mapHALDescriptorTypeForOpenCL;
}
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
index b1f1672..f1606cb 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
@@ -6,11 +6,11 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -42,12 +42,12 @@
// TODO(qedawkins): Once TransformStrategies is deprecated, drop the
// unnecessary dialect registrations.
registry
- .insert<IREE::Codegen::IREECodegenDialect, affine::AffineDialect,
- gpu::GPUDialect, IREE::HAL::HALDialect, linalg::LinalgDialect,
- IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
- bufferization::BufferizationDialect, scf::SCFDialect,
- spirv::SPIRVDialect, transform::TransformDialect,
- vector::VectorDialect>();
+ .insert<IREE::Codegen::IREECodegenDialect, IREE::GPU::IREEGPUDialect,
+ affine::AffineDialect, gpu::GPUDialect, IREE::HAL::HALDialect,
+ linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect,
+ memref::MemRefDialect, bufferization::BufferizationDialect,
+ scf::SCFDialect, spirv::SPIRVDialect,
+ transform::TransformDialect, vector::VectorDialect>();
}
void runOnOperation() override;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index 5bdca3f..59d6071 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -13,19 +13,16 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
-#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -33,11 +30,8 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-spirv-tile-and-distribute"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index 1f780ae..9ec9e77 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
@@ -112,7 +113,7 @@
: promoteCMatrix(promoteCMatrix), skipThreadLevel(skipThreadLevel) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<gpu::GPUDialect>();
+ registry.insert<gpu::GPUDialect, IREE::GPU::IREEGPUDialect>();
}
LogicalResult initializeOptions(
@@ -198,9 +199,10 @@
SmallVector<int64_t> &workgroupSize = maybeWorkgroupSize.value();
int64_t totalThreads = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
- std::optional<int> subgroupSize = getSPIRVSubgroupSize(funcOp);
+ std::optional<int> subgroupSize =
+ getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
if (!subgroupSize) {
- funcOp->emitError("failed to query subgroup size");
+ funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 0743d97..4a61471 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -16,10 +16,10 @@
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
@@ -68,33 +68,24 @@
constexpr char coopMatShapeAttrName[] = "iree.spirv.coopmatrix.shape";
/// Sets the chosen cooperative matrix type/shape for CodeGen onto the
-/// hal.executable.export op for the given `funcOp`.
+/// the given `funcOp`.
void setSPIRVCooperativeMatrixInfo(mlir::FunctionOpInterface funcOp,
linalg::LinalgOp rootOp,
ArrayRef<int64_t> shape) {
- auto exportOp = getEntryPoint(funcOp);
- if (!exportOp) {
- return;
- }
-
Builder b(funcOp.getContext());
- exportOp.value()->setAttr(coopMatShapeAttrName,
- b.getDenseI64ArrayAttr(shape));
+ funcOp->setAttr(coopMatShapeAttrName, b.getDenseI64ArrayAttr(shape));
auto inputType = cast<ShapedType>(rootOp.getDpsInputs().front().getType());
auto outputType = cast<ShapedType>(rootOp.getDpsInits().front().getType());
auto elementTypes = b.getTypeArrayAttr(
{inputType.getElementType(), outputType.getElementType()});
- exportOp.value()->setAttr(coopMatTypeAttrName, elementTypes);
+ funcOp->setAttr(coopMatTypeAttrName, elementTypes);
}
-/// Returns the chosen cooperative matrix shape for CodeGen from the
-/// hal.executable.export op for the given `funcOp`. Returns an empty
-/// ArrayRef if cannot query.
+/// Returns the chosen cooperative matrix shape for CodeGen from the given
+/// `funcOp`. Returns an empty ArrayRef if cannot query.
ArrayRef<int64_t>
getSPIRVCooperativeMatrixShape(mlir::FunctionOpInterface funcOp) {
- auto exportOp = getEntryPoint(funcOp);
- auto attr =
- exportOp.value()->getAttrOfType<DenseI64ArrayAttr>(coopMatShapeAttrName);
+ auto attr = funcOp->getAttrOfType<DenseI64ArrayAttr>(coopMatShapeAttrName);
if (!attr)
return {};
return attr.asArrayRef();
@@ -336,8 +327,8 @@
: public SPIRVTileToCooperativeOpsBase<SPIRVTileToCooperativeOpsPass> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<gpu::GPUDialect, linalg::LinalgDialect,
- vector::VectorDialect>();
+ registry.insert<gpu::GPUDialect, IREE::GPU::IREEGPUDialect,
+ linalg::LinalgDialect, vector::VectorDialect>();
}
void runOnOperation() override {
@@ -371,11 +362,13 @@
// Then tile and distribute to subgroups.
{
- std::optional<int> subgroupSize = getSPIRVSubgroupSize(funcOp);
+ std::optional<int> subgroupSize =
+ getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
if (!subgroupSize) {
funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
}
+
SmallVector<int64_t> subgroupTileSizes = getTileSizes(rootOp, 1);
if (failed(tileToSubgroup(funcOp, subgroupCounts, *subgroupSize,
subgroupTileSizes))) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
index 23d7a48..defbc89 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -41,13 +41,6 @@
return targetAttr.getConfiguration();
}
-spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op) {
- DictionaryAttr config = getTargetConfigAttr(op);
- if (!config)
- return nullptr;
- return config.getAs<spirv::TargetEnvAttr>(spirv::getTargetEnvAttrName());
-}
-
UnitAttr getIndirectBindingsAttr(Operation *op) {
DictionaryAttr config = getTargetConfigAttr(op);
if (!config)
@@ -56,18 +49,6 @@
return config.getAs<UnitAttr>("hal.bindings.indirect");
}
-std::optional<int> getSPIRVSubgroupSize(mlir::FunctionOpInterface funcOp) {
- std::optional<int64_t> subgroupSize = getSubgroupSize(funcOp);
- if (subgroupSize) {
- return subgroupSize.value();
- }
-
- spirv::TargetEnvAttr target = getSPIRVTargetEnvAttr(funcOp);
- if (!target)
- return std::nullopt;
- return target.getResourceLimits().getSubgroupSize();
-}
-
FailureOr<SmallVector<int64_t>>
getSPIRVTileSize(mlir::FunctionOpInterface funcOp, int tilingLevel) {
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
index f54fd62..2462fcd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -32,17 +31,9 @@
/// Given an operation, returns the HAL target config attribute.
DictionaryAttr getTargetConfigAttr(Operation *op);
-/// Given an operation, returns the `spirv.target_env` attribute.
-spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op);
-
/// Given an operation, returns the `hal.bindings.indirect` attribute.
UnitAttr getIndirectBindingsAttr(Operation *op);
-/// Given a FuncOp, returns the subgroup size to use for CodeGen, by first
-/// querying the hal.executable.export op, and then the SPIR-V target
-/// environment. Returns std::nullopt on failures.
-std::optional<int> getSPIRVSubgroupSize(mlir::FunctionOpInterface funcOp);
-
/// Returns the tile sizes at the given `tilingLevel` for compute ops in
/// `funcOp`.
FailureOr<SmallVector<int64_t>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
index 102c442..12ccd8f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
@@ -4,14 +4,12 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
-#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
-#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -54,20 +52,17 @@
"invalid matmul configuration without pipelining config");
}
- // Get spirv.target_env attributes
- const spirv::TargetEnvAttr targetEnvAttr = getSPIRVTargetEnvAttr(op);
- const spirv::TargetEnv targetEnv(targetEnvAttr);
- const auto limits = targetEnv.getResourceLimits();
- LLVM_DEBUG(llvm::dbgs() << "target environment: " << targetEnvAttr << "\n");
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(op);
+ LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");
auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
- const std::optional<int> subgroupSize = getSPIRVSubgroupSize(funcOp);
+ std::optional<int> subgroupSize =
+ getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
- const int maxThreads = limits.getMaxComputeWorkgroupInvocations();
- const auto maxWorkGroupSize = llvm::map_to_vector<3>(
- limits.getMaxComputeWorkgroupSize().getAsValueRange<IntegerAttr>(),
- [](const APInt &dim) { return dim.getSExtValue(); });
+ const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
+ const auto maxWorkGroupSize =
+ target.getWgp().getMaxWorkgroupSizes().asArrayRef();
if (workgroupSize.size() < 3) {
return funcOp->emitOpError("expected workgroup size to have three "
@@ -170,20 +165,17 @@
"invalid cooperative matrix configuration without pipelining config");
}
- // Get spirv.target_env attributes
- const spirv::TargetEnvAttr targetEnvAttr = getSPIRVTargetEnvAttr(op);
- const spirv::TargetEnv targetEnv(targetEnvAttr);
- const auto limits = targetEnv.getResourceLimits();
- LLVM_DEBUG(llvm::dbgs() << "target environment: " << targetEnvAttr << "\n");
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(op);
+ LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");
auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
- const std::optional<int> subgroupSize = getSPIRVSubgroupSize(funcOp);
+ std::optional<int> subgroupSize =
+ getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
- const int maxThreads = limits.getMaxComputeWorkgroupInvocations();
- const auto maxWorkGroupSize = llvm::map_to_vector<3>(
- limits.getMaxComputeWorkgroupSize().getAsValueRange<IntegerAttr>(),
- [](const APInt &dim) { return dim.getSExtValue(); });
+ const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
+ const auto maxWorkGroupSize =
+ target.getWgp().getMaxWorkgroupSizes().asArrayRef();
// Verify each dimension of workgroupSize should be power of two.
if (!llvm::isPowerOf2_64(workgroupSize[0]) ||
@@ -256,29 +248,24 @@
Type rhsType = getElementType(op->getOperand(1));
Type resultType = getElementType(op->getOperand(2));
- auto properties =
- limits.getCooperativeMatrixPropertiesKhr()
- .getAsRange<spirv::CooperativeMatrixPropertiesKHRAttr>();
-
// Verify that the fourth level tile sizes match cooperative matrix,
// and subgroup tile sizes should be multiple of cooperative matrix (M, N, K)
// sizes.
bool isNativeVectorSizeAccepted = false;
- for (auto p : properties) {
- if (p.getAType() == lhsType && p.getBType() == rhsType &&
- p.getCType() == resultType &&
- p.getScope().getValue() == spirv::Scope::Subgroup &&
- p.getMSize() == nativeVectorSizes[0] &&
- p.getNSize() == nativeVectorSizes[1] &&
- p.getKSize() == nativeVectorSizes[2]) {
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ auto [mSize, nSize, kSize] = mma.getMNKShape();
+ auto [aType, bType, cType] = mma.getABCElementTypes();
+
+ if (aType == lhsType && bType == rhsType && cType == resultType &&
+ mSize == nativeVectorSizes[0] && nSize == nativeVectorSizes[1] &&
+ kSize == nativeVectorSizes[2]) {
isNativeVectorSizeAccepted = true;
- if (subgroupTileSizes[0] % p.getMSize() != 0 ||
- subgroupTileSizes[1] % p.getNSize() != 0 ||
- reductionTileSizes[2] % p.getKSize() != 0) {
+ if (subgroupTileSizes[0] % mSize != 0 ||
+ subgroupTileSizes[1] % nSize != 0 ||
+ reductionTileSizes[2] % kSize != 0) {
return op->emitOpError(
"expected subgroup tile sizes to be multiple of ")
- << "[" << p.getMSize() << ", " << p.getNSize() << ", "
- << p.getKSize() << "]";
+ << "[" << mSize << ", " << nSize << ", " << kSize << "]";
}
}
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index eb59474..3886c6e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -25,11 +25,11 @@
"config_amd_conv.mlir",
"config_amd_matmul.mlir",
"config_amd_matmul_cooperative_ops.mlir",
+ "config_amd_matvec.mlir",
"config_default_conv.mlir",
"config_default_linalg_ext_ops.mlir",
"config_default_linalg_ops.mlir",
"config_default_matmul.mlir",
- "config_default_matvec.mlir",
"config_default_misc.mlir",
"config_default_reduction.mlir",
"config_default_sub_byte_types.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 273e581..078f92a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -21,11 +21,11 @@
"config_amd_conv.mlir"
"config_amd_matmul.mlir"
"config_amd_matmul_cooperative_ops.mlir"
+ "config_amd_matvec.mlir"
"config_default_conv.mlir"
"config_default_linalg_ext_ops.mlir"
"config_default_linalg_ops.mlir"
"config_default_matmul.mlir"
- "config_default_matvec.mlir"
"config_default_misc.mlir"
"config_default_reduction.mlir"
"config_default_sub_byte_types.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
index d7c32a6..fe7f86b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=adreno --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
// Conv - large OC - distribute to only one workgroup dimension.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @conv_112x112x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_112x112x512() {
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c112 = arith.constant 112 : index
@@ -33,9 +32,8 @@
// Conv - medium OC/OW/OH - distribute to two workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @conv_112x112x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_112x112x32() {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c112 = arith.constant 112 : index
@@ -64,9 +62,8 @@
// Conv - small OC/OW/OH - distribute to all three workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @conv_16x16x16() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_16x16x16() {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -93,9 +90,8 @@
// Depthwise conv - small OC/OW/OH - distribute to all three workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @dwconv_28x28x144() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @dwconv_28x28x144() {
%c0 = arith.constant 0 : index
%c144 = arith.constant 144 : index
%c28 = arith.constant 28 : index
@@ -124,9 +120,8 @@
// Depthwise conv - tiny OC/OW/OH - starving the GPU.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @dwconv_4x4x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @dwconv_4x4x8() {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
index ee8eded..ca83457 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=adreno --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
// Large matmul that can match the best tiling scheme.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @matmul_1024x2048x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_1024x2048x512() {
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c1024 = arith.constant 1024 : index
@@ -31,9 +30,8 @@
// -----
// Small matmul N that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @matmul_3136x24x96() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_3136x24x96() {
%c0 = arith.constant 0 : index
%c24 = arith.constant 24 : index
%c3136 = arith.constant 3136 : index
@@ -61,9 +59,8 @@
// -----
// Small matmul M that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @matmul_196x64x192() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_196x64x192() {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c196 = arith.constant 196 : index
@@ -92,9 +89,8 @@
// Small matmul K that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @matmul_12544x96x16() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_12544x96x16() {
%c0 = arith.constant 0 : index
%c96 = arith.constant 96 : index
%c12544 = arith.constant 12544 : index
@@ -119,9 +115,8 @@
// Odd matmul M and small N that cannot utilize all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @matmul_49x160x576() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_49x160x576() {
%c0 = arith.constant 0 : index
%c160 = arith.constant 160 : index
%c49 = arith.constant 49 : index
@@ -150,9 +145,8 @@
// Large batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @batch_matmul_4x384x384() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_4x384x384() {
%c0 = arith.constant 0 : index
%c384 = arith.constant 384 : index
%c4 = arith.constant 4 : index
@@ -181,9 +175,8 @@
// Small batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64>>}>
module {
- func.func @batch_matmul_4x8x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_4x8x8() {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
index fa81773..25dcfa0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
@@ -1,9 +1,8 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=rdna2@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
- func.func @nhwc_conv_pointwise_2x64x64x320() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @nhwc_conv_pointwise_2x64x64x320() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x66x66x320xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
index e578411..7705422 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
@@ -1,8 +1,7 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=rdna2@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
module {
- func.func @batch_matmul_f32_16x4096x40x4096() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_f32_16x4096x40x4096() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x4096x4096xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x4096x40xf32>>
@@ -25,9 +24,9 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+
module {
- func.func @matmul_f16_64x640x320() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_f16_64x640x320() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x320xf16>>
@@ -51,9 +50,9 @@
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+
module {
- func.func @batch_matmul_f32_16x4096x40x4096() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_f32_16x4096x40x4096() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x4096x4096xf32>>
@@ -76,12 +75,11 @@
// CHECK: linalg.batch_matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]
-
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @batch_matmul_f16_1x4096x4096x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_f16_1x4096x4096x512() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1x4096x512xf16>>
@@ -111,14 +109,14 @@
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
module {
- func.func @matmul_multi_reduce_i4xf32xf32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_multi_reduce_i4xf32xf32() {
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
index f7c04fd..6e2d836 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
@@ -1,9 +1,8 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=rdna3@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @matmul_256x1024x128_div_add() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_256x1024x128_div_add() {
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
%c256 = arith.constant 256 : index
@@ -40,10 +39,10 @@
// CHECK-SAME: lowering_config = #[[$CONFIG]]
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @batch_matmul_16x128x256x512_div() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_16x128x256x512_div() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x128x512xf16>>
@@ -76,7 +75,7 @@
// -----
// Linalg.generic that is a batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+
#map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -84,7 +83,7 @@
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
- func.func @generic_batch_matmul_32x8x512x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @generic_batch_matmul_32x8x512x64() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<128x32x64xf16>>
@@ -116,9 +115,8 @@
// K dim size not divisble by 32.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
module {
- func.func @batch_matmul_16x1024x1024x80() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_16x1024x1024x80() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x1024x80xf16>>
@@ -144,9 +142,9 @@
// -----
// Small K - not supported by cooperative matrix.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+
module {
- func.func @matmul_256x1024x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_256x1024x8() {
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
%c256 = arith.constant 256 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
similarity index 88%
rename from compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
rename to compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
index ba468be..7b0480a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matvec.mlir
@@ -1,12 +1,11 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=cdna2@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0)>
module {
- func.func @i4_dequant_matvec_f32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec_f32() {
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
@@ -48,14 +47,13 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
- func.func @i4_dequant_matvec_f32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec_f32() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%c4294967296_i64 = arith.constant 4294967296 : i64
@@ -99,14 +97,13 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
module {
- func.func @i4_dequant_matvec_f32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec_f32() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
@@ -177,14 +174,13 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformShuffle], [SPV_KHR_16bit_storage]>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64]>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
- func.func @i4_dequant_matvec_f16() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec_f16() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
@@ -219,7 +215,7 @@
}
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 2, 128]{{\]}}>
-// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce workgroup_size = [32, 1, 1]>
+// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce workgroup_size = [64, 1, 1]>
// CHECK: func.func @i4_dequant_matvec_f16()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
@@ -227,14 +223,13 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
module {
- func.func @i4_dequant_matvec() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
@@ -305,14 +300,13 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
module {
- func.func @i4_dequant_matvec() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
@@ -377,9 +371,8 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
module {
- func.func @dynamic_batch_matvec() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @dynamic_batch_matvec() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
index 7a14111..fc6da02 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
// Convolution with consumer pointwise ops
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
- func.func @nhwc_conv_pointwise_112x112x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @nhwc_conv_pointwise_112x112x32() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%c112 = arith.constant 112 : index
@@ -38,9 +37,9 @@
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+
module {
- func.func @nchw_conv_2x1280x8x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @nchw_conv_2x1280x8x8() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x1280x10x10xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1280x1280x3x3xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
index 6c7fe88..39a3ec1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
@@ -1,12 +1,11 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @static_1d_sort() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @static_1d_sort() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readwrite:tensor<1000xi32>>
%1 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [1000], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<1000xi32>> -> tensor<1000xi32>
- %2 = iree_linalg_ext.sort {__internal_linalg_transform__ = "workgroup"} dimension(0) outs(%1 : tensor<1000xi32>) {
+ %2 = iree_linalg_ext.sort dimension(0) outs(%1 : tensor<1000xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%3 = arith.cmpi slt, %arg0, %arg1 : i32
iree_linalg_ext.yield %3 : i1
@@ -26,10 +25,9 @@
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @static_3d_sort() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @static_3d_sort() {
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
@@ -39,7 +37,7 @@
^bb0(%in: i32, %out: i32):
linalg.yield %in : i32
}
- iree_linalg_ext.sort {__internal_linalg_transform__ = "workgroup"} dimension(1) outs(%1 : memref<64x32x128xi32>) {
+ iree_linalg_ext.sort dimension(1) outs(%1 : memref<64x32x128xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%2 = arith.cmpi slt, %arg0, %arg1 : i32
iree_linalg_ext.yield %2 : i1
@@ -48,17 +46,16 @@
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 0, 16], [1, 0, 1]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [16, 1, 1]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 0, 64], [1, 0, 1]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [64, 1, 1]>
// CHECK: func.func @static_3d_sort()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.sort
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirvfb = #hal.executable.target<"vulkan-spirv", "vulkan-spirvfb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @static_1d_fft_stage2() attributes {hal.executable.target = #executable_target_vulkan_spirvfb} {
+ func.func @static_1d_fft_stage2() {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%cst = arith.constant dense<[1.000000e+00, 6.12323426E-17]> : tensor<2xf32>
@@ -67,7 +64,7 @@
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readwrite:tensor<32xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<32xf32>> -> tensor<32xf32>
%3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<32xf32>> -> tensor<32xf32>
- %4:2 = iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"} ins(%c2, %cst, %cst_0 : index, tensor<2xf32>, tensor<2xf32>) outs(%2, %3 : tensor<32xf32>, tensor<32xf32>) : tensor<32xf32>, tensor<32xf32>
+ %4:2 = iree_linalg_ext.fft ins(%c2, %cst, %cst_0 : index, tensor<2xf32>, tensor<2xf32>) outs(%2, %3 : tensor<32xf32>, tensor<32xf32>) : tensor<32xf32>, tensor<32xf32>
flow.dispatch.tensor.store %4#0, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor<readwrite:tensor<32xf32>>
flow.dispatch.tensor.store %4#1, %1, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor<readwrite:tensor<32xf32>>
return
@@ -75,16 +72,15 @@
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[4]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [16, 1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [64, 1, 1]>
// CHECK: func.func @static_1d_fft_stage2()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.fft
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirvfb = #hal.executable.target<"vulkan-spirv", "vulkan-spirvfb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @static_3d_fft_stage3() attributes {hal.executable.target = #executable_target_vulkan_spirvfb} {
+ func.func @static_3d_fft_stage3() {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c64 = arith.constant 64 : index
@@ -96,22 +92,21 @@
%1 = bufferization.to_memref %cst : memref<4xf32>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x128x32xf32>
%3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x128x32xf32>
- iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"} ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>) outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
+ iree_linalg_ext.fft ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>) outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 8]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [16, 1, 1]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [64, 1, 1]>
// CHECK: func.func @static_3d_fft_stage3()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.fft
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirvfb = #hal.executable.target<"vulkan-spirv", "vulkan-spirvfb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @winograd_input_transform() attributes {hal.executable.target = #executable_target_vulkan_spirvfb} {
+ func.func @winograd_input_transform() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x128xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<8x8x2x6x6x128xf16>>
@@ -131,9 +126,8 @@
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
-#executable_target_vulkan_spirvfb = #hal.executable.target<"vulkan-spirv", "vulkan-spirvfb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @winograd_output_transform() attributes {hal.executable.target = #executable_target_vulkan_spirvfb} {
+ func.func @winograd_output_transform() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8x8x2x6x6x128xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x36x36x128xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
index 5cc86b6..e579c77 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
@@ -1,6 +1,11 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [16], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @copy_as_generic() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -25,7 +30,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @copy() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -53,7 +63,12 @@
// Average pooling op with nice tilable input.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
module {
func.func @avg_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -81,7 +96,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 4>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [4], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @avg_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -116,7 +136,12 @@
// Max pooling op with odd size-1 dimension sizes.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
module {
func.func @max_pool() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%cst = arith.constant 0xFF800000 : f32
@@ -147,7 +172,12 @@
// Element wise op with mismatched input and output rank.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
module {
@@ -179,7 +209,12 @@
// Fused depthwise convolution and element wise ops: don't vectorize with partially active subgroups.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
module {
func.func @dwconv_elementwise() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
@@ -218,7 +253,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
@@ -250,7 +290,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
@@ -290,7 +335,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [16], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
@@ -318,7 +368,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
@@ -356,7 +411,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
@@ -394,7 +454,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2, d3) -> ()>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
index a3c7c56..f4a803a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -2,7 +2,12 @@
// Odd K that forbids vectorization.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
module {
func.func @batch_matmul_1x3x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -34,7 +39,12 @@
// 8-bit integers can be vectorized.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
module {
func.func @matmul_64x16xi8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -65,7 +75,12 @@
// Vectorize non-32 bit types.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Int64], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int64|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
module {
func.func @matmul_64x16xi64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -96,7 +111,12 @@
// Odd N that forbids vectorization.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
@@ -138,7 +158,12 @@
// Odd M and non-4-multiplier N
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
@@ -180,7 +205,12 @@
// Matmul with consumer pointwise ops
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_pointwise_256x1024() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_misc.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_misc.mlir
index 9273827..2a20a2c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_misc.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_misc.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
#map = affine_map<(d0, d1, d2) -> (d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @complex_view_as_real() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @complex_view_as_real() {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1xi32>>
@@ -38,8 +37,8 @@
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[4, 2, 2], [1, 1, 1]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [2, 2, 4]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[16, 2, 2], [1, 1, 1]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [2, 2, 16]>
// CHECK: func.func @complex_view_as_real()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
index f521960..960d0ac 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
@@ -1,6 +1,11 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
+ subgroup_size_choices = [16], max_workgroup_sizes = [512, 512, 512],
+ max_thread_count_per_workgroup = 512, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
@@ -31,7 +36,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
@@ -69,7 +79,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1], [0, 64]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, GroupNonUniformShuffle], []>, api=Vulkan, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_sub_byte_types.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_sub_byte_types.mlir
index 9980423..9398820 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_sub_byte_types.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_sub_byte_types.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>>}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
- func.func @i4_dequant() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<131072x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<131072xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
index 0186b14..d56d0c3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
// Conv - large OC - distribute to only one workgroup dimension.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @conv_112x112x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_112x112x512() {
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c112 = arith.constant 112 : index
@@ -33,9 +32,8 @@
// Conv - medium OC/OW/OH - distribute to two workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @conv_112x112x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_112x112x32() {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c112 = arith.constant 112 : index
@@ -64,9 +62,8 @@
// Conv - small OC/OW/OH - distribute to all three workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @conv_16x16x16() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @conv_16x16x16() {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -94,9 +91,8 @@
// Depthwise conv - small OC/OW/OH - distribute to all three workgroup dimensions.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @dwconv_28x28x144() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @dwconv_28x28x144() {
%c0 = arith.constant 0 : index
%c144 = arith.constant 144 : index
%c28 = arith.constant 28 : index
@@ -125,9 +121,8 @@
// Depthwise conv - tiny OC/OW/OH - starving the GPU.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @dwconv_1x2x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @dwconv_1x2x8() {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c2 = arith.constant 2 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
index 68b0aaf..f3adbbc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
@@ -1,10 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
// Large matmul that can match the best tiling scheme.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_1024x2048x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_1024x2048x512() {
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c1024 = arith.constant 1024 : index
@@ -33,9 +32,8 @@
// Small matmul N that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_3136x24x96() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_3136x24x96() {
%c0 = arith.constant 0 : index
%c24 = arith.constant 24 : index
%c3136 = arith.constant 3136 : index
@@ -64,9 +62,8 @@
// Small matmul M that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_196x64x192() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_196x64x192() {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c196 = arith.constant 196 : index
@@ -95,9 +92,8 @@
// Small matmul K that can still tile to all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_12544x96x16() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_12544x96x16() {
%c0 = arith.constant 0 : index
%c96 = arith.constant 96 : index
%c12544 = arith.constant 12544 : index
@@ -122,9 +118,8 @@
// Odd matmul M and small N that cannot utilize all threads in a workgroup.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_49x160x576() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_49x160x576() {
%c0 = arith.constant 0 : index
%c160 = arith.constant 160 : index
%c49 = arith.constant 49 : index
@@ -153,9 +148,8 @@
// Small matmul M to "shift" parallelism to N.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_2x1024x576() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_2x1024x576() {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 3.000000e+00 : f32
%cst_1 = arith.constant 6.000000e+00 : f32
@@ -190,9 +184,8 @@
// Large matmul with i8 inputs.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @matmul_1024x2048x512xi8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_1024x2048x512xi8() {
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c1024 = arith.constant 1024 : index
@@ -211,9 +204,8 @@
}
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @batch_matmul_4x384x384() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_4x384x384() {
%c0 = arith.constant 0 : index
%c384 = arith.constant 384 : index
%c4 = arith.constant 4 : index
@@ -242,9 +234,8 @@
// Small batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
module {
- func.func @batch_matmul_4x2x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_4x2x8() {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c2 = arith.constant 2 : index
@@ -274,7 +265,6 @@
// Linalg.generic that is a batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -282,7 +272,7 @@
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
- func.func @generic_batch_matmul_32x2x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @generic_batch_matmul_32x2x512() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<8x32x64xf32>>
@@ -314,13 +304,12 @@
// Linalg.generic that is a batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @generic_batch_matmul_8x2500x512x4608() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @generic_batch_matmul_8x2500x512x4608() {
%c168607744 = arith.constant 168607744 : index
%c537247744 = arith.constant 537247744 : index
%c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
index 430750f..fe1e2ad 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
@@ -1,8 +1,7 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64]>>}>
module {
- func.func @matmul_4x4096x9216() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_4x4096x9216() {
%c36864 = arith.constant 36864 : index
%c667974912 = arith.constant 667974912 : index
%c209920 = arith.constant 209920 : index
@@ -32,9 +31,8 @@
// Matvec does not go down matmul pipelines.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64]>>}>
module {
- func.func @matmul_1x4096x9216() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_1x4096x9216() {
%c36864 = arith.constant 36864 : index
%c667974912 = arith.constant 667974912 : index
%c209920 = arith.constant 209920 : index
@@ -64,12 +62,11 @@
// Multi-reduction-dimension transposed-B matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64]>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
module {
- func.func @multi_reduction_transposed_b_matmul() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @multi_reduction_transposed_b_matmul() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
index f67ab1f..f274f4f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
@@ -1,11 +1,10 @@
-// RUN: iree-opt --split-input-file \
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=volta@vulkan \
// RUN: --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass)' %s | \
// RUN: FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @matmul_256x1024x128_div_add() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_256x1024x128_div_add() {
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
%c256 = arith.constant 256 : index
@@ -43,10 +42,9 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @batch_matmul_16x128x256x512_div() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_16x128x256x512_div() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x128x512xf16>>
@@ -80,7 +78,6 @@
// Linalg.generic that is a batch matmul.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -88,7 +85,7 @@
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
- func.func @generic_batch_matmul_32x8x512x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @generic_batch_matmul_32x8x512x64() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<128x32x64xf16>>
@@ -120,9 +117,8 @@
// K dim size not divisble by 32.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
module {
- func.func @batch_matmul_16x1024x1024x80() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @batch_matmul_16x1024x1024x80() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x1024x80xf16>>
@@ -149,9 +145,8 @@
// Small K - not supported by cooperative matrix.
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
module {
- func.func @matmul_256x1024x8() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_256x1024x8() {
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
%c256 = arith.constant 256 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
index 84ddf22..ca52b8a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
@@ -1,11 +1,10 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-codegen-materialize-user-configs, iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(iree-codegen-materialize-user-configs, iree-spirv-select-lowering-strategy-pass)' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [16, 8, 1] subgroup_size = 64>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
- func.func @matmul_128x1024x256() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_128x1024x256() {
%cst = arith.constant 0.000000e+00 : f32
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
@@ -17,7 +16,7 @@
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x1024xf32>> -> tensor<256x1024xf32>
%5 = tensor.empty() : tensor<128x1024xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
- %7 = linalg.matmul {__internal_linalg_transform__ = "workgroup", compilation_info = #compilation} ins(%3, %4 : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%6 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
+ %7 = linalg.matmul {compilation_info = #compilation} ins(%3, %4 : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%6 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 1024], strides = [1, 1] : tensor<128x1024xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x1024xf32>>
return
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
index b1f8092..d8d3770 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-convert-gpu-target)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-spirv-convert-gpu-target))))' %s | FileCheck %s
hal.executable @dispatch {
hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
@@ -18,7 +18,8 @@
}
}
-// CHECK: spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+// CHECK: builtin.module attributes
+// CHECK-SAME: spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
// CHECK-SAME: [Shader, Float64, Float16, Int64, Int16, Int8,
// CHECK-SAME: StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
// CHECK-SMAE: StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index 32fd746..7b80a76 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -8,12 +8,11 @@
]>
]>
hal.executable private @push_constant {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @push_constant layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
// CHECK-LABEL: spirv.module
// CHECK: spirv.GlobalVariable @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
// CHECK: spirv.func @push_constant()
@@ -49,12 +48,11 @@
]>
]>
hal.executable private @resource_bindings_in_same_func {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @resource_bindings_in_same_func layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
// CHECK-LABEL: spirv.module
// CHECK: spirv.GlobalVariable @[[ARG0:.+]] bind(1, 2) : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: spirv.GlobalVariable @[[ARG1_0:.+]] bind(1, 3) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -110,15 +108,14 @@
]>
]>
hal.executable private @resource_bindings_in_multi_entry_func {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @resource_bindings_in_entry_func1 layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
hal.executable.export @resource_bindings_in_entry_func2 layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
// CHECK-LABEL: spirv.module
// CHECK: spirv.GlobalVariable @[[FUNC1_ARG:.+]] bind(1, 2) : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: spirv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
@@ -171,12 +168,11 @@
]>
]>
hal.executable private @interface_binding {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @interface_binding layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
func.func @interface_binding() -> f32 {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<8x5xf32, #spirv.storage_class<StorageBuffer>>
@@ -217,12 +213,11 @@
]>
]>
hal.executable private @interface_wg_id {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @interface_wg_id layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
func.func @interface_wg_id() -> index {
%0 = hal.interface.workgroup.id[0] : index
%1 = hal.interface.workgroup.id[1] : index
@@ -253,12 +248,11 @@
]>
]>
hal.executable private @interface_wg_count {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @interface_wg_count layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
func.func @interface_wg_count() -> index {
%0 = hal.interface.workgroup.count[0] : index
%1 = hal.interface.workgroup.count[1] : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
index 14263f0..4a2cf5d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir
@@ -2,7 +2,12 @@
// RUN: --pass-pipeline='builtin.module(func.func(iree-spirv-emulate-i64))' %s | \
// RUN: FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
module {
func.func @buffer_types() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -30,7 +35,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
module {
func.func @emulate_1d_vector() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c95232 = arith.constant 95232 : index
@@ -72,7 +82,13 @@
// CHECK: return
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int64], []>, #spirv.resource_limits<>>}>
+
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int64|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
module {
func.func @no_emulation() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
index c651363..cc8f964 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/illegal_configuration.mlir
@@ -3,7 +3,12 @@
// RUN: --verify-diagnostics --split-input-file %s
#config = #iree_codegen.lowering_config<tile_sizes = []>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -21,7 +26,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64], [4, 4], [0, 0, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -39,7 +49,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64], [4, 4], [0, 0, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 128], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -57,7 +72,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64], [4, 2], [0, 0, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -75,7 +95,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64], [16, 8], [0, 0, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [8, 2, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -93,7 +118,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 60], [4, 4], [0, 0, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [15, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -111,7 +141,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -129,7 +164,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 64, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -147,7 +187,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [32, 32], [0, 0, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
+ mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
+ subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -170,7 +216,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [32, 32], [0, 0, 16], [8, 8, 8]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
+ mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
+ subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -194,7 +246,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 32], [8, 8], [0, 0, 4], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
+ mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
+ subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [256, 4, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -218,7 +276,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [32, 32], [0, 0, 16], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
+ mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
+ subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [64, 2, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -242,7 +306,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [32, 32], [0, 0, 16], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = none, dot = none,
+ mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
+ subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 4, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -266,7 +336,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 2], [0, 0, 0, 0, 1, 1, 4]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<()[s0] -> (s0 * 4)>
#map1 = affine_map<()[s0] -> (s0 * 16)>
#map2 = affine_map<(d0) -> (d0 * 2)>
@@ -315,7 +390,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 6, 6, 16], [0, 3, 3, 2], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<()[s0] -> (s0 * 4)>
#map1 = affine_map<()[s0] -> (s0 * 16)>
#map2 = affine_map<(d0) -> (d0 * 2)>
@@ -364,7 +444,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#map = affine_map<()[s0] -> (s0 * 4)>
#map1 = affine_map<()[s0] -> (s0 * 16)>
#map2 = affine_map<(d0) -> (d0 * 2)>
@@ -413,7 +498,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 7, 64], [0, 1, 7, 2], [0, 0, 0, 0, 5, 5], [0, 1, 0, 0]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [32, 1, 1]>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
@@ -431,7 +521,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 7, 64], [0, 1, 7, 2], [0, 0, 0, 0, 1, 1], [0, 0, 1, 1]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+}>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [32, 1, 1]>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
index 756ba3a..55bb2be 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
@@ -1,7 +1,6 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-codegen-spirv-configuration-pipeline, func.func(iree-spirv-lower-executable-target-pass))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=cdna2@vulkan --pass-pipeline='builtin.module(iree-codegen-spirv-configuration-pipeline, func.func(iree-spirv-lower-executable-target-pass))' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 1, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [65535, 65535, 65535]>>}>
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@@ -12,7 +11,7 @@
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
module {
- func.func @matmul_i4_quant_weight() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul_i4_quant_weight() {
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
index 27831cf..838d95c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline, func.func(iree-spirv-lower-executable-target-pass)))))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=volta@vulkan --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline, func.func(iree-spirv-lower-executable-target-pass)))))' %s | FileCheck %s
// TODO (MaheshRavishankar): This test should be modified to run just on the inner module/func.func. Blocked
// today since `TileAndDistributeToWorkgroups` runs the `FoldAffineMinOverWorkgroupIds` pattern that
@@ -20,12 +20,7 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
hal.executable @matmul_f32_128x256x64 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f32_128x256x64 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -59,16 +54,16 @@
}
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1]
-// CHECK: func.func @matmul_f32_128x256x64()
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
+// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1]
+// CHECK-LABEL: func.func @matmul_f32_128x256x64()
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: %[[CST0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: memref.alloc() : memref<2x64x20xf32, #gpu.address_space<workgroup>>
// CHECK: memref.alloc() : memref<2x16x68xf32, #gpu.address_space<workgroup>>
// CHECK: scf.for
// CHECK: gpu.barrier
-// CHECK: affine.apply #[[MAP]]
+// CHECK: affine.apply #[[$MAP]]
// CHECK-COUNT-2: vector.transfer_write %{{.+}}, %{{.+}} {in_bounds = [true]} : vector<4xf32>, memref<2x64x20xf32, #gpu.address_space<workgroup>>
// CHECK-COUNT-2: vector.transfer_write %{{.+}}, %{{.+}} {in_bounds = [true]} : vector<4xf32>, memref<2x16x68xf32, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
@@ -100,6 +95,9 @@
// Store in stage 0 of pipeline.
+#compilation = #iree_codegen.compilation_info<
+ lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 16]]>,
+ translation_info = <SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 2, store_stage = 0}>>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
@@ -111,12 +109,7 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
hal.executable @matmul_f32_128x256x64 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f32_128x256x64 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -135,7 +128,8 @@
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x256xf32>> -> tensor<128x256xf32>
%7 = tensor.empty() : tensor<128x256xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x256xf32>) -> tensor<128x256xf32>
- %9 = linalg.matmul ins(%4, %5 : tensor<128x512xf32>, tensor<512x256xf32>) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32>
+ %9 = linalg.matmul {compilation_info = #compilation}
+ ins(%4, %5 : tensor<128x512xf32>, tensor<512x256xf32>) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
ins(%9, %6 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%7 : tensor<128x256xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
@@ -149,10 +143,10 @@
}
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 3)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1]
-// CHECK: func.func @matmul_f32_128x256x64()
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 3)>
+// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1]
+// CHECK-LABEL: func.func @matmul_f32_128x256x64()
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: %[[CST0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: memref.alloc() : memref<3x64x20xf32, #gpu.address_space<workgroup>>
// CHECK: memref.alloc() : memref<3x16x68xf32, #gpu.address_space<workgroup>>
@@ -178,7 +172,7 @@
// CHECK-COUNT-32: vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x64x20xf32, #gpu.address_space<workgroup>>, vector<4xf32>
// CHECK-COUNT-16: vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x16x68xf32, #gpu.address_space<workgroup>>, vector<4xf32>
// CHECK-COUNT-128: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK-DAG: %[[APPLY:.+]] = affine.apply #[[MAP]]
+// CHECK-DAG: %[[APPLY:.+]] = affine.apply #[[$MAP]]
// CHECK-DAG: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
// CHECK: vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
@@ -216,12 +210,7 @@
]>
hal.executable @matmul_f16_4096x512x512 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f16_4096x512x512 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir
index 3a4a0ed..90e8dc3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir
@@ -1,12 +1,11 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass, func.func(iree-spirv-lower-executable-target-pass))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=cdna2@vulkan --pass-pipeline='builtin.module(iree-spirv-select-lowering-strategy-pass, func.func(iree-spirv-lower-executable-target-pass))' %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>>}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0)>
module {
- func.func @i4_dequant_matvec_f32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @i4_dequant_matvec_f32() {
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
index 69f2583..7e703df 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir
@@ -2,7 +2,12 @@
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-decompose-softmax), iree-spirv-select-lowering-strategy-pass, func.func(iree-spirv-lower-executable-target-pass))' \
// RUN: %s | FileCheck %s
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformShuffle], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], cooperative_matrix_properties_khr = []>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536>>
+}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
@@ -81,7 +86,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformShuffle], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64],
+ max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 65536>>
+}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
@@ -155,7 +165,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64]>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 64],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
module {
func.func @softmax() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c0 = arith.constant 0 : index
@@ -253,7 +268,12 @@
// -----
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, GroupNonUniformShuffle], [SPV_KHR_16bit_storage]>, api=Vulkan, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
+#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+}>
module {
func.func @dynamic_softmax() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%c32_i64 = arith.constant 32 : i64
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_scalar_dispatch.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_scalar_dispatch.mlir
index e3de27b..ceae10c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_scalar_dispatch.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_scalar_dispatch.mlir
@@ -1,13 +1,9 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-spirv-select-lowering-strategy-pass, func.func(iree-spirv-lower-executable-target-pass)))))' -mlir-print-local-scope %s | FileCheck %s
-
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, Unknown:Unknown,
- #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 32>>}>
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-spirv-select-lowering-strategy-pass, func.func(iree-spirv-lower-executable-target-pass)))))' -mlir-print-local-scope %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
hal.executable @scalar_dispatch {
- hal.executable.variant public @vulkan_spirv_fb target(#executable_target_vulkan_spirv_fb) {
+ hal.executable.variant public @vulkan_spirv_fb target(#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @scalar_dispatch ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
index 2abb856..9c622d8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/map_memref_storage_class.mlir
@@ -1,35 +1,27 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-map-memref-storage-class)))))' --allow-unregistered-dialect %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-spirv-map-memref-storage-class))' --allow-unregistered-dialect %s | FileCheck %s
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable private @vulkan_client_api {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>}>) {
- hal.executable.export @vulkan_client_api layout(#pipeline_layout) attributes {
- workgroup_size = [32: index, 1: index, 1: index]
- }
- builtin.module {
- func.func @vulkan_client_api() {
- %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
- "dialect.memref_consumer"(%0) : (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>) -> ()
+#target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<
+ arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
+ dot = none, mma = [], subgroup_size_choices = [64],
+ max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
+ max_workgroup_memory_bytes = 16384>>}>
- %1 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>)
- "dialect.memref_consumer"(%1) : (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>) -> ()
+func.func @vulkan_client_api() attributes {hal.executable.target = #target} {
+ %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
+ "dialect.memref_consumer"(%0) : (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>) -> ()
- %2 = "dialect.memref_producer"() : () -> (memref<?x8xf32>)
- "dialect.memref_consumer"(%2) : (memref<?x8xf32>) -> ()
+ %1 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>)
+ "dialect.memref_consumer"(%1) : (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>) -> ()
- %3 = "dialect.memref_producer"() : () -> (memref<?x8xf32, 3>)
- "dialect.memref_consumer"(%3) : (memref<?x8xf32, 3>) -> ()
+ %2 = "dialect.memref_producer"() : () -> (memref<?x8xf32>)
+ "dialect.memref_consumer"(%2) : (memref<?x8xf32>) -> ()
- return
- }
- }
- }
+ %3 = "dialect.memref_producer"() : () -> (memref<?x8xf32, 3>)
+ "dialect.memref_consumer"(%3) : (memref<?x8xf32, 3>) -> ()
+
+ return
}
// CHECK-LABEL: func.func @vulkan_client_api()
@@ -47,36 +39,28 @@
// -----
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable private @opencl_client_api {
- hal.executable.variant @opencl target(<"opencl-spirv", "opencl-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel], []>, #spirv.resource_limits<>>}>) {
- hal.executable.export @opencl_client_api layout(#pipeline_layout) attributes {
- workgroup_size = [32: index, 1: index, 1: index]
- }
- builtin.module {
- func.func @opencl_client_api() {
- %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
- "dialect.memref_consumer"(%0) : (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>) -> ()
+#target = #hal.executable.target<"opencl-spirv", "opencl-spirv-fb", {
+ iree.gpu.target = #iree_gpu.target<
+ arch = "", features = "spirv:v1.3,cap:Kernel", wgp = <
+ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
+ dot = none, mma = [], subgroup_size_choices = [64],
+ max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
+ max_workgroup_memory_bytes = 16384>>}>
- %1 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>)
- "dialect.memref_consumer"(%1) : (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>) -> ()
+func.func @opencl_client_api() attributes {hal.executable.target = #target} {
+ %0 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>)
+ "dialect.memref_consumer"(%0) : (memref<?x8xf32, #hal.descriptor_type<uniform_buffer>>) -> ()
- %2 = "dialect.memref_producer"() : () -> (memref<?x8xf32>)
- "dialect.memref_consumer"(%2) : (memref<?x8xf32>) -> ()
+ %1 = "dialect.memref_producer"() : () -> (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>)
+ "dialect.memref_consumer"(%1) : (memref<?x8xf32, #hal.descriptor_type<storage_buffer>>) -> ()
- %3 = "dialect.memref_producer"() : () -> (memref<?x8xf32, 3>)
- "dialect.memref_consumer"(%3) : (memref<?x8xf32, 3>) -> ()
+ %2 = "dialect.memref_producer"() : () -> (memref<?x8xf32>)
+ "dialect.memref_consumer"(%2) : (memref<?x8xf32>) -> ()
- return
- }
- }
- }
+ %3 = "dialect.memref_producer"() : () -> (memref<?x8xf32, 3>)
+ "dialect.memref_consumer"(%3) : (memref<?x8xf32, 3>) -> ()
+
+ return
}
// CHECK-LABEL: func.func @opencl_client_api()
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/physical_storage_buffer_addresses.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/physical_storage_buffer_addresses.mlir
index 265dd86..fb61908 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/physical_storage_buffer_addresses.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/physical_storage_buffer_addresses.mlir
@@ -10,14 +10,15 @@
], flags = Indirect>
]>
hal.executable private @interface_binding {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb-ptr", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Int64, Shader, PhysicalStorageBufferAddresses],
- [SPV_KHR_physical_storage_buffer]>, #spirv.resource_limits<>>,
- hal.bindings.indirect}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb-ptr", {hal.bindings.indirect}>) {
hal.executable.export @interface_binding layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
- builtin.module {
+ builtin.module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+ [Int64, Shader, PhysicalStorageBufferAddresses],
+ [SPV_KHR_physical_storage_buffer]>, #spirv.resource_limits<>>
+ } {
func.func @interface_binding() -> f32 {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<8x5xf32, #spirv.storage_class<PhysicalStorageBuffer>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
index 8102908..b7c9485 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
@@ -1,7 +1,11 @@
-// RUN: iree-opt --split-input-file \
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=volta@vulkan \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline, canonicalize, cse)))' \
// RUN: %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=rdna3@vulkan \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline, canonicalize, cse)))' \
+// RUN: %s | FileCheck %s --check-prefix=RDNA3
+
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
@@ -13,28 +17,7 @@
]>
hal.executable public @matmul_256x1024x128_div_exp {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_256x1024x128_div_exp layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -205,6 +188,12 @@
// CHECK-COUNT-2: spirv.GL.FAbs %{{.+}} : vector<4xf16>
// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
+// RDNA3-LABEL: spirv.module Logical GLSL450
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
+// RDNA3-DAG: spirv.GlobalVariable @[[C_MEM:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
+// RDNA3: spirv.func @matmul_256x1024x128_div_exp
+
// -----
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
@@ -217,28 +206,7 @@
]>
]>
hal.executable public @batch_matmul_16x128x256x512_div {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @batch_matmul_16x128x256x512_div layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
@@ -325,6 +293,10 @@
// CHECK-COUNT-4: %{{.+}} = spirv.FDiv %{{.+}}, %{{.+}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-COUNT-4: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C32]], <RowMajor>
+// RDNA3-LABEL: spirv.module Logical GLSL450
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
+// RDNA3: spirv.func @batch_matmul_16x128x256x512_div
// -----
@@ -341,28 +313,7 @@
]>
hal.executable public @matmul_32x32x32_div {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_32x32x32_div layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -415,28 +366,7 @@
]>
hal.executable public @generic_batch_matmul_32x128x512x64 {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
+ hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @generic_batch_matmul_32x128x512x64 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4
@@ -516,425 +446,7 @@
// CHECK-COUNT-4: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C64]], <RowMajor>
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>,
- #hal.descriptor_set.binding<3, storage_buffer>,
- #hal.descriptor_set.binding<4, storage_buffer>
- ]>
-]>
-
-hal.executable public @matmul_256x1024x128_div_exp {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 65536,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 1024],
- subgroup_size = 64>
- >}>) {
- hal.executable.export public @matmul_256x1024x128_div_exp layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @matmul_256x1024x128_div_exp() {
- %c0 = arith.constant 0 : index
- %c1024 = arith.constant 1024 : index
- %c256 = arith.constant 256 : index
- %cst = arith.constant 0.000000e+00 : f16
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<256x1024xf16>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<256x1024xf16>>
- %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<256x128xf16>>
- %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<128x1024xf16>>
- %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) : !flow.dispatch.tensor<writeonly:tensor<256x1024xf16>>
- %11 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x1024xf16>> -> tensor<256x1024xf16>
- %14 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x1024xf16>> -> tensor<256x1024xf16>
- %17 = tensor.empty() : tensor<256x1024xf16>
- %19 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [256, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x128xf16>> -> tensor<256x128xf16>
- %21 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [128, 1204], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x1024xf16>> -> tensor<128x1024xf16>
- %24 = tensor.empty() : tensor<256x1024xf16>
- %25 = linalg.fill ins(%cst : f16) outs(%24 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
- %26 = linalg.matmul ins(%19, %21 : tensor<256x128xf16>, tensor<128x1024xf16>) outs(%25 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
- %27 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%26, %11, %14 : tensor<256x1024xf16>, tensor<256x1024xf16>, tensor<256x1024xf16>)
- outs(%17 : tensor<256x1024xf16>) {
- ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):
- %28 = arith.divf %arg2, %arg3 : f16
- // spirv.GL.FAbs is not permitted to use cooperative matrix types per the spec.
- %29 = math.absf %28 : f16
- linalg.yield %29 : f16
- } -> tensor<256x1024xf16>
- flow.dispatch.tensor.store %27, %4, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : tensor<256x1024xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x1024xf16>>
- return
- }
- }
- }
-}
-
-// CHECK-LABEL: spirv.module Logical GLSL450
-
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
-// CHECK-DAG: spirv.GlobalVariable @[[C_MEM:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
-
-// CHECK: spirv.func @matmul_256x1024x128_div_exp
-
-// CHECK-DAG: %[[C5:.+]] = spirv.Constant 5 : i32
-// CHECK-DAG: %[[C17:.+]] = spirv.Constant 17 : i32
-// CHECK-DAG: %[[C32:.+]] = spirv.Constant 32 : i32
-// CHECK-DAG: %[[C128:.+]] = spirv.Constant 128 : i32
-// CHECK-DAG: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f16
-// CHECK: %{{.+}} = spirv.CompositeConstruct %[[F0]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-
-// CHECK-COUNT-8: %{{.+}} = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, Function>
-// CHECK: spirv.mlir.loop
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-
-//CHECK-COUNT-16: %{{.+}} = spirv.KHR.CooperativeMatrixMulAdd %{{.+}}, %{{.+}}, %{{.+}}
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-// CHECK-COUNT-8: spirv.Store "Function" %{{.+}}, %{{.+}}
-// CHECK: spirv.mlir.merge
-
-// CHECK-COUNT-8: %{{.+}} = spirv.Load "Function" %{{.+}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-
-//CHECK-COUNT-16: %{{.+}} = spirv.KHR.CooperativeMatrixMulAdd %{{.+}}, %{{.+}}, %{{.+}}
-
-// CHECK-COUNT-8: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C17]], <RowMajor>
-
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-// CHECK-COUNT-2: spirv.FDiv %{{.+}}, %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.GL.FAbs %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.FDiv %{{.+}}, %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.GL.FAbs %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.FDiv %{{.+}}, %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.GL.FAbs %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.FDiv %{{.+}}, %{{.+}} : vector<4xf16>
-// CHECK-COUNT-2: spirv.GL.FAbs %{{.+}} : vector<4xf16>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>,
- #hal.descriptor_set.binding<3, storage_buffer>,
- #hal.descriptor_set.binding<4, storage_buffer>
- ]>
-]>
-hal.executable public @batch_matmul_16x128x256x512_div {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 65536,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 1024],
- subgroup_size = 64>
- >}>) {
- hal.executable.export public @batch_matmul_16x128x256x512_div layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @batch_matmul_16x128x256x512_div() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f16
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x128x512xf16>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x512x256xf16>>
- %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x128x256xf16>>
- %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16x128x256xf16>>
- %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [16, 128, 512], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x128x512xf16>> -> tensor<16x128x512xf16>
- %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 512, 256], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x512x256xf16>> -> tensor<16x512x256xf16>
- %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 128, 256], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x128x256xf16>> -> tensor<16x128x256xf16>
- %7 = tensor.empty() : tensor<16x128x256xf16>
- %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<16x128x256xf16>) -> tensor<16x128x256xf16>
- %9 = linalg.batch_matmul ins(%4, %5 : tensor<16x128x512xf16>, tensor<16x512x256xf16>) outs(%8 : tensor<16x128x256xf16>) -> tensor<16x128x256xf16>
- %10 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%9, %6 : tensor<16x128x256xf16>, tensor<16x128x256xf16>) outs(%7 : tensor<16x128x256xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %11 = arith.divf %in, %in_0 : f16
- linalg.yield %11 : f16
- } -> tensor<16x128x256xf16>
- flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0], sizes = [16, 128, 256], strides = [1, 1, 1] : tensor<16x128x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x128x256xf16>>
- return
- }
- }
- }
-}
-
-// CHECK-LABEL: spirv.module Logical GLSL450
-
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
-
-// CHECK: spirv.func @batch_matmul_16x128x256x512_div
-
-// CHECK-DAG: %[[C5:.+]] = spirv.Constant 5 : i32
-// CHECK-DAG: %[[C17:.+]] = spirv.Constant 17 : i32
-// CHECK-DAG: %[[C32:.+]] = spirv.Constant 32 : i32
-// CHECK-DAG: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f16
-// CHECK: %{{.+}} = spirv.CompositeConstruct %[[F0]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-
-// CHECK-COUNT-4: %{{.+}} = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, Function>
-// CHECK: spirv.mlir.loop
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixMulAdd %{{.+}}, %{{.+}}, %{{.+}}
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-// CHECK-COUNT-4: spirv.Store "Function" %{{.+}}, %{{.+}}
-// CHECK: spirv.mlir.merge
-
-// CHECK-COUNT-8: %{{.+}} = spirv.Load "Function" %{{.+}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C32]], <RowMajor> : !spirv.ptr<vector<4xf32>, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-// CHECK-COUNT-8: %{{.+}} = spirv.FDiv %{{.+}}, %{{.+}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-// CHECK-COUNT-8: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C32]], <RowMajor>
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>
- ]>
-]>
-
-hal.executable public @generic_batch_matmul_32x128x512x64 {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 65536,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 1024],
- subgroup_size = 64>
- >}>) {
- hal.executable.export public @generic_batch_matmul_32x128x512x64 layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @generic_batch_matmul_32x128x512x64() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f16
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<32x128x64xf16>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x512xf16>>
- %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<32x128x512xf16>>
- %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [32, 128, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128x64xf16>> -> tensor<32x128x64xf16>
- %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [64, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x512xf16>> -> tensor<64x512xf16>
- %5 = tensor.empty() : tensor<32x128x512xf16>
- %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<32x128x512xf16>) -> tensor<32x128x512xf16>
- %7 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
- ins(%3, %4 : tensor<32x128x64xf16>, tensor<64x512xf16>) outs(%6 : tensor<32x128x512xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %8 = arith.mulf %in, %in_0 : f16
- %9 = arith.addf %out, %8 : f16
- linalg.yield %9 : f16
- } -> tensor<32x128x512xf16>
- flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [32, 128, 512], strides = [1, 1, 1] : tensor<32x128x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<32x128x512xf16>>
- return
- }
- }
- }
-}
-
-// CHECK-LABEL: spirv.module Logical GLSL450
-
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
-// CHECK-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
-
-// CHECK: spirv.func @generic_batch_matmul_32x128x512x64
-
-// CHECK-DAG: %[[C5:.+]] = spirv.Constant 5 : i32
-// CHECK-DAG: %[[C17:.+]] = spirv.Constant 17 : i32
-// CHECK-DAG: %[[C64:.+]] = spirv.Constant 64 : i32
-// CHECK-DAG: %[[C256:.+]] = spirv.Constant 256 : i32
-// CHECK-DAG: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f16
-// CHECK: %{{.+}} = spirv.CompositeConstruct %[[F0]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-//CHECK-COUNT-16: %{{.+}} = spirv.KHR.CooperativeMatrixMulAdd %{{.+}}, %{{.+}}, %{{.+}}
-
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: %{{.+}} = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
-// CHECK: spirv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-// CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
-
-// CHECK-COUNT-4: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C5]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
-// CHECK-COUNT-8: %{{.+}} = spirv.KHR.CooperativeMatrixLoad %{{.+}}, %[[C17]], <RowMajor> : !spirv.ptr<vector<4xf32>, Workgroup>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
-//CHECK-COUNT-16: %{{.+}} = spirv.KHR.CooperativeMatrixMulAdd %{{.+}}, %{{.+}}, %{{.+}}
-
-// CHECK-COUNT-8: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C64]], <RowMajor>
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>
- ]>
-]>
-
-#compilation = #iree_codegen.compilation_info<
- lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 64], [1, 16, 64], [0, 0, 0, 16], [1, 16, 16, 16]]>,
- translation_info = <SPIRVCooperativeMatrixVectorize workgroup_size = [32, 4, 1] subgroup_size = 32, {pipeline_depth = 1, store_stage = 1}>>
-
-hal.executable public @batch_matmul_f16_16x4096x4096x64_truncf_mulf {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, AMD:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 65536,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 1024],
- subgroup_size = 64>
- >}>) {
- hal.executable.export public @batch_matmul_f16_16x4096x4096x64_truncf_mulf layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @batch_matmul_f16_16x4096x4096x64_truncf_mulf() {
- %cst = arith.constant 0.158113882 : f32
- %cst_0 = arith.constant 0.000000e+00 : f16
- %c0 = arith.constant 0 : index
- %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x4096x64xf16>>
- %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x64x4096xf16>>
- %8 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16x4096x4096xf16>>
- %9 = flow.dispatch.tensor.load %6, offsets = [0, 0, 0], sizes = [16, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x4096x64xf16>> -> tensor<16x4096x64xf16>
- %10 = flow.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [16, 64, 4096], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x64x4096xf16>> -> tensor<16x64x4096xf16>
- %11 = tensor.empty() : tensor<16x4096x4096xf16>
- %12 = linalg.fill ins(%cst_0 : f16) outs(%11 : tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16>
- %13 = linalg.batch_matmul {compilation_info = #compilation}
- ins(%9, %10 : tensor<16x4096x64xf16>, tensor<16x64x4096xf16>)
- outs(%12 : tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16>
- %14 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%13 : tensor<16x4096x4096xf16>) outs(%11 : tensor<16x4096x4096xf16>) {
- ^bb0(%in: f16, %out: f16):
- %15 = arith.truncf %cst : f32 to f16
- %16 = arith.mulf %in, %15 : f16
- linalg.yield %16 : f16
- } -> tensor<16x4096x4096xf16>
- flow.dispatch.tensor.store %14, %8, offsets = [0, 0, 0], sizes = [16, 4096, 4096], strides = [1, 1, 1] : tensor<16x4096x4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x4096x4096xf16>>
- return
- }
- }
- }
-}
-
-// CHECK-LABEL: spirv.module Logical GLSL450
-
-// CHECK-NOT: spirv.GlobalVariable {{.+}} Workgroup
-// CHECK-COUNT-2: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<{{.+}}>)>, Workgroup>
-// CHECK-NOT: spirv.GlobalVariable {{.+}} Workgroup
-
-// CHECK: spirv.func @batch_matmul_f16_16x4096x4096x64_truncf_mulf
-
-// CHECK-DAG: %[[C512:.+]] = spirv.Constant 512 : i32
-// CHECK-DAG: %[[SCALAR:.+]] = spirv.Constant 0.158113882 : f32
-
-
-// CHECK-COUNT-4: %{{.+}} = spirv.Load "Function" %{{.+}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
-// CHECK: %[[CONVERT:.+]] = spirv.FConvert %[[SCALAR]] : f32 to f16
-// CHECK-COUNT-4: %{{.+}} = spirv.MatrixTimesScalar %{{.+}}, %[[CONVERT]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
-
-// CHECK-COUNT-4: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %[[C512]], <RowMajor>
+// RDNA3-LABEL: spirv.module Logical GLSL450
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<1088 x vector<4xf32>>)>, Workgroup>
+// RDNA3-DAG: spirv.GlobalVariable @{{.+}} : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
+// RDNA3: spirv.func @generic_batch_matmul_32x128x512x64
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
index 854fc57..37fbea0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -11,12 +11,7 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
hal.executable @matmul_f32_128x256x64 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f32_128x256x64 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -91,12 +86,7 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
hal.executable @matmul_f16_128x256x64 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, Float16], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f16_128x256x64 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -177,12 +167,7 @@
translation_info = <SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0, store_stage = 1}>>
hal.executable @matmul_f16_32x1280x1280 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, Float16, StorageBuffer16BitAccess], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [65535, 65535, 65535],
- subgroup_size = 32>>}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @matmul_f16_32x1280x1280 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
index 5b3eea9..1abafdd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -8,13 +8,7 @@
]>
]>
hal.executable private @fuse_and_vectorize_fill_matmul {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 16>>
- }>) {
+ hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @fuse_and_vectorize_fill_matmul layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
@@ -59,13 +53,7 @@
]>
]>
hal.executable private @fuse_and_vectorize_matmul_add {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 16>>
- }>) {
+ hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @fuse_and_vectorize_matmul_add layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
index 8f505f7..34acbde 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file \
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=pascal@vulkan \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' \
// RUN: %s | FileCheck %s
@@ -13,16 +13,10 @@
]>
hal.executable @i4_dequant_unit_matmul_f16 {
hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [
- Shader, Float16, StorageBuffer16BitAccess, GroupNonUniform,
- GroupNonUniformArithmetic, GroupNonUniformShuffle
- ], [SPV_KHR_16bit_storage]>, Unknown:IntegratedGPU,
- #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 64],
- subgroup_size = 32
- >>
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
}>) {
hal.executable.export @i4_dequant_unit_matmul_f16 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
@@ -130,17 +124,11 @@
]>
hal.executable @i4_dequant_matvec_f16_subgroup_64 {
hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [
- Shader, Float16, StorageBuffer16BitAccess, GroupNonUniform,
- GroupNonUniformArithmetic, GroupNonUniformShuffle
- ], [SPV_KHR_16bit_storage]>, Unknown:IntegratedGPU,
- #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [1024, 1024, 64],
- subgroup_size = 64
- >>
- }>) {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|fp16|int32, storage = b32|b16, subgroup = shuffle|arithmetic, dot = none, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+ }>) {
hal.executable.export @i4_dequant_matvec_f16_subgroup_64 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_reduction_subgroup.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_reduction_subgroup.mlir
index a402e2d..9c8825c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_reduction_subgroup.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_reduction_subgroup.mlir
@@ -1,4 +1,5 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s --check-prefix=NOSHUFFLE
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -7,13 +8,7 @@
]>
]>
hal.executable private @subgroup_reduce {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, ARM:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 16>>
- }>) {
+ hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export public @subgroup_reduce ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
@@ -94,51 +89,5 @@
// CHECK: spirv.ExecutionMode @{{.+}} "LocalSize", 128, 1, 1
-// -----
-
-// Check the case of no GroupNonUniformShuffle capability.
-
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable private @subgroup_reduce {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 16>>
- }>) {
- hal.executable.export public @subgroup_reduce ordinal(0) layout(#pipeline_layout) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @subgroup_reduce() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x512xf32>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2xf32>>
- %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x512xf32>> -> tensor<2x512xf32>
- %3 = tensor.empty() : tensor<2xf32>
- %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<2xf32>) -> tensor<2xf32>
- %5 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]
- } ins(%2 : tensor<2x512xf32>) outs(%4 : tensor<2xf32>) {
- ^bb0(%arg0: f32, %arg1: f32):
- %6 = arith.addf %arg1, %arg0 : f32
- linalg.yield %6 : f32
- } -> tensor<2xf32>
- flow.dispatch.tensor.store %5, %1, offsets = [0], sizes = [2], strides = [1] : tensor<2xf32> -> !flow.dispatch.tensor<writeonly:tensor<2xf32>>
- return
- }
- }
- }
-}
-
-// CHECK-NOT: spirv.GroupNonUniformShuffleXor
+// NOSHUFFLE-LABEL: spirv.func @subgroup_reduce()
+// NOSHUFFLE-NOT: spirv.GroupNonUniformShuffleXor
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_sub_byte_dequant.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_sub_byte_dequant.mlir
index b418a5e..e995a93 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_sub_byte_dequant.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_sub_byte_dequant.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=vp_android_baseline_2022@vulkan --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-spirv-configuration-pipeline), iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -9,13 +9,7 @@
]>
]>
hal.executable @i4_dequant {
- hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
- max_compute_shared_memory_size = 32768,
- max_compute_workgroup_invocations = 512,
- max_compute_workgroup_size = [512, 512, 512],
- subgroup_size = 64>>
- }>) {
+ hal.executable.variant @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @i4_dequant layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
index 21bc445..66b4d3c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/set_transform_strategy.mlir
@@ -1,10 +1,13 @@
-// RUN: iree-opt %s --split-input-file \
+// RUN: iree-opt %s --split-input-file --iree-gpu-test-target=volta@vulkan \
// RUN: --pass-pipeline="builtin.module(iree-spirv-select-lowering-strategy-pass)"\
-// RUN: --iree-spirv-enable-transform-dialect-jit=true | FileCheck %s
+// RUN: --iree-spirv-enable-transform-dialect-jit=true
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 8, a_type = f32, b_type = f32, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
+// TODO: Transform script based CodeGen expects fp32-input to target tensor
+// core, but there are no such wmma intrinsics. Fix it to support fp16-input.
+// TODO: | FileCheck %s
+
module {
- func.func @matmul() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
+ func.func @matmul() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2052x2556xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
index a35d8aa..23b272c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
@@ -1,20 +1,19 @@
-// RUN: iree-opt --split-input-file --mlir-print-local-scope \
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --iree-gpu-test-target=volta@vulkan \
// RUN: --pass-pipeline='builtin.module(func.func(iree-spirv-tile-and-promote{promote-c=false skip-thread=true}, cse))' \
// RUN: %s | FileCheck %s
-// RUN: iree-opt --split-input-file --mlir-print-local-scope \
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --iree-gpu-test-target=volta@vulkan \
// RUN: --pass-pipeline='builtin.module(func.func(iree-spirv-tile-and-promote{promote-c=true skip-thread=true}, cse))' \
// RUN: %s | FileCheck %s --check-prefix=PROMOTEC
// Single tile per workgroup means no subview ops for promotion.
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 32], [16, 16, 16], [0, 0, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [64, 2, 1]>
module {
- func.func @matmul_f16_32x32x32() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @matmul_f16_32x32x32() attributes {translation_info = #translation} {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
@@ -65,7 +64,6 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 32, 32], [1, 16, 16, 16], [0, 0, 0, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
@@ -73,7 +71,7 @@
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [64, 2, 1]>
module {
- func.func @generic_batch_matmul_f16_32x128x512x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @generic_batch_matmul_f16_32x128x512x64() attributes {translation_info = #translation} {
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
@@ -175,7 +173,6 @@
// Cooperative matrix fusable elementwise ops do not need promote C.
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 32, 32], [1, 16, 16, 16], [0, 0, 0, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
@@ -183,7 +180,7 @@
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [64, 2, 1]>
module {
- func.func @generic_batch_matmul_f16_32x128x512x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @generic_batch_matmul_f16_32x128x512x64() attributes {translation_info = #translation} {
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
@@ -255,14 +252,13 @@
// No need to promote C if there is no fused element wise ops.
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 32, 32], [1, 16, 16, 16], [0, 0, 0, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [2147483647, 65535, 65535], cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [64, 2, 1]>
module {
- func.func @generic_batch_matmul_f16_32x128x512x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @generic_batch_matmul_f16_32x128x512x64() attributes {translation_info = #translation} {
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
@@ -336,13 +332,12 @@
// No need to promote again with allocations from bufferization.
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 128], [1, 32, 64], [0, 0, 0, 32], [1, 16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_NV_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1]>
module {
- func.func @batch_matmul_f16_1x64x128x512() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @batch_matmul_f16_1x64x128x512() attributes {translation_info = #translation} {
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
@@ -408,14 +403,13 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_NV_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<(d0, d1) -> (d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1]>
module {
- func.func @matmul_f16_f512x4096x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @matmul_f16_f512x4096x64() attributes {translation_info = #translation} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
@@ -494,14 +488,13 @@
// Transposed+broadcasted elementwise ops does not need promoting C matrix.
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_NV_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1]>
module {
- func.func @matmul_f16_f512x4096x64() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @matmul_f16_f512x4096x64() attributes {translation_info = #translation} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
@@ -580,14 +573,13 @@
// Inlined large constant array needs promoting C matrix.
#config = #iree_codegen.lowering_config<tile_sizes = [[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR], [SPV_NV_cooperative_matrix]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>]>>}>
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [128, 2, 1]>
module {
- func.func @matmul_f16_128x262144x2304() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @matmul_f16_128x262144x2304() attributes {translation_info = #translation} {
%c128 = arith.constant 128 : index
%c262144 = arith.constant 262144 : index
%c96565312 = arith.constant 96565312 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
index 41c0e7c..c9a6289 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
@@ -1,14 +1,13 @@
-// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(func.func(iree-spirv-tile-and-promote, cse))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --iree-gpu-test-target=pascal@vulkan --pass-pipeline='builtin.module(func.func(iree-spirv-tile-and-promote, cse))' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [16, 4], [0, 0, 32]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [65535, 65535, 65535]>>}>
#map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1]>
module {
- func.func @matmul_f32_256x1024x128() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @matmul_f32_256x1024x128() attributes {translation_info = #translation} {
%c1024 = arith.constant 1024 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
@@ -107,13 +106,12 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 256], [1, 8, 8], [0, 0, 0, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float16], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 256)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1]>
module {
- func.func @batch_matmul_16x1024x1024x80() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @batch_matmul_16x1024x1024x80() attributes {translation_info = #translation} {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1024 = arith.constant 1024 : index
@@ -169,12 +167,11 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 512, 8], [1, 8, 4], [0, 0, 0, 16]]>
-#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>>}>
#map = affine_map<()[s0] -> (s0 * 512)>
#map1 = affine_map<()[s0] -> (s0 * 8)>
#translation = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [2, 64, 1]>
module {
- func.func @batch_matmul_f32_16x4096x40x4096() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb, translation_info = #translation} {
+ func.func @batch_matmul_f32_16x4096x40x4096() attributes {translation_info = #translation} {
%c16 = arith.constant 16 : index
%c4096 = arith.constant 4096 : index
%c40 = arith.constant 40 : index
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index 40e4b91..77c5f34 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -1,119 +1,75 @@
-// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-codegen-generic-vectorization, iree-spirv-vectorize-to-cooperative-ops, iree-codegen-optimize-tensor-insert-extract-slices, canonicalize, cse)))))' \
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=volta@vulkan \
+// RUN: --pass-pipeline='builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-codegen-generic-vectorization, iree-spirv-vectorize-to-cooperative-ops, iree-codegen-optimize-tensor-insert-extract-slices, canonicalize, cse))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 32], [16, 16], [0, 0, 32], [16, 16, 16]]>
-#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>,
- #hal.descriptor_set.binding<3, storage_buffer>,
- #hal.descriptor_set.binding<4, storage_buffer>
- ]>
-]>
-hal.executable public @matmul_256x1024x128_div_add {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
- hal.executable.export public @matmul_256x1024x128_div_add layout(#pipeline_layout) attributes {
- translation_info = #translation,
- workgroup_size = [32 : index, 1 : index, 1 : index]
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg0]
- %1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg1]
- hal.return %0, %1, %c1 : index, index, index
- }
- builtin.module {
- func.func @matmul_256x1024x128_div_add() {
- %cst = arith.constant 0.000000e+00 : f16
- %c0 = arith.constant 0 : index
- %c32 = arith.constant 32 : index
- %c1024 = arith.constant 1024 : index
- %0 = gpu.thread_id x
- %1 = gpu.thread_id y
- %2 = gpu.thread_id z
- %alloc = memref.alloc() : memref<32x32xf16, 3>
- %alloc_0 = memref.alloc() : memref<32x32xf16, 3>
- %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<256x1024xf16>
- %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1024x128xf16>
- %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
- %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
- %7 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
- %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
- %subview = memref.subview %7[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
- %subview_1 = memref.subview %3[%8, 0] [32, 1024] [1, 1] : memref<256x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
- %subview_2 = memref.subview %4[0, %9] [1024, 32] [1, 1] : memref<1024x128xf16> to memref<1024x32xf16, strided<[128, 1], offset: ?>>
- linalg.fill {__internal_linalg_transform__ = "workgroup_memory"} ins(%cst : f16) outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
- scf.for %arg0 = %c0 to %c1024 step %c32 {
- %subview_5 = memref.subview %subview_1[0, %arg0] [32, 32] [1, 1] : memref<32x1024xf16, strided<[1024, 1], offset: ?>> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
- %subview_6 = memref.subview %subview_2[%arg0, 0] [32, 32] [1, 1] : memref<1024x32xf16, strided<[128, 1], offset: ?>> to memref<32x32xf16, strided<[128, 1], offset: ?>>
- gpu.barrier
- %subview_7 = memref.subview %alloc[%c0, %c0] [32, 32] [1, 1] : memref<32x32xf16, 3> to memref<32x32xf16, strided<[32, 1], offset: ?>, 3>
- %10 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %11 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %12 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %13 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_8 = memref.subview %subview_5[%10, %11] [1, 8] [1, 1] : memref<32x32xf16, strided<[1024, 1], offset: ?>> to memref<1x8xf16, strided<[1024, 1], offset: ?>>
- %subview_9 = memref.subview %subview_7[%12, %13] [1, 8] [1, 1] : memref<32x32xf16, strided<[32, 1], offset: ?>, 3> to memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
- %14 = vector.transfer_read %subview_8[%c0, %c0], %cst {in_bounds = [true, true]} : memref<1x8xf16, strided<[1024, 1], offset: ?>>, vector<1x8xf16>
- vector.transfer_write %14, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
- %subview_10 = memref.subview %alloc_0[%c0, %c0] [32, 32] [1, 1] : memref<32x32xf16, 3> to memref<32x32xf16, strided<[32, 1], offset: ?>, 3>
- %15 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %16 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %17 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %18 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_11 = memref.subview %subview_6[%15, %16] [1, 8] [1, 1] : memref<32x32xf16, strided<[128, 1], offset: ?>> to memref<1x8xf16, strided<[128, 1], offset: ?>>
- %subview_12 = memref.subview %subview_10[%17, %18] [1, 8] [1, 1] : memref<32x32xf16, strided<[32, 1], offset: ?>, 3> to memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
- %19 = vector.transfer_read %subview_11[%c0, %c0], %cst {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>>, vector<1x8xf16>
- vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
- gpu.barrier
- linalg.matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config}
- ins(%alloc, %alloc_0 : memref<32x32xf16, 3>, memref<32x32xf16, 3>) outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
- }
- %subview_3 = memref.subview %5[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
- %subview_4 = memref.subview %6[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
- linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- }
- ins(%subview_3, %subview_4 : memref<32x32xf16, strided<[128, 1], offset: ?>>, memref<32x32xf16, strided<[128, 1], offset: ?>>)
- outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
- attrs = {__internal_linalg_transform__ = "workgroup_memory"} {
- ^bb0(%in: f16, %in_5: f16, %out: f16):
- %10 = arith.divf %out, %in : f16
- %11 = arith.addf %10, %in_5 : f16
- linalg.yield %11 : f16
- }
- return
- }
- }
+#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size = [32, 1, 1]>
+builtin.module {
+func.func @matmul_256x1024x128_div_add() attributes {translation_info = #translation} {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c1024 = arith.constant 1024 : index
+ %0 = gpu.thread_id x
+ %1 = gpu.thread_id y
+ %2 = gpu.thread_id z
+ %alloc = memref.alloc() : memref<32x32xf16, 3>
+ %alloc_0 = memref.alloc() : memref<32x32xf16, 3>
+ %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<256x1024xf16>
+ %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1024x128xf16>
+ %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
+ %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
+ %7 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xf16>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+ %subview = memref.subview %7[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
+ %subview_1 = memref.subview %3[%8, 0] [32, 1024] [1, 1] : memref<256x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
+ %subview_2 = memref.subview %4[0, %9] [1024, 32] [1, 1] : memref<1024x128xf16> to memref<1024x32xf16, strided<[128, 1], offset: ?>>
+ linalg.fill {__internal_linalg_transform__ = "workgroup_memory"} ins(%cst : f16) outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
+ scf.for %arg0 = %c0 to %c1024 step %c32 {
+ %subview_5 = memref.subview %subview_1[0, %arg0] [32, 32] [1, 1] : memref<32x1024xf16, strided<[1024, 1], offset: ?>> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
+ %subview_6 = memref.subview %subview_2[%arg0, 0] [32, 32] [1, 1] : memref<1024x32xf16, strided<[128, 1], offset: ?>> to memref<32x32xf16, strided<[128, 1], offset: ?>>
+ gpu.barrier
+ %subview_7 = memref.subview %alloc[%c0, %c0] [32, 32] [1, 1] : memref<32x32xf16, 3> to memref<32x32xf16, strided<[32, 1], offset: ?>, 3>
+ %10 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %11 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %12 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %13 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_8 = memref.subview %subview_5[%10, %11] [1, 8] [1, 1] : memref<32x32xf16, strided<[1024, 1], offset: ?>> to memref<1x8xf16, strided<[1024, 1], offset: ?>>
+ %subview_9 = memref.subview %subview_7[%12, %13] [1, 8] [1, 1] : memref<32x32xf16, strided<[32, 1], offset: ?>, 3> to memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
+ %14 = vector.transfer_read %subview_8[%c0, %c0], %cst {in_bounds = [true, true]} : memref<1x8xf16, strided<[1024, 1], offset: ?>>, vector<1x8xf16>
+ vector.transfer_write %14, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
+ %subview_10 = memref.subview %alloc_0[%c0, %c0] [32, 32] [1, 1] : memref<32x32xf16, 3> to memref<32x32xf16, strided<[32, 1], offset: ?>, 3>
+ %15 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %16 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %17 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %18 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_11 = memref.subview %subview_6[%15, %16] [1, 8] [1, 1] : memref<32x32xf16, strided<[128, 1], offset: ?>> to memref<1x8xf16, strided<[128, 1], offset: ?>>
+ %subview_12 = memref.subview %subview_10[%17, %18] [1, 8] [1, 1] : memref<32x32xf16, strided<[32, 1], offset: ?>, 3> to memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
+ %19 = vector.transfer_read %subview_11[%c0, %c0], %cst {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>>, vector<1x8xf16>
+ vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[32, 1], offset: ?>, 3>
+ gpu.barrier
+ linalg.matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config}
+ ins(%alloc, %alloc_0 : memref<32x32xf16, 3>, memref<32x32xf16, 3>) outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
}
+ %subview_3 = memref.subview %5[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
+ %subview_4 = memref.subview %6[%8, %9] [32, 32] [1, 1] : memref<256x128xf16> to memref<32x32xf16, strided<[128, 1], offset: ?>>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%subview_3, %subview_4 : memref<32x32xf16, strided<[128, 1], offset: ?>>, memref<32x32xf16, strided<[128, 1], offset: ?>>)
+ outs(%subview : memref<32x32xf16, strided<[128, 1], offset: ?>>)
+ attrs = {__internal_linalg_transform__ = "workgroup_memory"} {
+ ^bb0(%in: f16, %in_5: f16, %out: f16):
+ %10 = arith.divf %out, %in : f16
+ %11 = arith.addf %10, %in_5 : f16
+ linalg.yield %11 : f16
+ }
+ return
+}
}
// CHECK: #[[$MAP_Y:.+]] = affine_map<()[s0] -> (s0 * 16)>
@@ -172,121 +128,78 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 32, 32], [1, 16, 16], [0, 0, 0, 32], [1, 16, 16, 16]]>
-#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>,
- #hal.descriptor_set.binding<3, storage_buffer>,
- #hal.descriptor_set.binding<4, storage_buffer>
- ]>
-]>
-hal.executable public @matmul_256x1024x128_div_add {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
- hal.executable.export public @matmul_256x1024x128_div_add layout(#pipeline_layout) attributes {
- translation_info = #translation,
- workgroup_size = [32 : index, 1 : index, 1 : index]
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg0]
- %1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg1]
- hal.return %0, %1, %c1 : index, index, index
+#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size=[32, 1, 1]>
+builtin.module {
+ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #translation} {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c512 = arith.constant 512 : index
+ %c1 = arith.constant 1 : index
+ %0 = gpu.thread_id x
+ %1 = gpu.thread_id y
+ %2 = gpu.thread_id z
+ %alloc = memref.alloc() : memref<1x32x32xf16, 3>
+ %alloc_0 = memref.alloc() : memref<1x32x32xf16, 3>
+ %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x512xf16>
+ memref.assume_alignment %3, 64 : memref<16x128x512xf16>
+ %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<16x512x256xf16>
+ memref.assume_alignment %4, 64 : memref<16x512x256xf16>
+ %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x256xf16>
+ memref.assume_alignment %5, 64 : memref<16x128x256xf16>
+ %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x256xf16>
+ memref.assume_alignment %6, 64 : memref<16x128x256xf16>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+ %subview = memref.subview %6[%workgroup_id_z, %7, %8] [1, 32, 32] [1, 1, 1] : memref<16x128x256xf16> to memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>
+ %subview_1 = memref.subview %3[%workgroup_id_z, %7, 0] [1, 32, 512] [1, 1, 1] : memref<16x128x512xf16> to memref<1x32x512xf16, strided<[65536, 512, 1], offset: ?>>
+ %subview_2 = memref.subview %4[%workgroup_id_z, 0, %8] [1, 512, 32] [1, 1, 1] : memref<16x512x256xf16> to memref<1x512x32xf16, strided<[131072, 256, 1], offset: ?>>
+ linalg.fill {__internal_linalg_transform__ = "workgroup_memory"}
+ ins(%cst : f16) outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
+ scf.for %arg0 = %c0 to %c512 step %c32 {
+ %subview_4 = memref.subview %subview_1[0, 0, %arg0] [1, 32, 32] [1, 1, 1] : memref<1x32x512xf16, strided<[65536, 512, 1], offset: ?>> to memref<1x32x32xf16, strided<[65536, 512, 1], offset: ?>>
+ %subview_5 = memref.subview %subview_2[0, %arg0, 0] [1, 32, 32] [1, 1, 1] : memref<1x512x32xf16, strided<[131072, 256, 1], offset: ?>> to memref<1x32x32xf16, strided<[131072, 256, 1], offset: ?>>
+ gpu.barrier
+ %subview_6 = memref.subview %alloc[%c0, %c0, %c0] [1, 32, 32] [1, 1, 1] : memref<1x32x32xf16, 3> to memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ %9 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %10 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %11 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %12 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_7 = memref.subview %subview_4[0, %9, %10] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[65536, 512, 1], offset: ?>> to memref<1x1x8xf16, strided<[65536, 512, 1], offset: ?>>
+ %subview_8 = memref.subview %subview_6[0, %11, %12] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3> to memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ %13 = vector.transfer_read %subview_7[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8xf16, strided<[65536, 512, 1], offset: ?>>, vector<1x1x8xf16>
+ vector.transfer_write %13, %subview_8[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ %subview_9 = memref.subview %alloc_0[%c0, %c0, %c0] [1, 32, 32] [1, 1, 1] : memref<1x32x32xf16, 3> to memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ %14 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %15 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %16 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %17 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_10 = memref.subview %subview_5[0, %14, %15] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[131072, 256, 1], offset: ?>> to memref<1x1x8xf16, strided<[131072, 256, 1], offset: ?>>
+ %subview_11 = memref.subview %subview_9[0, %16, %17] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3> to memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ %18 = vector.transfer_read %subview_10[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8xf16, strided<[131072, 256, 1], offset: ?>>, vector<1x1x8xf16>
+ vector.transfer_write %18, %subview_11[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
+ gpu.barrier
+ linalg.batch_matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config}
+ ins(%alloc, %alloc_0 : memref<1x32x32xf16, 3>, memref<1x32x32xf16, 3>) outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
}
- builtin.module {
- func.func @matmul_256x1024x128_div_add() {
- %cst = arith.constant 0.000000e+00 : f16
- %c0 = arith.constant 0 : index
- %c32 = arith.constant 32 : index
- %c512 = arith.constant 512 : index
- %c1 = arith.constant 1 : index
- %0 = gpu.thread_id x
- %1 = gpu.thread_id y
- %2 = gpu.thread_id z
- %alloc = memref.alloc() : memref<1x32x32xf16, 3>
- %alloc_0 = memref.alloc() : memref<1x32x32xf16, 3>
- %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x512xf16>
- memref.assume_alignment %3, 64 : memref<16x128x512xf16>
- %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<16x512x256xf16>
- memref.assume_alignment %4, 64 : memref<16x512x256xf16>
- %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x256xf16>
- memref.assume_alignment %5, 64 : memref<16x128x256xf16>
- %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<16x128x256xf16>
- memref.assume_alignment %6, 64 : memref<16x128x256xf16>
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_id_z = hal.interface.workgroup.id[2] : index
- %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
- %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
- %subview = memref.subview %6[%workgroup_id_z, %7, %8] [1, 32, 32] [1, 1, 1] : memref<16x128x256xf16> to memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>
- %subview_1 = memref.subview %3[%workgroup_id_z, %7, 0] [1, 32, 512] [1, 1, 1] : memref<16x128x512xf16> to memref<1x32x512xf16, strided<[65536, 512, 1], offset: ?>>
- %subview_2 = memref.subview %4[%workgroup_id_z, 0, %8] [1, 512, 32] [1, 1, 1] : memref<16x512x256xf16> to memref<1x512x32xf16, strided<[131072, 256, 1], offset: ?>>
- linalg.fill {__internal_linalg_transform__ = "workgroup_memory"}
- ins(%cst : f16) outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
- scf.for %arg0 = %c0 to %c512 step %c32 {
- %subview_4 = memref.subview %subview_1[0, 0, %arg0] [1, 32, 32] [1, 1, 1] : memref<1x32x512xf16, strided<[65536, 512, 1], offset: ?>> to memref<1x32x32xf16, strided<[65536, 512, 1], offset: ?>>
- %subview_5 = memref.subview %subview_2[0, %arg0, 0] [1, 32, 32] [1, 1, 1] : memref<1x512x32xf16, strided<[131072, 256, 1], offset: ?>> to memref<1x32x32xf16, strided<[131072, 256, 1], offset: ?>>
- gpu.barrier
- %subview_6 = memref.subview %alloc[%c0, %c0, %c0] [1, 32, 32] [1, 1, 1] : memref<1x32x32xf16, 3> to memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3>
- %9 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %10 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %11 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %12 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_7 = memref.subview %subview_4[0, %9, %10] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[65536, 512, 1], offset: ?>> to memref<1x1x8xf16, strided<[65536, 512, 1], offset: ?>>
- %subview_8 = memref.subview %subview_6[0, %11, %12] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3> to memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
- %13 = vector.transfer_read %subview_7[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8xf16, strided<[65536, 512, 1], offset: ?>>, vector<1x1x8xf16>
- vector.transfer_write %13, %subview_8[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
- %subview_9 = memref.subview %alloc_0[%c0, %c0, %c0] [1, 32, 32] [1, 1, 1] : memref<1x32x32xf16, 3> to memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3>
- %14 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %15 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %16 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %17 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_10 = memref.subview %subview_5[0, %14, %15] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[131072, 256, 1], offset: ?>> to memref<1x1x8xf16, strided<[131072, 256, 1], offset: ?>>
- %subview_11 = memref.subview %subview_9[0, %16, %17] [1, 1, 8] [1, 1, 1] : memref<1x32x32xf16, strided<[1024, 32, 1], offset: ?>, 3> to memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
- %18 = vector.transfer_read %subview_10[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8xf16, strided<[131072, 256, 1], offset: ?>>, vector<1x1x8xf16>
- vector.transfer_write %18, %subview_11[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, memref<1x1x8xf16, strided<[1024, 32, 1], offset: ?>, 3>
- gpu.barrier
- linalg.batch_matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config}
- ins(%alloc, %alloc_0 : memref<1x32x32xf16, 3>, memref<1x32x32xf16, 3>) outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
- }
- %subview_3 = memref.subview %5[%workgroup_id_z, %7, %8] [1, 32, 32] [1, 1, 1] : memref<16x128x256xf16> to memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>
- linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%subview_3 : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
- outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
- attrs = {__internal_linalg_transform__ = "workgroup_memory"} {
- ^bb0(%in: f16, %out: f16):
- %9 = arith.divf %out, %in : f16
- linalg.yield %9 : f16
- }
- return
- }
+ %subview_3 = memref.subview %5[%workgroup_id_z, %7, %8] [1, 32, 32] [1, 1, 1] : memref<16x128x256xf16> to memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%subview_3 : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
+ outs(%subview : memref<1x32x32xf16, strided<[32768, 256, 1], offset: ?>>)
+ attrs = {__internal_linalg_transform__ = "workgroup_memory"} {
+ ^bb0(%in: f16, %out: f16):
+ %9 = arith.divf %out, %in : f16
+ linalg.yield %9 : f16
}
+ return
}
}
+
// CHECK: #[[$MAP_Y:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK: #[[$MAP_X:.+]] = affine_map<()[s0] -> ((s0 floordiv 32) * 16)>
@@ -347,113 +260,69 @@
// -----
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 32], [16, 16], [0, 0, 32], [16, 16, 16]]>
-#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>,
- #hal.descriptor_set.binding<2, storage_buffer>,
- #hal.descriptor_set.binding<3, storage_buffer>,
- #hal.descriptor_set.binding<4, storage_buffer>
- ]>
-]>
-hal.executable public @matmul_256x1024x128_mixed_signedness_int8 {
- hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb", {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.6,
- [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixKHR],
- [SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, NVIDIA:DiscreteGPU,
- #spirv.resource_limits<
- cooperative_matrix_properties_khr = [
- #spirv.coop_matrix_props_khr<
- a_type = i8, b_type = i8, c_type = i32, k_size = 32,
- m_size = 8, n_size = 8, result_type = i32, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f16, k_size = 16,
- m_size = 16, n_size = 16, result_type = f16, acc_sat = false, scope = <Subgroup>>,
- #spirv.coop_matrix_props_khr<
- a_type = f16, b_type = f16, c_type = f32, k_size = 16,
- m_size = 16, n_size = 16, result_type = f32, acc_sat = false, scope = <Subgroup>>
- ],
- max_compute_shared_memory_size = 49152,
- max_compute_workgroup_invocations = 1024,
- max_compute_workgroup_size = [2147483647, 65535, 65535],
- subgroup_size = 32>
- >}>) {
- hal.executable.export public @matmul_256x1024x128_mixed_signedness_int8 layout(#pipeline_layout) attributes {
- translation_info = #translation,
- workgroup_size = [32 : index, 1 : index, 1 : index]
- } {
- ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
- %c1 = arith.constant 1 : index
- %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg0]
- %1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg1]
- hal.return %0, %1, %c1 : index, index, index
+#translation = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize workgroup_size=[32, 1, 1]>
+builtin.module {
+func.func @matmul_256x1024x128_mixed_signedness_int8() {
+ %cst = arith.constant 0 : i32
+ %cst_i8 = arith.constant 0 : i8
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c1024 = arith.constant 1024 : index
+ %0 = gpu.thread_id x
+ %1 = gpu.thread_id y
+ %2 = gpu.thread_id z
+ %alloc = memref.alloc() : memref<32x32xi8, 3>
+ %alloc_0 = memref.alloc() : memref<32x32xi8, 3>
+ %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<256x1024xi8>
+ %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1024x128xi8>
+ %7 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xi32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+ %subview = memref.subview %7[%8, %9] [32, 32] [1, 1] : memref<256x128xi32> to memref<32x32xi32, strided<[128, 1], offset: ?>>
+ %subview_1 = memref.subview %3[%8, 0] [32, 1024] [1, 1] : memref<256x1024xi8> to memref<32x1024xi8, strided<[1024, 1], offset: ?>>
+ %subview_2 = memref.subview %4[0, %9] [1024, 32] [1, 1] : memref<1024x128xi8> to memref<1024x32xi8, strided<[128, 1], offset: ?>>
+ linalg.fill {__internal_linalg_transform__ = "workgroup_memory"} ins(%cst : i32) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>)
+ scf.for %arg0 = %c0 to %c1024 step %c32 {
+ %subview_5 = memref.subview %subview_1[0, %arg0] [32, 32] [1, 1] : memref<32x1024xi8, strided<[1024, 1], offset: ?>> to memref<32x32xi8, strided<[1024, 1], offset: ?>>
+ %subview_6 = memref.subview %subview_2[%arg0, 0] [32, 32] [1, 1] : memref<1024x32xi8, strided<[128, 1], offset: ?>> to memref<32x32xi8, strided<[128, 1], offset: ?>>
+ gpu.barrier
+ %subview_7 = memref.subview %alloc[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3>
+ %10 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %11 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %12 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %13 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_8 = memref.subview %subview_5[%10, %11] [1, 8] [1, 1] : memref<32x32xi8, strided<[1024, 1], offset: ?>> to memref<1x8xi8, strided<[1024, 1], offset: ?>>
+ %subview_9 = memref.subview %subview_7[%12, %13] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
+ %14 = vector.transfer_read %subview_8[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[1024, 1], offset: ?>>, vector<1x8xi8>
+ vector.transfer_write %14, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
+ %subview_10 = memref.subview %alloc_0[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3>
+ %15 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %16 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %17 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
+ %18 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
+ %subview_11 = memref.subview %subview_6[%15, %16] [1, 8] [1, 1] : memref<32x32xi8, strided<[128, 1], offset: ?>> to memref<1x8xi8, strided<[128, 1], offset: ?>>
+ %subview_12 = memref.subview %subview_10[%17, %18] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
+ %19 = vector.transfer_read %subview_11[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[128, 1], offset: ?>>, vector<1x8xi8>
+ vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
+ gpu.barrier
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]
}
- builtin.module {
- func.func @matmul_256x1024x128_mixed_signedness_int8() {
- %cst = arith.constant 0 : i32
- %cst_i8 = arith.constant 0 : i8
- %c0 = arith.constant 0 : index
- %c32 = arith.constant 32 : index
- %c1024 = arith.constant 1024 : index
- %0 = gpu.thread_id x
- %1 = gpu.thread_id y
- %2 = gpu.thread_id z
- %alloc = memref.alloc() : memref<32x32xi8, 3>
- %alloc_0 = memref.alloc() : memref<32x32xi8, 3>
- %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<256x1024xi8>
- %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1024x128xi8>
- %7 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xi32>
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
- %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
- %subview = memref.subview %7[%8, %9] [32, 32] [1, 1] : memref<256x128xi32> to memref<32x32xi32, strided<[128, 1], offset: ?>>
- %subview_1 = memref.subview %3[%8, 0] [32, 1024] [1, 1] : memref<256x1024xi8> to memref<32x1024xi8, strided<[1024, 1], offset: ?>>
- %subview_2 = memref.subview %4[0, %9] [1024, 32] [1, 1] : memref<1024x128xi8> to memref<1024x32xi8, strided<[128, 1], offset: ?>>
- linalg.fill {__internal_linalg_transform__ = "workgroup_memory"} ins(%cst : i32) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>)
- scf.for %arg0 = %c0 to %c1024 step %c32 {
- %subview_5 = memref.subview %subview_1[0, %arg0] [32, 32] [1, 1] : memref<32x1024xi8, strided<[1024, 1], offset: ?>> to memref<32x32xi8, strided<[1024, 1], offset: ?>>
- %subview_6 = memref.subview %subview_2[%arg0, 0] [32, 32] [1, 1] : memref<1024x32xi8, strided<[128, 1], offset: ?>> to memref<32x32xi8, strided<[128, 1], offset: ?>>
- gpu.barrier
- %subview_7 = memref.subview %alloc[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3>
- %10 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %11 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %12 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %13 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_8 = memref.subview %subview_5[%10, %11] [1, 8] [1, 1] : memref<32x32xi8, strided<[1024, 1], offset: ?>> to memref<1x8xi8, strided<[1024, 1], offset: ?>>
- %subview_9 = memref.subview %subview_7[%12, %13] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
- %14 = vector.transfer_read %subview_8[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[1024, 1], offset: ?>>, vector<1x8xi8>
- vector.transfer_write %14, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
- %subview_10 = memref.subview %alloc_0[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3>
- %15 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %16 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %17 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- %18 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- %subview_11 = memref.subview %subview_6[%15, %16] [1, 8] [1, 1] : memref<32x32xi8, strided<[128, 1], offset: ?>> to memref<1x8xi8, strided<[128, 1], offset: ?>>
- %subview_12 = memref.subview %subview_10[%17, %18] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
- %19 = vector.transfer_read %subview_11[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[128, 1], offset: ?>>, vector<1x8xi8>
- vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3>
- gpu.barrier
- linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel", "reduction"]
- }
- ins(%alloc, %alloc_0 : memref<32x32xi8, 3>, memref<32x32xi8, 3>) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>)
- attrs = {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config} {
- ^bb0(%in: i8, %in_5: i8, %out: i32):
- %20 = arith.extui %in : i8 to i32
- %21 = arith.extsi %in_5 : i8 to i32
- %22 = arith.muli %20, %21 : i32
- %23 = arith.addi %22, %out : i32
- linalg.yield %23 : i32
- }
- }
- return
- }
+ ins(%alloc, %alloc_0 : memref<32x32xi8, 3>, memref<32x32xi8, 3>) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>)
+ attrs = {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config} {
+ ^bb0(%in: i8, %in_5: i8, %out: i32):
+ %20 = arith.extui %in : i8 to i32
+ %21 = arith.extsi %in_5 : i8 to i32
+ %22 = arith.muli %20, %21 : i32
+ %23 = arith.addi %22, %out : i32
+ linalg.yield %23 : i32
}
}
+ return
+}
}
// CHECK: #[[$MAP_Y:.+]] = affine_map<()[s0] -> (s0 * 16)>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
index 2faab55..105012f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
@@ -27,10 +27,11 @@
// the target env. We expect the conv to follow the inner product lowering.
func.func @nwc_conv_1d_dot_prod(%input: tensor<1x7x3xi8>, %filter: tensor<1x3x4xi8>) -> tensor<1x4x4xi32> attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
- [DotProduct, DotProductInputAll, DotProductInput4x8Bit],
- [SPV_KHR_integer_dot_product]>,
- #spirv.resource_limits<>> } {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+} {
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
%init = tensor.empty() : tensor<1x4x4xi32>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index 47aab36..c2da6a5 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -268,10 +268,11 @@
// the target env. We expect the matmul to follow the inner product lowering.
func.func @matmul_4x4x4_i8_to_i32_dot_prod(%lhs: tensor<4x4xi8>, %rhs : tensor<4x4xi8>) -> tensor<4x4xi32> attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
- [DotProduct, DotProductInputAll, DotProductInput4x8Bit],
- [SPV_KHR_integer_dot_product]>,
- #spirv.resource_limits<>> } {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+} {
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
%init = tensor.empty() : tensor<4x4xi32>
@@ -326,10 +327,11 @@
// the target env. We expect the matmul to follow the inner product lowering.
func.func @matmul_4x16x4_i8_to_i32_dot_prod(%lhs: tensor<4x16xi8>, %rhs : tensor<16x4xi8>) -> tensor<4x4xi32> attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
- [DotProduct, DotProductInputAll, DotProductInput4x8Bit],
- [SPV_KHR_integer_dot_product]>,
- #spirv.resource_limits<>> } {
+ iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
+ compute = fp32|int32|int16|int8, storage = b32|b16|b8, subgroup = none, dot = dp4xi8toi32, mma = [],
+ subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
+} {
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
%init = tensor.empty() : tensor<4x4xi32>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index fc43fa4..dbf3b89 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -38,8 +38,10 @@
static llvm::cl::opt<std::string> clTestTarget(
"iree-gpu-test-target",
llvm::cl::desc(
- "The target for IR LIT tests; the interpretation depends on the target "
- "API. e.g., \"gfx942\" for HIP, \"sm_80\" for CUDA"),
+ "The target for IR LIT tests. Format is '<arch>:<feature>@<api>', "
+ "where <feature> and <api> are optional; e.g., "
+ "'gfx942:+sramecc,-xnack@hip'. If <api> is missing, it will be deduced "
+ "from <arch>; e.g., 'gfx*' defaults to HIP, 'sm_*' defaults to CUDA"),
llvm::cl::init(""));
namespace mlir::iree_compiler {
@@ -956,41 +958,57 @@
// GPU Target Information
//===----------------------------------------------------------------------===//
+static IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) {
+ if (clTestTarget.empty())
+ return nullptr;
+
+ auto [archAndFeatures, backend] = StringRef(clTestTarget).split("@");
+ if (backend.empty()) {
+ // Guess what the target API is based on common scheme. This does not work
+ // for cases like "ampere" which can be accepted by both CUDA and Vulkan;
+ // it's very limited. So it's targeting common cases to make writing tests
+ // simpler.
+ if (StringRef(clTestTarget).starts_with("sm_"))
+ backend = "cuda";
+ else if (StringRef(clTestTarget).starts_with("gfx"))
+ backend = "hip";
+ else if (StringRef(clTestTarget).starts_with("adreno"))
+ backend = "vulkan";
+ else if (StringRef(clTestTarget).starts_with("apple"))
+ backend = "vulkan";
+ else if (StringRef(clTestTarget).starts_with("valhall"))
+ backend = "vulkan";
+ }
+ auto [arch, features] = StringRef(archAndFeatures).split(':');
+ // Use the target specified in the command line for testing purposes.
+ return IREE::GPU::getFullTarget(backend, arch, features, context);
+}
+
IREE::GPU::TargetAttr getGPUTargetAttr(IREE::HAL::ExecutableTargetAttr target) {
if (auto config = target.getConfiguration()) {
if (auto attr = config.getAs<IREE::GPU::TargetAttr>("iree.gpu.target"))
return attr;
}
- if (!clTestTarget.empty()) {
- auto [arch, features] = StringRef(clTestTarget).split(':');
- // Use the target specified in the command line for testing purposes.
- return IREE::GPU::getFullTarget(target.getBackend(), arch, features,
- target.getContext());
- }
-
- return nullptr;
+ return getCLGPUTarget(target.getContext());
}
IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op) {
if (auto target = IREE::HAL::ExecutableTargetAttr::lookup(op)) {
return getGPUTargetAttr(target);
}
- if (!clTestTarget.empty()) {
- // Guess what the target API is based on common scheme. This does not work
- // for cases like "ampere" which can be accepted by both CUDA and Vulkan.
- // So it's very limited. However, it makes writing tests simpler. Maybe we
- // should consider making it explicit in the clTestTarget what API we are
- // targeting.
- StringRef backend;
- if (StringRef(clTestTarget).starts_with("sm_"))
- backend = "cuda";
- else if (StringRef(clTestTarget).starts_with("gfx"))
- backend = "rocm";
- auto [arch, features] = StringRef(clTestTarget).split(':');
- // Use the target specified in the command line for testing purposes.
- return IREE::GPU::getFullTarget(backend, arch, features, op->getContext());
- }
- return nullptr;
+ return getCLGPUTarget(op->getContext());
+}
+
+std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
+ bool pickLargest) {
+ // First try to see if there is a subgroup size chosen in the CodeGen pipeline
+ // configuration.
+ if (std::optional<int64_t> subgroupSize = getSubgroupSize(func))
+ return subgroupSize.value();
+ // Then try to find the subgroup size from the target description.
+ if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func))
+ return target.getPreferredSubgroupSize(pickLargest);
+ return std::nullopt;
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index 19bb3c7..74ef772 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -152,6 +152,12 @@
/// if found. Returns null TargetAttr othersise.
IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);
+/// Returns the GPU subgroup size chosen for the current CodeGen pipeline if
+/// exists; otherwise returns the subgroup size from the GPU target description.
+/// Returns std::nullopt if none found.
+std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
+ bool pickLargest);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_