[spirv] Make default configuration consider all partitioned loops (#7464) This commit revises the default linalg op configuration to consider all partitioned loops when deciding workgroup size and tiling schemes. Previously we only consider the innermost partitioned loop and distribute that to 1-D workgroup. This helps to fill the GPU for the cases where we have small innermost dimensions, like 244x224x3. Under the previous configuration, we will just distribute to a workgorup of, like, 64x1x1, which starves the GPU and pushes all the work to effectively 3 threads. With the current configuration this would be distributed to a workgroup of 1x32x2.
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD index 86950b3..f774694 100644 --- a/integrations/tensorflow/e2e/keras/layers/BUILD +++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -172,6 +172,7 @@ "MaxPool2D", "Conv3D", #TODO(#5150): Enable the test. "Conv3DTranspose", #TODO(#5150): Enable the test. + "UpSampling3D", #TODO(#7481): rank-reducing memref folder generates invalid linalg.copy ops for SwiftShader. ], "target_backends": "iree_vulkan", },
diff --git a/integrations/tensorflow/e2e/keras/layers/CMakeLists.txt b/integrations/tensorflow/e2e/keras/layers/CMakeLists.txt index 053ec55..4b76d63 100644 --- a/integrations/tensorflow/e2e/keras/layers/CMakeLists.txt +++ b/integrations/tensorflow/e2e/keras/layers/CMakeLists.txt
@@ -66,6 +66,7 @@ ",,MaxPool2D,,,,iree_vulkan" ",,Conv3D,,,,iree_vulkan" ",,Conv3DTranspose,,,,iree_vulkan" + ",,UpSampling3D,,,,iree_vulkan" ) iree_e2e_cartesian_product_test_suite(
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index d2bb4e1..3692cf6 100644 --- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -11,12 +11,14 @@ #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" @@ -214,9 +216,9 @@ int64_t residualTilingFactor = (bestThreadM + bestThreadK) * bestThreadN; SmallVector<int64_t, 3> workgroupSize(3, 1); // (X, Y, Z) - SmallVector<int64_t> workgroupTileSizes(2 + isBM, 0); // (B, M, N, K) - SmallVector<int64_t> invocationTileSizes(2 + isBM, 0); // (B, M, N, K) - SmallVector<int64_t> reductionTileSizes(3 + isBM, 0); // (B, M, N, K) + SmallVector<int64_t> workgroupTileSizes(2 + isBM, 0); // ([B,] M, N) + SmallVector<int64_t> invocationTileSizes(2 + isBM, 0); // ([B,] M, N) + SmallVector<int64_t> reductionTileSizes(3 + isBM, 0); // ([B,] M, N, K) if (isBM) workgroupTileSizes[0] = invocationTileSizes[0] = 1; @@ -287,8 +289,8 @@ // FFT Default Configuration //===----------------------------------------------------------------------===// -static LogicalResult setOpConfig(spirv::ResourceLimitsAttr limits, - linalg_ext::FftOp op) { +static LogicalResult setFftOpConfig(spirv::ResourceLimitsAttr limits, + linalg_ext::FftOp op) { const int64_t subgroupSize = limits.subgroup_size().getValue().getSExtValue(); auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistribute; @@ -319,24 +321,88 @@ } //===----------------------------------------------------------------------===// -// Default Configuration +// Everything Default Configuration //===----------------------------------------------------------------------===// static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits, Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "Using default config for op: " << *op << "\n"); + FuncOp funcOp = op->getParentOfType<FuncOp>(); auto partitionedLoops = getPartitionedLoops(op); + + // Special case for not tiled ops. if (partitionedLoops.empty()) { + // No tiled loops means we cannot tile (and distribute) at all. Use just one + // single thread to run everything. auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize; std::array<int64_t, 3> workgroupSize = {1, 1, 1}; - auto funcOp = op->getParentOfType<FuncOp>(); return setOpConfigAndEntryPointFnTranslation(funcOp, op, {}, {}, pipeline, workgroupSize); } - const int64_t subgroupSize = limits.subgroup_size().getValue().getSExtValue(); - int64_t numElementsPerWorkgroup = subgroupSize; - int64_t numElementsPerThread = 1; - auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistribute; + const int subgroupSize = limits.subgroup_size().getValue().getSExtValue(); + const unsigned loopDepth = partitionedLoops.back() + 1; + + // Configurations we need to decide. + std::array<int64_t, 3> workgroupSize; + SmallVector<int64_t> workgroupTileSizes; + SmallVector<int64_t> threadTileSizes; + + // Initialize the configuration. + auto initConfiguration = [&]() { + workgroupSize = {subgroupSize, 1, 1}; + workgroupTileSizes.resize(loopDepth, 0); + threadTileSizes.resize(loopDepth, 0); + + // Initialize tiling along all partitioned loops with size 1. + for (int64_t loopIndex : partitionedLoops) { + workgroupTileSizes[loopIndex] = threadTileSizes[loopIndex] = 1; + } + // Override the innermost dimension to distribute to threads in a subgroup. + workgroupTileSizes.back() = subgroupSize; + threadTileSizes.back() = 1; + }; + + // Special case for non-linalg ops. + auto linalgOp = dyn_cast<linalg::LinalgOp>(op); + if (!linalgOp || linalgOp.getNumOutputs() != 1) { + auto pipeline = + IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistribute; + + initConfiguration(); + TileSizesListType tileSizes; + tileSizes.push_back(workgroupTileSizes); + tileSizes.push_back(threadTileSizes); + + return setOpConfigAndEntryPointFnTranslation(funcOp, op, tileSizes, {}, + pipeline, workgroupSize); + } + + // Common case for all linalg ops. + + // The core idea is to distribute the partitioned loops to the workgroup + // dimensions. The goal is to fill up the GPU as much as possible, which means + // 1) distributing to as many threads as possible, and 2) avoid assigning too + // many threads to handle out-of-bound elements (thus idle). + + SmallVector<TiledLoopInfo> tiledLoopInfo = getTiledLoopInfo(funcOp); + // The number of linalg implicit loops to partition and tiled loops + // surrounding the op should match. Otherwise, something is incorrect. + assert(partitionedLoops.size() == tiledLoopInfo.size()); + + // The upper bound for each implicit loop: 0 - untiled, negative - dynamic. + SmallVector<int64_t> loopBounds(loopDepth, 0); + // tiledLoopInfo uses the reverse order of partitionedLoops. + for (auto pair : llvm::zip(llvm::reverse(partitionedLoops), tiledLoopInfo)) { + unsigned loopIndex = std::get<0>(pair); + const TiledLoopInfo &loopInfo = std::get<1>(pair); + Optional<int64_t> attrValue = getConstantIntValue(loopInfo.ub); + if (attrValue) { + loopBounds[loopIndex] = *attrValue; + } else { + loopBounds[loopIndex] = ShapedType::kDynamicSize; + } + } // Returns true if the given `operand` has 32-bit element type. auto has32BitElementType = [](Value operand) { @@ -346,71 +412,130 @@ return elementType.isa<FloatType>() || elementType.isInteger(32); }; - if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { - bool vectorize = false; - auto outputShape = getUntiledResultShape(linalgOp, 0); + // Whether we can try to use the vectorization pipeline. + bool vectorizable = + !linalgOp.hasIndexSemantics() && + // Skip vectorization for non-minor identity inputs as it generates + // vector.transfer_read ops with permutation maps that we currently + // cannot lower. + // TODO: Remove this restriction once the lowering of the permutation + // map is supported in core. + llvm::all_of(linalgOp.getIndexingMaps(), + [](AffineMap &map) { return map.isMinorIdentity(); }) && + // TODO: Lowering of integers other than i32 may require emulation. + // This is currently not supported for vector operation. + llvm::all_of(linalgOp->getOperands(), has32BitElementType) && + llvm::none_of(getUntiledResultShape(linalgOp, 0), ShapedType::isDynamic); - if (!linalgOp.hasIndexSemantics() && - // Skip vectorization for non-minor identity inputs as it generates - // vector.transfer_read ops with permutation maps that we currently - // cannot lower. - // TODO: Remove this restriction once the lowering of the permutation - // map is supported in core. - llvm::all_of(linalgOp.getIndexingMaps(), - [](AffineMap &map) { return map.isMinorIdentity(); }) && - // TODO(thomasraoux): Lowering of integers other than i32 may require - // emulation. This is currently not supported for vector operation. - // Re-enable this when the bug is fixed on SPIR-V lowering side. - llvm::all_of(linalgOp->getOperands(), has32BitElementType) && - llvm::all_of(outputShape, - [](int64_t dim) { return !ShapedType::isDynamic(dim); })) { - vectorize = true; + LLVM_DEBUG({ + llvm::dbgs() << "Linalg op " << linalgOp << "\n partitioned loops: ["; + llvm::interleaveComma(partitionedLoops, llvm::dbgs()); + llvm::dbgs() << "]\n loop bounds: ["; + llvm::interleaveComma(loopBounds, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + // Distribute workload to the given `numThreads` by allowing a potental loss. + auto distributeToThreads = [&](int64_t numThreads, + Optional<int64_t> lossFactor = llvm::None) { + LLVM_DEBUG(llvm::dbgs() << "\nLoss factor: " << lossFactor << "\n"); + initConfiguration(); + + // Scan from the innermost shape dimension and try to deduce the + // configuration for the corresponding GPU workgroup dimension. + for (int shapeDim = loopDepth - 1, wgDim = 0; shapeDim >= 0; --shapeDim) { + LLVM_DEBUG({ + llvm::dbgs() << "Remaining threads: " << numThreads << "\n"; + llvm::dbgs() << "Shape dim #" << shapeDim << "="; + llvm::dbgs() << loopBounds[shapeDim] << "\n" + << "Workgroup dim #" << wgDim << "\n"; + }); + // Skip dynamic/untiled/size-1 dimensions. + if (loopBounds[shapeDim] <= 1) continue; + + // Try to find some power of two that can devide the current shape dim + // size. This vector keeps the candidate tile sizes. + SmallVector<int64_t, 8> candidates; + + // For the inner most workgroup dim, try to see if we can have 4 + // elements per thread. This enables vectorization. + if (vectorizable && wgDim == 0) candidates.push_back(4 * numThreads); + // Try all power of two numbers upto the subgroup size. + for (unsigned i = numThreads; i >= 1; i >>= 1) { + candidates.push_back(i); + } + LLVM_DEBUG({ + llvm::dbgs() << "Candidates tile sizes: ["; + llvm::interleaveComma(candidates, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + for (int64_t candidate : candidates) { + if (loopBounds[shapeDim] % candidate != 0) { + if (!lossFactor) continue; + // Skip this candidate if it causes many threads to be idle. + int64_t idleThreads = candidate - (loopBounds[shapeDim] % candidate); + if (idleThreads > candidate / *lossFactor) continue; + } + LLVM_DEBUG(llvm::dbgs() << "Chosen Candiate " << candidate << "\n"); + + // Found a suitable candidate. Try to let each thread handle 4 + // elements if this is the workgroup x dimension. + workgroupTileSizes[shapeDim] = candidate; + if (vectorizable && wgDim == 0 && candidate % 4 == 0) { + threadTileSizes[shapeDim] = 4; + workgroupSize[wgDim++] = candidate / 4; + assert(numThreads % (candidate / 4) == 0); + numThreads /= candidate / 4; + } else { + if (wgDim == 0) vectorizable = false; + threadTileSizes[shapeDim] = 1; + workgroupSize[wgDim++] = candidate; + assert(numThreads % candidate == 0); + numThreads /= candidate; + } + assert(numThreads >= 1); + break; + } + + // Check if we have distributed all threads in this subgroup all used + // up all distribution dims. + if (numThreads == 1 || wgDim > 3) break; } + return numThreads; + }; - SmallVector<int64_t, 4> candidateTileSizes; - if (vectorize) candidateTileSizes.push_back(4 * subgroupSize); - candidateTileSizes.push_back(subgroupSize); + // First try to see if we can use up all threads without any loss. + if (distributeToThreads(subgroupSize) != 1) { + // Otherwise, allow larger and larger loss factor. - for (int64_t size : candidateTileSizes) { - if (outputShape.back() % size != 0) continue; - numElementsPerWorkgroup = size; - break; - } + // Threads for distribution Use 32 at least. + int64_t numThreads = std::max(subgroupSize, 32); + // We can tolerate (1 / lossFactor) of threads in the workgroup to be idle. + int64_t lossFactor = 32; - if (numElementsPerWorkgroup <= subgroupSize || - outputShape.back() % numElementsPerWorkgroup != 0) { - vectorize = false; - } - - if (vectorize) { - numElementsPerThread = numElementsPerWorkgroup / subgroupSize; - pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize; + for (; lossFactor >= 1; lossFactor >>= 1) { + if (distributeToThreads(numThreads, lossFactor) == 1) break; } } - std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1}; - - unsigned loopDepth = partitionedLoops.back() + 1; - SmallVector<int64_t> workgroupTileSize(loopDepth, 0); - SmallVector<int64_t> threadTileSize(loopDepth, 0); - - // Tiling along partitioned loops with size 1. - for (int64_t loopIndex : partitionedLoops) { - workgroupTileSize[loopIndex] = threadTileSize[loopIndex] = 1; - } - // Overwrite the configuration for the innermost dimension. - workgroupTileSize.back() = numElementsPerWorkgroup; - threadTileSize.back() = numElementsPerThread; + auto pipeline = + vectorizable + ? IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize + : IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistribute; TileSizesListType tileSizes; - tileSizes.push_back(workgroupTileSize); - tileSizes.push_back(threadTileSize); + tileSizes.push_back(workgroupTileSizes); + tileSizes.push_back(threadTileSizes); - return setOpConfigAndEntryPointFnTranslation(op->getParentOfType<FuncOp>(), - op, tileSizes, {}, pipeline, - workgroupSize); + return setOpConfigAndEntryPointFnTranslation(funcOp, op, tileSizes, {}, + pipeline, workgroupSize); } +//===----------------------------------------------------------------------===// +// Configuration Dispatcher +//===----------------------------------------------------------------------===// + /// Sets the CodeGen configuration as attributes to the given `rootOp` if it's a /// known Linalg matmul/convolution op with good configurations. static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv, @@ -451,8 +576,6 @@ // If unsuccessful, try to tile and distribute. return setDefaultOpConfig(limits, op); }) - .Case<linalg_ext::FftOp>( - [limits](auto op) { return setOpConfig(limits, op); }) .Case<linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwOp>( [limits](auto op) { // Try to tile and vectorize first. It's common to see 32 threads @@ -465,9 +588,12 @@ // If unsuccessful, try to tile and distribute. return setDefaultOpConfig(limits, op); }) - .Case<linalg::GenericOp>([limits](auto op) { - // If generic op has reduction iterator types, it is a root as - // well. Just set the default configuration, which marks it as a root. + .Case<linalg_ext::FftOp>( + [limits](linalg_ext::FftOp op) { return setFftOpConfig(limits, op); }) + .Case<linalg::GenericOp>([limits](linalg::GenericOp op) { + // If a generic op has reduction iterator types, it can be treated as a + // root op for configuration as well. Use the default configuration, + // which will mark it as a root. if (op.getNumLoops() != op.getNumParallelLoops()) { return setDefaultOpConfig(limits, op); }
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD index c57ce23..2d61206 100644 --- a/iree/compiler/Codegen/SPIRV/test/BUILD +++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -21,9 +21,9 @@ [ "config_adreno_conv.mlir", "config_adreno_matmul.mlir", + "config_default_linalg_ext_ops.mlir", + "config_default_linalg_ops.mlir", "config_default_matmul.mlir", - "config_linalg_ext_ops.mlir", - "config_linalg_ops.mlir", "config_mali_conv.mlir", "config_mali_matmul.mlir", "config_nvidia_matmul_cooperative_ops.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index 6fbb762..400fe1f 100644 --- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -16,9 +16,9 @@ SRCS "config_adreno_conv.mlir" "config_adreno_matmul.mlir" + "config_default_linalg_ext_ops.mlir" + "config_default_linalg_ops.mlir" "config_default_matmul.mlir" - "config_linalg_ext_ops.mlir" - "config_linalg_ops.mlir" "config_mali_conv.mlir" "config_mali_matmul.mlir" "config_nvidia_matmul_cooperative_ops.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/config_linalg_ext_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir similarity index 98% rename from iree/compiler/Codegen/SPIRV/test/config_linalg_ext_ops.mlir rename to iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir index 57ea3d8..48113c2 100644 --- a/iree/compiler/Codegen/SPIRV/test/config_linalg_ext_ops.mlir +++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
@@ -4,7 +4,7 @@ hal.interface.binding @s0b0_rw_external, set=0, binding=0, type="StorageBuffer", access="Read|Write" } hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, @@ -54,7 +54,7 @@ hal.interface.binding @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" } hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, @@ -122,7 +122,7 @@ hal.interface.binding @s0b1_rw_external, set=0, binding=1, type="StorageBuffer", access="Read|Write" } hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirvfb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, @@ -172,7 +172,7 @@ hal.interface.binding @s0b1_rw_external, set=0, binding=1, type="StorageBuffer", access="Read|Write" } hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirvfb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>,
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir new file mode 100644 index 0000000..3f361e1 --- /dev/null +++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
@@ -0,0 +1,282 @@ +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | IreeFileCheck %s + +hal.executable @tensor_insert { + hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 32768 : i32, + max_compute_workgroup_invocations = 512 : i32, + max_compute_workgroup_size = dense<512> : vector<3xi32>, + subgroup_size = 16 : i32}> + }> { + hal.executable.entry_point @tensor_insert_slice attributes {interface = @io, ordinal = 0 : index} + builtin.module { + builtin.func @tensor_insert_slice() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:?x?xi32> + %1 = hal.interface.load.constant offset = 0 : index + %2 = hal.interface.load.constant offset = 1 : index + %3 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y] + %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y] + %d0 = hal.interface.load.constant offset = 2 : index + %d1 = hal.interface.load.constant offset = 2 : index + scf.for %arg0 = %4 to %d0 step %5 { + %6 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0] + %7 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x] + scf.for %arg1 = %7 to %d1 step %8 { + %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1] + %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%6, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32> + %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%1] + %12 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%2] + flow.dispatch.tensor.store %10, %3, offsets = [%11, %12], sizes = [%6, %9], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32> + } + } + return + } + hal.interface @io attributes {push_constants = 2 : index, sym_visibility = "private"} { + hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [64, 1]> +// CHECK: hal.executable.entry_point public @tensor_insert_slice +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[NWGSX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK: hal.return %[[NWGSX]], %[[ARG1]], %[[C1]] + +// ----- + +hal.executable @tensor_insert { + hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 32768 : i32, + max_compute_workgroup_invocations = 512 : i32, + max_compute_workgroup_size = dense<512> : vector<3xi32>, + subgroup_size = 16 : i32}> + }> { + hal.executable.entry_point @tensor_insert_slice attributes {interface = @io, ordinal = 0 : index} + builtin.module { + builtin.func @tensor_insert_slice() { + %c0 = arith.constant 0 : index + %d0 = hal.interface.load.constant offset = 0 : index + %d1 = hal.interface.load.constant offset = 1 : index + %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<?x?xi32>{%d0, %d1} + %1 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : memref<?x?xi32>{%d0, %d1} + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y] + %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y] + scf.for %arg0 = %2 to %d0 step %3 { + %4 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0] + %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x] + %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x] + scf.for %arg1 = %5 to %d1 step %6 { + %7 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1] + %8 = memref.subview %0[%arg0, %arg1] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> + %9 = affine.apply affine_map<(d0) -> (d0 + 4)>(%arg0) + %10 = affine.apply affine_map<(d0) -> (d0 + 3)>(%arg1) + %11 = memref.subview %1[%9, %10] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> + linalg.copy(%8, %11) : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> + } + } + return + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 16], [1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 1]> +// CHECK: hal.executable.entry_point public @tensor_insert_slice +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[NWGSX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] +// CHECK: hal.return %[[NWGSX]], %[[ARG1]], %[[C1]] +// CHECK: linalg.copy +// CHECK-SAME: lowering.config = #[[CONFIG]] + +// ----- + +hal.executable @tensor_insert { + hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 32768 : i32, + max_compute_workgroup_invocations = 512 : i32, + max_compute_workgroup_size = dense<512> : vector<3xi32>, + subgroup_size = 64 : i32}> + }> { + hal.executable.entry_point @copy attributes {interface = @io, ordinal = 0 : index} + builtin.module { + builtin.func @copy() { + %c0 = arith.constant 0 : index + %c224 = arith.constant 224 : index + %c3 = arith.constant 3 : index + %0 = hal.interface.binding.subspan @io::@s0b0_rw_external[%c0] : memref<1x225x225x3xf32> + %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<1x224x224x3xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_z, %workgroup_id_z] + %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_z, %workgroup_count_z] + scf.for %arg0 = %2 to %c224 step %3 { + %4 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 224)>(%arg0)[%workgroup_size_z] + %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y] + scf.for %arg1 = %5 to %c224 step %6 { + %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 224)>(%arg1)[%workgroup_size_y] + %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x] + %9 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x] + scf.for %arg2 = %8 to %c3 step %9 { + %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg2)[%workgroup_size_x] + %11 = memref.subview %1[0, %arg0, %arg1, %arg2] [1, %4, %7, %10] [1, 1, 1, 1] : memref<1x224x224x3xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 150528 + s0 + d1 * 672 + d2 * 3 + d3)>> + %12 = memref.subview %0[0, %arg0, %arg1, %arg2] [1, %4, %7, %10] [1, 1, 1, 1] : memref<1x225x225x3xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>> + linalg.copy(%11, %12) : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 150528 + s0 + d1 * 672 + d2 * 3 + d3)>>, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>> + } + } + } + return + } + hal.interface @io attributes {sym_visibility = "private"} { + hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 2, 32, 1], [0, 1, 1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> +// CHECK-DAG: #[[MAP_Z:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [1, 32, 2]> + +// CHECK: hal.executable.entry_point public @copy +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-NEXT: (%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index) +// CHECK-DAG: %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]] +// CHECK-DAG: %[[Z_COUNT:.+]] = affine.apply #[[MAP_Z]]()[%[[Z]]] +// CHECK: hal.return %[[X]], %[[Y_COUNT]], %[[Z_COUNT]] + +// CHECK: linalg.copy +// CHECK-SAME: lowering.config = #[[CONFIG]] + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 * s1)> +#map1 = affine_map<(d0) -> (d0 * 12)> +#map2 = affine_map<(d0)[s0] -> (s0 * 12, d0 * -12 + 24)> +#map3 = affine_map<(d0)[s0] -> (s0, -d0 + 8)> +#map4 = affine_map<(d0)[s0] -> (s0, -d0 + 2)> +#map5 = affine_map<(d0)[s0] -> (-d0 + 2, s0)> +#map6 = affine_map<(d0)[s0] -> (-d0 + 8, s0)> +#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +hal.executable @avg_pool { + hal.interface public @io { + hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" + } + hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 32768 : i32, + max_compute_workgroup_invocations = 512 : i32, + max_compute_workgroup_size = dense<512> : vector<3xi32>, + subgroup_size = 32 : i32}> + }> { + hal.executable.entry_point public @avg_pool attributes {interface = @io, ordinal = 0 : index} + builtin.module { + func @avg_pool() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:1x24x24x8xf32> + %1 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : !flow.dispatch.tensor<writeonly:1x2x2x8xf32> + %2 = linalg.init_tensor [12, 12] : tensor<12x12xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c2 step %4 { + %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c2 step %6 { + %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c8 step %8 { + %9 = affine.apply #map1(%arg0) + %10 = affine.min #map2(%arg0)[%workgroup_size_z] + %11 = affine.apply #map1(%arg1) + %12 = affine.min #map2(%arg1)[%workgroup_size_y] + %13 = affine.min #map3(%arg2)[%workgroup_size_x] + %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x24x24x8xf32> -> tensor<1x?x?x?xf32> + %15 = affine.min #map4(%arg0)[%workgroup_size_z] + %16 = affine.min #map4(%arg1)[%workgroup_size_y] + %17 = affine.min #map5(%arg0)[%workgroup_size_z] + %18 = affine.min #map5(%arg1)[%workgroup_size_y] + %19 = affine.min #map6(%arg2)[%workgroup_size_x] + %20 = linalg.init_tensor [1, %17, %18, %19] : tensor<1x?x?x?xf32> + %21 = linalg.fill(%cst, %20) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32> + %22 = linalg.pooling_nhwc_sum {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<12> : vector<2xi64>} ins(%14, %2 : tensor<1x?x?x?xf32>, tensor<12x12xf32>) outs(%21 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32> + flow.dispatch.tensor.store %22, %1, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %16, %13], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x2x2x8xf32> + } + } + } + return + } + hal.interface private @io { + hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 2, 2, 0, 0, 8], [0, 1, 1, 0, 0, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [8, 2, 2]> + +// CHECK: hal.executable.entry_point public @avg_pool +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-NEXT: (%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index) +// CHECK-DAG: %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]] +// CHECK-DAG: %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]] +// CHECK-DAG: %[[Z_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Z]]] +// CHECK: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]] + +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir index 10dc64f..aaa6935 100644 --- a/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir +++ b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -13,7 +13,7 @@ max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, - subgroup_size = 4 : i32}> + subgroup_size = 32 : i32}> }> { hal.executable.entry_point public @batch_matmul_1x3x32 attributes {interface = @io, ordinal = 0 : index} builtin.module { @@ -74,12 +74,12 @@ } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 1, 4], [1, 1, 1]{{\]}}, native_vector_size = []> -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [4, 1, 1]> +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 1, 32], [1, 1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [32, 1, 1]> // CHECK: hal.executable.entry_point public @batch_matmul_1x3x32 // CHECK-SAME: translation.info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [32 : index, 1 : index, 1 : index] // CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index): // CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]] // CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y]], %[[Z]] @@ -103,7 +103,7 @@ max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, - subgroup_size = 4 : i32}> + subgroup_size = 64 : i32}> }> { hal.executable.entry_point public @matmul_64x16 attributes {interface = @io, ordinal = 0 : index} builtin.module { @@ -152,17 +152,191 @@ } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 4], [1, 1]{{\]}}, native_vector_size = []> -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [4, 1]> +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4, 16], [1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 4]> // CHECK: hal.executable.entry_point public @matmul_64x16 // CHECK-SAME: translation.info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [16 : index, 4 : index, 1 : index] // CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index): // CHECK-NEXT: %[[ONE:.+]] = arith.constant 1 : index -// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]] -// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y]], %[[ONE]] +// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]] +// CHECK-NEXT: %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]] +// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]] // CHECK: func @matmul_64x16() // CHECK: linalg.matmul // CHECK-SAME: lowering.config = #[[CONFIG]] + +// ----- + +// Odd N that forbids vectorization. + +hal.executable @matmul_400x273 { + hal.interface public @io { + hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 16384 : i32, + max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, + subgroup_size = 64 : i32}> + }> { + hal.executable.entry_point public @matmul_400x273 attributes {interface = @io, ordinal = 0 : index} + builtin.module { + func @matmul_400x273() { + %c0 = arith.constant 0 : index + %c11775744 = arith.constant 11775744 : index + %cst = arith.constant 0.000000e+00 : f32 + %c400 = arith.constant 400 : index + %c273 = arith.constant 273 : index + %0 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c11775744] : !flow.dispatch.tensor<readonly:273xf32> + %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:400x576xf32> + %2 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c0] : !flow.dispatch.tensor<readonly:576x273xf32> + %3 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:400x273xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %4 to %c400 step %5 { + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %6 to %c273 step %7 { + %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 273)>(%arg1)[%workgroup_size_x] + %9 = flow.dispatch.tensor.load %0, offsets = [%arg1], sizes = [%8], strides = [1] : !flow.dispatch.tensor<readonly:273xf32> -> tensor<?xf32> + %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 400)>(%arg0)[%workgroup_size_y] + %11 = linalg.init_tensor [%10, %8] : tensor<?x?xf32> + %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 400, s0)>(%arg0)[%workgroup_size_y] + %13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%12, 576], strides = [1, 1] : !flow.dispatch.tensor<readonly:400x576xf32> -> tensor<?x576xf32> + %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 273, s0)>(%arg1)[%workgroup_size_x] + %15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [576, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:576x273xf32> -> tensor<576x?xf32> + %16 = linalg.init_tensor [%12, %14] : tensor<?x?xf32> + %17 = linalg.fill(%cst, %16) : f32, tensor<?x?xf32> -> tensor<?x?xf32> + %18 = linalg.matmul ins(%13, %15 : tensor<?x576xf32>, tensor<576x?xf32>) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32> + %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %18 : tensor<?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %20 = arith.addf %arg2, %arg3 : f32 + linalg.yield %20 : f32 + } -> tensor<?x?xf32> + flow.dispatch.tensor.store %19, %3, offsets = [%arg0, %arg1], sizes = [%10, %8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:400x273xf32> + } + } + return + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[2, 32], [1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> +// CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [32, 2]> + +// CHECK: hal.executable.entry_point public @matmul_400x273 +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-SAME: workgroup_size = [32 : index, 2 : index, 1 : index] +// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index): +// CHECK-NEXT: %[[ONE:.+]] = arith.constant 1 : index +// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]] +// CHECK-NEXT: %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]] +// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]] + +// CHECK: func @matmul_400x273() +// CHECK: linalg.matmul +// CHECK-SAME: lowering.config = #[[CONFIG]] + +// ----- + +// Odd M and non-4-multiplier N + +hal.executable @matmul_25x546 { + hal.interface public @io { + hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, { + max_compute_shared_memory_size = 16384 : i32, + max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, + subgroup_size = 64 : i32}> + }> { + hal.executable.entry_point public @matmul_25x546 attributes {interface = @io, ordinal = 0 : index} + builtin.module { + func @matmul_25x546() { + %c0 = arith.constant 0 : index + %c15842560 = arith.constant 15842560 : index + %cst = arith.constant 0.000000e+00 : f32 + %c25 = arith.constant 25 : index + %c546 = arith.constant 546 : index + %0 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c15842560] : !flow.dispatch.tensor<readonly:546xf32> + %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:25x512xf32> + %2 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c0] : !flow.dispatch.tensor<readonly:512x546xf32> + %3 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:25x546xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %4 to %c25 step %5 { + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %6 to %c546 step %7 { + %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 546)>(%arg1)[%workgroup_size_x] + %9 = flow.dispatch.tensor.load %0, offsets = [%arg1], sizes = [%8], strides = [1] : !flow.dispatch.tensor<readonly:546xf32> -> tensor<?xf32> + %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 25)>(%arg0)[%workgroup_size_y] + %11 = linalg.init_tensor [%10, %8] : tensor<?x?xf32> + %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 25, s0)>(%arg0)[%workgroup_size_y] + %13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%12, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:25x512xf32> -> tensor<?x512xf32> + %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 546, s0)>(%arg1)[%workgroup_size_x] + %15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [512, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x546xf32> -> tensor<512x?xf32> + %16 = linalg.init_tensor [%12, %14] : tensor<?x?xf32> + %17 = linalg.fill(%cst, %16) : f32, tensor<?x?xf32> -> tensor<?x?xf32> + %18 = linalg.matmul ins(%13, %15 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32> + %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %18 : tensor<?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) attrs = {__internal_linalg_transform__ = "workgroup"} { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %20 = arith.addf %arg2, %arg3 : f32 + linalg.yield %20 : f32 + } -> tensor<?x?xf32> + flow.dispatch.tensor.store %19, %3, offsets = [%arg0, %arg1], sizes = [%10, %8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:25x546xf32> + } + } + return + } + hal.interface private @io { + hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[32, 2], [1, 1]{{\]}}, native_vector_size = []> +// CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [2, 32]> +// CHECK: hal.executable.entry_point public @matmul_25x546 +// CHECK-SAME: translation.info = #[[TRANSLATION]] +// CHECK-SAME: workgroup_size = [2 : index, 32 : index, 1 : index] +// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index): +// CHECK-NEXT: %[[ONE:.+]] = arith.constant 1 : index +// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]] +// CHECK-NEXT: %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]] +// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]] + +// CHECK: func @matmul_25x546() +// CHECK: linalg.matmul +// CHECK-SAME: lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_linalg_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_linalg_ops.mlir deleted file mode 100644 index 4ed7441..0000000 --- a/iree/compiler/Codegen/SPIRV/test/config_linalg_ops.mlir +++ /dev/null
@@ -1,115 +0,0 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | IreeFileCheck %s - -hal.executable @tensor_insert { - hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { - max_compute_shared_memory_size = 32768 : i32, - max_compute_workgroup_invocations = 512 : i32, - max_compute_workgroup_size = dense<512> : vector<3xi32>, - subgroup_size = 16 : i32}> - }> { - hal.executable.entry_point @tensor_insert_slice attributes {interface = @io, ordinal = 0 : index} - builtin.module { - builtin.func @tensor_insert_slice() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:?x?xi32> - %1 = hal.interface.load.constant offset = 0 : index - %2 = hal.interface.load.constant offset = 1 : index - %3 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32> - %workgroup_size_x = hal.interface.workgroup.size[0] : index - %workgroup_size_y = hal.interface.workgroup.size[1] : index - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y] - %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y] - %d0 = hal.interface.load.constant offset = 2 : index - %d1 = hal.interface.load.constant offset = 2 : index - scf.for %arg0 = %4 to %d0 step %5 { - %6 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0] - %7 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x] - %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x] - scf.for %arg1 = %7 to %d1 step %8 { - %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1] - %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%6, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32> - %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%1] - %12 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%2] - flow.dispatch.tensor.store %10, %3, offsets = [%11, %12], sizes = [%6, %9], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32> - } - } - return - } - hal.interface @io attributes {push_constants = 2 : index, sym_visibility = "private"} { - hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @s0b1_xw_external, set=0, binding=1, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [64, 1]> -// CHECK: hal.executable.entry_point public @tensor_insert_slice -// CHECK-SAME: translation.info = #[[TRANSLATION]] -// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[NWGSX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK: hal.return %[[NWGSX]], %[[ARG1]], %[[C1]] - -// ----- - -hal.executable @tensor_insert { - hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { - spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, ARM:IntegratedGPU, { - max_compute_shared_memory_size = 32768 : i32, - max_compute_workgroup_invocations = 512 : i32, - max_compute_workgroup_size = dense<512> : vector<3xi32>, - subgroup_size = 16 : i32}> - }> { - hal.executable.entry_point @tensor_insert_slice attributes {interface = @io, ordinal = 0 : index} - builtin.module { - builtin.func @tensor_insert_slice() { - %c0 = arith.constant 0 : index - %d0 = hal.interface.load.constant offset = 0 : index - %d1 = hal.interface.load.constant offset = 1 : index - %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<?x?xi32>{%d0, %d1} - %1 = hal.interface.binding.subspan @io::@s0b1_xw_external[%c0] : memref<?x?xi32>{%d0, %d1} - %workgroup_size_x = hal.interface.workgroup.size[0] : index - %workgroup_size_y = hal.interface.workgroup.size[1] : index - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y] - %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y] - scf.for %arg0 = %2 to %d0 step %3 { - %4 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0] - %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x] - %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x] - scf.for %arg1 = %5 to %d1 step %6 { - %7 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1] - %8 = memref.subview %0[%arg0, %arg1] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> - %9 = affine.apply affine_map<(d0) -> (d0 + 4)>(%arg0) - %10 = affine.apply affine_map<(d0) -> (d0 + 3)>(%arg1) - %11 = memref.subview %1[%9, %10] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> - linalg.copy(%8, %11) : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>, memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>> - } - } - return - } - } - } -} -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 16], [1, 1]{{\]}}, native_vector_size = []> -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 1]> -// CHECK: hal.executable.entry_point public @tensor_insert_slice -// CHECK-SAME: translation.info = #[[TRANSLATION]] -// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[NWGSX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK: hal.return %[[NWGSX]], %[[ARG1]], %[[C1]] -// CHECK: linalg.copy -// CHECK-SAME: lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp index 58a1e04..7be2706 100644 --- a/iree/compiler/Codegen/Utils/Utils.cpp +++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -429,6 +429,7 @@ static Optional<TiledLoopInfo> isTiledLoop(MLIRContext *context, scf::ForOp forOp) { TiledLoopInfo loopInfo; + loopInfo.tiledLoop = forOp; auto lbApplyOp = forOp.lowerBound().getDefiningOp<AffineApplyOp>(); if (!lbApplyOp) { return llvm::None;
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/test/smoketest.mlir index 269bb35..dc126ce 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/test/smoketest.mlir
@@ -1,7 +1,5 @@ // RUN: iree-opt -split-input-file -iree-hal-transformation-pipeline %s | IreeFileCheck %s -#map = affine_map<(d0) -> (d0)> - module attributes { hal.device.targets = [ #hal.device.target<"metal", { @@ -14,21 +12,18 @@ ] } { -flow.executable @add_dispatch_0 { - flow.dispatch.entry @add_dispatch_0 attributes { - workgroup_rank = 3 : index - } - builtin.module { - func @add_dispatch_0(%arg0: !flow.dispatch.tensor<readonly:16xf32>, %arg1: !flow.dispatch.tensor<readonly:16xf32>, %arg2: !flow.dispatch.tensor<writeonly:16xf32>) { - %0 = linalg.init_tensor [16] : tensor<16xf32> +flow.executable @reduce_dispatch { + flow.dispatch.entry @reduce_dispatch attributes {workgroup_rank = 3 : index} + builtin.module { + func @reduce_dispatch(%arg0: !flow.dispatch.tensor<readonly:16xf32>, %arg1: !flow.dispatch.tensor<writeonly:f32>) { + %0 = linalg.init_tensor [] : tensor<f32> %1 = flow.dispatch.tensor.load %arg0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32> - %2 = flow.dispatch.tensor.load %arg1, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %4 = arith.addf %arg3, %arg4 : f32 + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor<f32>) { + ^bb0(%arg2: f32, %arg3: f32): + %4 = arith.addf %arg2, %arg3 : f32 linalg.yield %4 : f32 - } -> tensor<16xf32> - flow.dispatch.tensor.store %3, %arg2, offsets=[], sizes=[], strides=[] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:16xf32> + } -> tensor<f32> + flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32> return } }
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir index d3d6598..04421d2 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir
@@ -1,7 +1,5 @@ // RUN: iree-opt -split-input-file -iree-hal-transformation-pipeline %s | IreeFileCheck %s -#map = affine_map<(d0) -> (d0)> - module attributes { hal.device.targets = [ #hal.device.target<"vulkan", { @@ -14,21 +12,18 @@ ] } { -flow.executable @add_dispatch_0 { - flow.dispatch.entry @add_dispatch_0 attributes { - workgroup_rank = 3 : index - } - builtin.module { - func @add_dispatch_0(%arg0: !flow.dispatch.tensor<readonly:16xf32>, %arg1: !flow.dispatch.tensor<readonly:16xf32>, %arg2: !flow.dispatch.tensor<writeonly:16xf32>) { - %0 = linalg.init_tensor [16] : tensor<16xf32> +flow.executable @reduce_dispatch { + flow.dispatch.entry @reduce_dispatch attributes {workgroup_rank = 3 : index} + builtin.module { + func @reduce_dispatch(%arg0: !flow.dispatch.tensor<readonly:16xf32>, %arg1: !flow.dispatch.tensor<writeonly:f32>) { + %0 = linalg.init_tensor [] : tensor<f32> %1 = flow.dispatch.tensor.load %arg0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32> - %2 = flow.dispatch.tensor.load %arg1, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %4 = arith.addf %arg3, %arg4 : f32 + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor<f32>) { + ^bb0(%arg2: f32, %arg3: f32): + %4 = arith.addf %arg2, %arg3 : f32 linalg.yield %4 : f32 - } -> tensor<16xf32> - flow.dispatch.tensor.store %3, %arg2, offsets=[], sizes=[], strides=[] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:16xf32> + } -> tensor<f32> + flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32> return } }
diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir index 0b9ac53..f4fa245 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir +++ b/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -13,26 +13,19 @@ // MALI: #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> // TURINGT4: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[1024, 1024, 64]> : vector<3xi32>, subgroup_size = 32 : i32}> // AMD5700XT: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, AMD:DiscreteGPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 65536 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<1024> : vector<3xi32>, subgroup_size = 64 : i32}> -#map0 = affine_map<(d0) -> (d0)> -flow.executable @simpleMath_dispatch_0 { - flow.dispatch.entry @simpleMath_dispatch_0 attributes {workgroup_rank = 3 : index} + +flow.executable @reduce_dispatch { + flow.dispatch.entry @reduce_dispatch attributes {workgroup_rank = 3 : index} builtin.module { - func @simpleMath_dispatch_0(%arg0: !flow.dispatch.tensor<readonly:4xf32>, %arg1: !flow.dispatch.tensor<writeonly:4xf32>) { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = flow.dispatch.tensor.load %arg0, offsets = [%c0], sizes = [%c4], strides = [%c1] : !flow.dispatch.tensor<readonly:4xf32> -> tensor<4xf32> - %1 = linalg.init_tensor [4] : tensor<4xf32> - %2 = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%0 : tensor<4xf32>) - outs(%1 : tensor<4xf32>) { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %3 = arith.addf %arg3, %arg3 : f32 - linalg.yield %3 : f32 - } -> tensor<4xf32> - flow.dispatch.tensor.store %2, %arg1, offsets = [%c0], sizes = [%c4], strides = [%c1] : tensor<4xf32> -> !flow.dispatch.tensor<writeonly:4xf32> + func @reduce_dispatch(%arg0: !flow.dispatch.tensor<readonly:16xf32>, %arg1: !flow.dispatch.tensor<writeonly:f32>) { + %0 = linalg.init_tensor [] : tensor<f32> + %1 = flow.dispatch.tensor.load %arg0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor<f32>) { + ^bb0(%arg2: f32, %arg3: f32): + %4 = arith.addf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor<f32> + flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32> return } }
diff --git a/iree/test/e2e/tosa_ops/mul.mlir b/iree/test/e2e/tosa_ops/mul.mlir index 640b2da..a1e498b 100644 --- a/iree/test/e2e/tosa_ops/mul.mlir +++ b/iree/test/e2e/tosa_ops/mul.mlir
@@ -14,11 +14,15 @@ return } +// TODO: The following generates tosa.ApplyScale ops that leaks to backends. +// Sizes like tensor<4xi32> will trigger vectorization on the SPIR-V backend. +// But we cannot vectorize tosa.ApplyScale ops. + func @tensor_int_shifted() { - %0 = util.unfoldable_constant dense<[1, 0, 3, 4]> : tensor<4xi32> - %1 = util.unfoldable_constant dense<[5, 6, -3, 8]> : tensor<4xi32> - %result = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[3, 0, -4, 16]> : tensor<4xi32>) : tensor<4xi32> + %0 = util.unfoldable_constant dense<[1, 0, 3, 4, 4]> : tensor<5xi32> + %1 = util.unfoldable_constant dense<[5, 6, -3, 8, 8]> : tensor<5xi32> + %result = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32> + check.expect_eq_const(%result, dense<[3, 0, -4, 16, 16]> : tensor<5xi32>) : tensor<5xi32> return }