[spirv] Add a pipeline to use workgroup memory (#8425)
This commit adds a pipeline for using workgroup memory in
SPIR-V CodeGen. It basically moves existing LLVM GPU
code into Common/ directory and use them. This pipeline
is meant to support desktop/server class GPUs where using
shared memory can give nice performance gains. So along
the way, added configuration stubs for NVIDIA/AMD GPUs.
They are not tuned yet, just to get started.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index d681159..f4783b4 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -43,6 +43,8 @@
"FoldAffineMinInDistributedLoops.cpp",
"FoldTensorExtractOpPass.cpp",
"ForOpCanonicalizationPass.cpp",
+ "GPUDistributeSharedMemoryCopy.cpp",
+ "GPUPipelining.cpp",
"IREEComprehensiveBufferizePass.cpp",
"InsertDistributionInfoPass.cpp",
"LinalgBufferizePass.cpp",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index c613a6b..80ef1e8 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -36,6 +36,8 @@
"FoldAffineMinInDistributedLoops.cpp"
"FoldTensorExtractOpPass.cpp"
"ForOpCanonicalizationPass.cpp"
+ "GPUDistributeSharedMemoryCopy.cpp"
+ "GPUPipelining.cpp"
"IREEComprehensiveBufferizePass.cpp"
"InsertDistributionInfoPass.cpp"
"LinalgBufferizePass.cpp"
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp b/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
similarity index 91%
rename from iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
rename to iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
index c0702f5..0499715 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
+++ b/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
@@ -8,12 +8,12 @@
#include <numeric>
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
@@ -195,9 +195,9 @@
namespace {
-class LLVMGPUDistributeSharedMemoryCopyPass
- : public LLVMGPUDistributeSharedMemoryCopyBase<
- LLVMGPUDistributeSharedMemoryCopyPass> {
+class GPUDistributeSharedMemoryCopyPass
+ : public GPUDistributeSharedMemoryCopyBase<
+ GPUDistributeSharedMemoryCopyPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect, scf::SCFDialect>();
}
@@ -220,20 +220,8 @@
copiesToWorkgroupMem, [flatWorkgroupSize](linalg::GenericOp copyOp) {
auto shape =
copyOp.getOperand(0).getType().cast<MemRefType>().getShape();
- // Verify that each dimension of the shape can be distributed on the
- // threads
- int64_t threadsAvailable = flatWorkgroupSize;
- for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
- int64_t numElementPerThread =
- dim.index() == 0 ? targetVectorSize : 1;
- int64_t numThreads = dim.value() / numElementPerThread;
- if (numThreads == 0) return false;
- numThreads = std::min(numThreads, threadsAvailable);
- if (threadsAvailable % numThreads != 0) return false;
- threadsAvailable = threadsAvailable / numThreads;
- if (threadsAvailable == 1) break;
- }
- return threadsAvailable == 1;
+ return canPerformVectorAccessUsingAllThreads(shape, flatWorkgroupSize,
+ targetVectorSize);
});
if (isAligned) {
// Step 1. Vectorize the shared memory copy.
@@ -292,8 +280,8 @@
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
-createLLVMGPUDistributeSharedMemoryCopy() {
- return std::make_unique<LLVMGPUDistributeSharedMemoryCopyPass>();
+createGPUDistributeSharedMemoryCopy() {
+ return std::make_unique<GPUDistributeSharedMemoryCopyPass>();
}
} // namespace iree_compiler
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp b/iree/compiler/Codegen/Common/GPUPipelining.cpp
similarity index 94%
rename from iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
rename to iree/compiler/Codegen/Common/GPUPipelining.cpp
index 8e6355c..1ff2865 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
+++ b/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -83,9 +82,8 @@
}
namespace {
-struct LLVMGPUPipeliningPass
- : public LLVMGPUPipeliningBase<LLVMGPUPipeliningPass> {
- LLVMGPUPipeliningPass(unsigned depth) : depth(depth) {}
+struct GPUPipeliningPass : public GPUPipeliningBase<GPUPipeliningPass> {
+ GPUPipeliningPass(unsigned depth) : depth(depth) {}
void runOnOperation() override {
auto funcOp = getOperation();
MLIRContext* context = &getContext();
@@ -155,9 +153,9 @@
};
} // namespace
-std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUPipeliningPass(
+std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
unsigned depth) {
- return std::make_unique<LLVMGPUPipeliningPass>(depth);
+ return std::make_unique<GPUPipeliningPass>(depth);
}
} // namespace iree_compiler
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index d6abde2..b78d822 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -24,6 +24,7 @@
"canonicalize_interface_load_store.mlir",
"convert_to_destination_passing_style.mlir",
"dead_alloc.mlir",
+ "distribute_gpu_shared_memory.mlir",
"flatten_memref_subspan.mlir",
"fold_affine_min_in_distributed_loops.mlir",
"fold_tensor_extract_op.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index f45b163..7e9b53b 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -19,6 +19,7 @@
"canonicalize_interface_load_store.mlir"
"convert_to_destination_passing_style.mlir"
"dead_alloc.mlir"
+ "distribute_gpu_shared_memory.mlir"
"flatten_memref_subspan.mlir"
"fold_affine_min_in_distributed_loops.mlir"
"fold_tensor_extract_op.mlir"
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir b/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
similarity index 91%
rename from iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
rename to iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
index dde4b51..51cbc0f 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
+++ b/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-distribute-shared-memory-copy))))' -cse %s | FileCheck %s
+// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' -cse %s | FileCheck %s
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
@@ -49,10 +49,10 @@
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
- ins(%m0 : memref<64x16xf32>)
- outs(%sm0 : memref<64x16xf32, 3>)
+ ins(%m0 : memref<64x16xf32>)
+ outs(%sm0 : memref<64x16xf32, 3>)
attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
^bb0(%arg4: f32, %s: f32): // no predecessors
linalg.yield %arg4 : f32
@@ -65,10 +65,10 @@
// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
// CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
- ins(%m1 : memref<256x4xf32>)
- outs(%sm1 : memref<256x4xf32, 3>)
+ ins(%m1 : memref<256x4xf32>)
+ outs(%sm1 : memref<256x4xf32, 3>)
attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
^bb0(%arg4: f32, %s: f32): // no predecessors
linalg.yield %arg4 : f32
@@ -82,10 +82,10 @@
// CHECK: vector.transfer_write %[[R5]], %{{.*}}[%c1, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
// CHECK: vector.transfer_write %[[R6]], %{{.*}}[%c2, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
- ins(%m2 : memref<3x512xf32>)
- outs(%sm2 : memref<3x512xf32, 3>)
+ ins(%m2 : memref<3x512xf32>)
+ outs(%sm2 : memref<3x512xf32, 3>)
attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
^bb0(%arg4: f32, %s: f32): // no predecessors
linalg.yield %arg4 : f32
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.td b/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 9293a89..e4108de 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -42,6 +42,8 @@
: I32EnumAttrCase<"SPIRVVectorize", 12>;
def SPIRV_VectorizeToCooperativeOps
: I32EnumAttrCase<"SPIRVVectorizeToCooperativeOps", 13>;
+def SPIRV_VectorizeWithWorkgroupMemory
+ : I32EnumAttrCase<"SPIRVVectorizeWithWorkgroupMemory", 14>;
def None
: I32EnumAttrCase<"None", 0xff>;
@@ -56,7 +58,8 @@
CPU_BufferOpsTileAndVectorize, Linalg_TransformInterpCodegen,
LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize, LLVMGPU_MatmulSimt,
LLVMGPU_MatmulTensorCore, SPIRV_Distribute, SPIRV_Vectorize,
- SPIRV_VectorizeToCooperativeOps, None]> {
+ SPIRV_VectorizeToCooperativeOps, SPIRV_VectorizeWithWorkgroupMemory,
+ None]> {
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
// Don't generate a C++ class! We want to use the AttrDef
let genSpecializedAttr = 0;
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index 2bb5615..038d634 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -17,14 +17,11 @@
"ConvertToNVVM.cpp",
"ConvertToROCDL.cpp",
"KernelConfig.cpp",
- "LLVMGPUDistributeSharedMemoryCopy.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUMultiBuffering.cpp",
- "LLVMGPUPipelining.cpp",
"LLVMGPUReduceBankConflicts.cpp",
"LLVMGPUTensorCoreVectorization.cpp",
"LLVMGPUTileAndDistribute.cpp",
- "LLVMGPUUtils.cpp",
"LLVMGPUVectorLowering.cpp",
"LLVMGPUVectorToGPU.cpp",
"LLVMGPUVectorization.cpp",
@@ -34,7 +31,6 @@
hdrs = [
"ConvertToLLVM.h",
"KernelConfig.h",
- "LLVMGPUUtils.h",
],
deps = [
"//iree/compiler/Codegen:PassHeaders",
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 586059d..2abeefd 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -16,20 +16,16 @@
HDRS
"ConvertToLLVM.h"
"KernelConfig.h"
- "LLVMGPUUtils.h"
SRCS
"ConvertToLLVM.cpp"
"ConvertToNVVM.cpp"
"ConvertToROCDL.cpp"
"KernelConfig.cpp"
- "LLVMGPUDistributeSharedMemoryCopy.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUMultiBuffering.cpp"
- "LLVMGPUPipelining.cpp"
"LLVMGPUReduceBankConflicts.cpp"
"LLVMGPUTensorCoreVectorization.cpp"
"LLVMGPUTileAndDistribute.cpp"
- "LLVMGPUUtils.cpp"
"LLVMGPUVectorLowering.cpp"
"LLVMGPUVectorToGPU.cpp"
"LLVMGPUVectorization.cpp"
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 828069c..8f0c24f 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -8,10 +8,10 @@
#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -178,38 +178,6 @@
return success();
}
-static Optional<Value> allocateWorkgroupMemory(
- OpBuilder &b, memref::SubViewOp subview,
- ArrayRef<Value> boundingSubViewSize, DataLayout &layout) {
- OpBuilder::InsertionGuard guard(b);
- func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
- if (!funcOp) {
- subview.emitError("expected op to be within std.func");
- return llvm::None;
- }
-
- // The bounding subview size is expected to be constant. This specified the
- // shape of the allocation.
- SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
- llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
- APInt value;
- if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
- return -1;
- }));
- if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
- MemRefType allocType =
- MemRefType::get(shape, subview.getType().getElementType(), {},
- gpu::GPUDialect::getWorkgroupAddressSpace());
- b.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
- Value buffer = b.create<memref::AllocOp>(funcOp.getLoc(), allocType);
- return buffer;
-}
-
-static LogicalResult deallocateWorkgroupMemory(OpBuilder &b, Value buffer) {
- // Nothing to do.
- return success();
-}
-
static void populatePromotionPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp
deleted file mode 100644
index c7ab1aa..0000000
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
-
-#include "mlir/Dialect/GPU/Passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
- mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
- llvm::ArrayRef<int64_t> workgroupSize) {
- assert(numDims <= kNumGPUDims);
- llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
- std::array<gpu::Dimension, kNumGPUDims> dimAttr{
- gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
- mlir::Type indexType = builder.getIndexType();
- for (unsigned i = 0; i < numDims; ++i) {
- procInfo[numDims - 1 - i] = {
- builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]),
- builder.create<mlir::arith::ConstantOp>(
- loc, builder.getIndexAttr(workgroupSize[i]))};
- }
- return procInfo;
-}
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
- mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
- llvm::ArrayRef<int64_t> numSubgroups) {
- assert(numDims <= kNumGPUDims);
- llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
- std::array<gpu::Dimension, kNumGPUDims> dimAttr{
- gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
- mlir::Type indexType = builder.getIndexType();
- for (unsigned i = 0; i < numDims; ++i) {
- mlir::Value subgroupId =
- builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
- if (i == 0) {
- mlir::AffineExpr d0 = builder.getAffineDimExpr(0);
- subgroupId = mlir::makeComposedAffineApply(
- builder, loc, d0.floorDiv(builder.getAffineConstantExpr(kWarpSize)),
- {subgroupId});
- }
- procInfo[numDims - 1 - i] = {
- subgroupId, builder.create<mlir::arith::ConstantOp>(
- loc, builder.getIndexAttr(numSubgroups[i]))};
- }
- return procInfo;
-}
-
-std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp) {
- std::array<int64_t, 3> workgroupSize;
- auto entryPointOp = mlir::iree_compiler::getEntryPoint(funcOp);
- llvm::Optional<mlir::ArrayAttr> workgroupSizeAttr =
- entryPointOp.workgroup_size();
- assert(workgroupSizeAttr.hasValue());
- for (auto it : llvm::enumerate(workgroupSizeAttr.getValue())) {
- workgroupSize[it.index()] =
- it.value().cast<mlir::IntegerAttr>().getValue().getZExtValue();
- }
- return workgroupSize;
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h b/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h
deleted file mode 100644
index aaca6bd..0000000
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
-#define IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
-
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-static constexpr int32_t kNumGPUDims = 3;
-static constexpr int32_t kWarpSize = 32;
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
- mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
- llvm::ArrayRef<int64_t> workgroupSize);
-
-/// Compute subgroup ID. CUDA doesn't have a subgroupId equivalent so we are are
-/// computing the subgroup ID based on the threadID.
-/// When tiling to warp we assume each warp is full and we pick a workgroup
-/// size so that `workgroupSize.x % warpSize == 0`. This is why we can have
-/// warpId = { threadId.x / warpSize, threadId.y, threadId.z }.
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
- mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
- llvm::ArrayRef<int64_t> numSubgroups);
-
-/// return the workgroup size associated to the funcOp entry point.
-std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp);
-
-} // namespace iree_compiler
-} // namespace mlir
-#endif // IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index a70bc33..a1224ed 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -88,7 +88,7 @@
// Distribute linalg onto threads within the workgroup.
pm.addNestedPass<func::FuncOp>(createLLVMGPUTileAndDistribute());
pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
- pm.addNestedPass<func::FuncOp>(createLLVMGPUDistributeSharedMemoryCopy());
+ pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
@@ -106,7 +106,7 @@
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
// Pipeline memory operations.
- pm.addNestedPass<func::FuncOp>(createLLVMGPUPipeliningPass());
+ pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
}
void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) {
@@ -118,7 +118,7 @@
if (pipelineDepth > 1)
pm.addNestedPass<func::FuncOp>(createLLVMGPUMultiBuffering(pipelineDepth));
pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
- pm.addNestedPass<func::FuncOp>(createLLVMGPUDistributeSharedMemoryCopy());
+ pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
@@ -142,7 +142,7 @@
pm.addPass(createCSEPass());
// Pipeline memory operations.
- pm.addNestedPass<func::FuncOp>(createLLVMGPUPipeliningPass(pipelineDepth));
+ pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass(pipelineDepth));
}
void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
diff --git a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
index c4854ea..e832517 100644
--- a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Linalg/Passes.h"
namespace mlir {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/BUILD b/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 5011132..36d1d95 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -22,7 +22,6 @@
"convert_to_nvvm.mlir",
"convert_to_rocdl.mlir",
"distribute_to_thread.mlir",
- "distribute_wg_copy.mlir",
"gpu_set_num_workgroups.mlir",
"nvvm_pipeline_test.mlir",
"reduce_bank_conflicts.mlir",
diff --git a/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 7069e4f..547de8a 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -17,7 +17,6 @@
"convert_to_nvvm.mlir"
"convert_to_rocdl.mlir"
"distribute_to_thread.mlir"
- "distribute_wg_copy.mlir"
"gpu_set_num_workgroups.mlir"
"illegal_configuration.mlir"
"legalize.mlir"
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index a89bbb6..b84ec6b 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -153,6 +153,15 @@
/// Creates a pass to convert memref.copy to linalg op.
std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();
+/// Convert GPU shared memory copies to distributed
+/// transfer_read/transfer_write.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createGPUDistributeSharedMemoryCopy();
+
+/// Apply software pipelining.
+std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
+ unsigned depth = 1);
+
/// Converts vector ops to gpu dialect.
std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
unsigned swizzleLogTile = 0);
@@ -352,14 +361,6 @@
/// Lower vector ops before convertion to LLVM.
std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUVectorLoweringPass();
-/// Convert shared memory copies to distributed transfer_read/transfer_write.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLLVMGPUDistributeSharedMemoryCopy();
-
-/// Apply software pipelining.
-std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUPipeliningPass(
- unsigned depth = 1);
-
/// Apply multi-buffering transformation.
std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUMultiBuffering(
unsigned numBuffers = 5);
@@ -391,6 +392,12 @@
/// performs distribution to threads with vectorization.
void addSPIRVTileAndVectorizeToCooperativeOpsPassPipeline(OpPassManager &pm);
+/// Pass pipeline to lower IREE HAL executables with workgroup tiled and
+/// distributed Linalg ops to SPIR-V scalar and vector code. Additionally
+/// performs distribution to threads with vectorization and promotion to use
+/// workgroup memory.
+void addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(OpPassManager &pm);
+
/// Pass to perform the final conversion to SPIR-V dialect.
///
/// This pass converts remaining interface ops into SPIR-V global variables,
@@ -411,6 +418,10 @@
/// Pass to tile and distribute Linalg ops with buffer semantics to invocations.
std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndDistributePass();
+/// Pass to promote Linalg ops with buffer semantics to use workgroup memory and
+/// then tile to invocations.
+std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndPromotePass();
+
/// Pass to tile Linalg ops with buffer semantics to subgroups and vectorize to
/// vector ops suitable for lowering to SPIR-V cooperative ops.
std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index d4f23b6..05e8fcb 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -163,6 +163,17 @@
"mlir::iree_compiler::createMemrefCopyToLinalgPass()";
}
+def GPUDistributeSharedMemoryCopy :
+ Pass<"iree-gpu-distribute-shared-memory-copy", "func::FuncOp"> {
+ let summary = "Pass to distribute shared memory copies to threads.";
+ let constructor = "mlir::iree_compiler::createGPUDistributeSharedMemoryCopy()";
+}
+
+def GPUPipelining : Pass<"iree-gpu-pipelining", "func::FuncOp"> {
+ let summary = "Pass to do software pipelining.";
+ let constructor = "mlir::iree_compiler::createGPUPipeliningPass()";
+}
+
def WorkGroupSwizzle :
Pass<"iree-workgroup-swizzle", "func::FuncOp"> {
let summary = "swizzle the workgroup ids for better cache reuse";
@@ -293,18 +304,6 @@
let constructor = "mlir::iree_compiler::createLLVMGPUVectorLoweringPass()";
}
-def LLVMGPUDistributeSharedMemoryCopy :
- Pass<"iree-llvmgpu-distribute-shared-memory-copy", "func::FuncOp"> {
- let summary = "Pass to distribute shared memory copies to threads.";
- let constructor = "mlir::iree_compiler::createLLVMGPUDistributeSharedMemoryCopy()";
-}
-
-def LLVMGPUPipelining :
- Pass<"iree-llvmgpu-pipelining", "func::FuncOp"> {
- let summary = "Pass to do software pipelining.";
- let constructor = "mlir::iree_compiler::createLLVMGPUPipeliningPass()";
-}
-
def LLVMGPUMultiBuffering :
Pass<"iree-llvmgpu-multi-buffering", "func::FuncOp"> {
let summary = "Pass to do multi buffering.";
@@ -366,6 +365,13 @@
"mlir::iree_compiler::createSPIRVTileAndVectorizeToCooperativeOpsPass()";
}
+def SPIRVTileAndPromote : Pass<"iree-spirv-tile-and-promote", "func::FuncOp"> {
+ let summary = "Promote tiled Linalg ops with buffer semantics to use "
+ "workgroup memory and then tile to invocations";
+ let constructor =
+ "mlir::iree_compiler::createSPIRVTileAndPromotePass()";
+}
+
def SPIRVVectorize : Pass<"iree-spirv-vectorize", "func::FuncOp"> {
let summary = "Vectorize Linalg ops with buffer semantics";
let constructor = "mlir::iree_compiler::createSPIRVVectorizePass()";
diff --git a/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
new file mode 100644
index 0000000..4fb689a
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -0,0 +1,52 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- AMDConfig.h - AMD CodeGen Configurations ---------------------------===//
+//
+// This file contains CodeGen configurations for AMD GPUs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
+#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+#define DEBUG_TYPE "iree-spirv-amd-config"
+
+namespace mlir {
+namespace iree_compiler {
+namespace detail {
+
+// RDNA architecture:
+// https://gpuopen.com/wp-content/uploads/2019/08/RDNA_Architecture_public.pdf
+//
+// Workgroup Processor (WGP) is the block for workgroups in RDNA; it has its own
+// instruction/constant cache, L0 cache x2, Local Data Share (LDS, a.k.a. shared
+// memory), SALU x4, SIMD32 x4.
+//
+// * 1024 registers per SIMD32
+// * 128KB LDS per WGP
+// * Max 20 waves per SIMD32
+// * Max 64KB LDS per workgroup
+
+LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+ Operation *rootOp) {
+ int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
+ if (auto matmulOp = dyn_cast<linalg::MatmulOp>(rootOp)) {
+ std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 8};
+ std::array<int64_t, 3> threadMNK = {8, 4, 32};
+ return setMatmulOpConfig(matmulOp, subgroupSize, workgroupXY, threadMNK,
+ /*useWorkgroupMemory=*/true);
+ }
+ return success();
+}
+
+} // namespace detail
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index d4511cd..bfbd2f4 100644
--- a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -29,10 +29,10 @@
Operation *rootOp) {
int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
return TypeSwitch<Operation *, LogicalResult>(rootOp)
- .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([](auto op) {
- std::array<int64_t, 2> workgroupXY = {32, 2};
+ .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([subgroupSize](auto op) {
+ std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
std::array<int64_t, 3> threadMNK = {16, 4, 4};
- return setMatmulOpConfig(op, workgroupXY, threadMNK);
+ return setMatmulOpConfig(op, subgroupSize, workgroupXY, threadMNK);
})
.Case<linalg::Conv2DNhwcHwcfOp>([subgroupSize](auto op) {
return setConvOpConfig(op, subgroupSize,
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index 12115b0..e2543c5 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -13,6 +13,7 @@
cc_library(
name = "SPIRV",
srcs = [
+ "AMDConfig.cpp",
"AdrenoConfig.cpp",
"ConvertToSPIRVPass.cpp",
"KernelConfig.cpp",
@@ -25,6 +26,7 @@
"SPIRVLowerExecutableTargetPass.cpp",
"SPIRVTile.cpp",
"SPIRVTileAndDistribute.cpp",
+ "SPIRVTileAndPromote.cpp",
"SPIRVTileAndVectorizeToCooperativeOps.cpp",
"SPIRVVectorToCooperativeOps.cpp",
"SPIRVVectorize.cpp",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index cea4c3b..e222de5 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -18,6 +18,7 @@
"MemorySpace.h"
"Utils.h"
SRCS
+ "AMDConfig.cpp"
"AdrenoConfig.cpp"
"ConvertToSPIRVPass.cpp"
"KernelConfig.cpp"
@@ -30,6 +31,7 @@
"SPIRVLowerExecutableTargetPass.cpp"
"SPIRVTile.cpp"
"SPIRVTileAndDistribute.cpp"
+ "SPIRVTileAndPromote.cpp"
"SPIRVTileAndVectorizeToCooperativeOps.cpp"
"SPIRVVectorToCooperativeOps.cpp"
"SPIRVVectorize.cpp"
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index d8e6e2c..6dd8f80 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -6,6 +6,9 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
+#include <functional>
+#include <numeric>
+
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
@@ -169,9 +172,10 @@
namespace detail {
-LogicalResult setMatmulOpConfig(linalg::LinalgOp op,
+LogicalResult setMatmulOpConfig(linalg::LinalgOp op, int64_t subgroupSize,
std::array<int64_t, 2> bestWorkgroupSizeXY,
- std::array<int64_t, 3> bestThreadTileSizeMNK) {
+ std::array<int64_t, 3> bestThreadTileSizeMNK,
+ bool useWorkgroupMemory) {
auto lhsType = op.inputs()[0].getType().cast<ShapedType>();
auto elementBits = lhsType.getElementType().getIntOrFloatBitWidth();
if (elementBits != 16 && elementBits != 32) return success();
@@ -261,11 +265,20 @@
}
if (reductionTileSizes[2 + isBM] == 0) return success();
- auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize;
+ auto totalThreads =
+ std::accumulate(workgroupSize.begin(), workgroupSize.end(), 1,
+ std::multiplies<int64_t>());
+ auto pipeline =
+ (useWorkgroupMemory && totalThreads > subgroupSize)
+ ? IREE::Codegen::DispatchLoweringPassPipeline::
+ SPIRVVectorizeWithWorkgroupMemory
+ : IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize;
+
TileSizesListType tileSizes;
tileSizes.push_back(workgroupTileSizes);
tileSizes.push_back(invocationTileSizes);
tileSizes.push_back(reductionTileSizes);
+
return setOpConfigAndEntryPointFnTranslation(
op->getParentOfType<func::FuncOp>(), op, tileSizes, pipeline,
workgroupSize);
@@ -512,6 +525,9 @@
// First try to find a proper CodeGen configuration to tile and vectorize for
// the current target architecture.
switch (targetEnv.getVendorID()) {
+ case spirv::Vendor::AMD:
+ result = detail::setAMDCodeGenConfig(targetEnv, rootOp);
+ break;
case spirv::Vendor::ARM:
result = detail::setMaliCodeGenConfig(targetEnv, rootOp);
break;
@@ -534,10 +550,12 @@
spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
return TypeSwitch<Operation *, LogicalResult>(rootOp)
.Case<linalg::BatchMatmulOp, linalg::MatmulOp>([limits](auto op) {
- // Try to tile and vectorize first.
+ // Try to tile and vectorize first. It's common to see 32 threads
+ // per subgroup for GPUs.
std::array<int64_t, 2> workgroupXY = {32, 2};
std::array<int64_t, 3> threadMNK = {8, 8, 4};
- auto result = detail::setMatmulOpConfig(op, workgroupXY, threadMNK);
+ auto result = detail::setMatmulOpConfig(op, /*subgroupSize=*/32,
+ workgroupXY, threadMNK);
if (failed(result)) return result;
if (getLoweringConfig(op)) return result;
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.h b/iree/compiler/Codegen/SPIRV/KernelConfig.h
index a11c0aa..3d27354 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.h
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.h
@@ -35,9 +35,10 @@
/// Sets CodeGen configurations via attributes to the given matmul `linalgOp`
/// with the given best workgroup size and tile size hints.
-LogicalResult setMatmulOpConfig(linalg::LinalgOp linalgOp,
+LogicalResult setMatmulOpConfig(linalg::LinalgOp linalgOp, int64_t subgroupSize,
std::array<int64_t, 2> bestWorkgroupSizeXY,
- std::array<int64_t, 3> bestThreadTileSizeMNK);
+ std::array<int64_t, 3> bestThreadTileSizeMNK,
+ bool useWorkgroupMemory = false);
/// Sets CodeGen configuration for GPUs from a specific vendor.
///
@@ -51,6 +52,8 @@
LogicalResult setAdrenoCodeGenConfig(const spirv::TargetEnv &targetEnv,
Operation *rootOp);
+LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+ Operation *rootOp);
LogicalResult setMaliCodeGenConfig(const spirv::TargetEnv &targetEnv,
Operation *rootOp);
LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
diff --git a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index 9411ec5..caf65c9 100644
--- a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -29,8 +29,8 @@
Operation *rootOp) {
int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
return TypeSwitch<Operation *, LogicalResult>(rootOp)
- .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([](auto op) {
- std::array<int64_t, 2> workgroupXY = {8, 2};
+ .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([subgroupSize](auto op) {
+ std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
std::array<int64_t, 3> threadMNK;
auto inputType = op.inputs()[0].getType().template cast<ShapedType>();
if (inputType.getElementType().isF16()) {
@@ -38,7 +38,7 @@
} else {
threadMNK = {6, 4, 4};
}
- return setMatmulOpConfig(op, workgroupXY, threadMNK);
+ return setMatmulOpConfig(op, subgroupSize, workgroupXY, threadMNK);
})
.Case<linalg::Conv2DNhwcHwcfOp>([subgroupSize](auto op) {
bool hasPaddedInput =
diff --git a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
index c4b121a..bd7aeb4 100644
--- a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
@@ -109,10 +109,48 @@
workgroupSize);
}
+// Volta architecture:
+// https://docs.nvidia.com/cuda/volta-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 96KB shared memory per SM
+// * Max 32 thread blocks per SM
+// * Max 64 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Turing architecture:
+// https://docs.nvidia.com/cuda/turing-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 64KB shared memory per SM
+// * Max 16 thread blocks per SM
+// * Max 32 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Ampere architecture:
+// https://docs.nvidia.com/cuda/ampere-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 164KB/96KB shared memory for compute capability 8.0/8.6
+// * Max 32/16 thread blocks per SM for compute capability 8.0/8.6
+// * Max 64 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Note that the above numbers are from CUDA docs; for Vulkan the drivder can
+// expose slightly different numbers, e.g., max shared memory size is smaller.
+
LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
Operation *rootOp) {
+ int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(rootOp)) {
- return setOpConfig(targetEnv, matmulOp);
+ // First try to see if we can use tensor cores.
+ if (failed(setOpConfig(targetEnv, matmulOp))) return failure();
+ if (getLoweringConfig(rootOp)) return success();
+
+ std::array<int64_t, 2> workgroupXY = {subgroupSize, 8};
+ std::array<int64_t, 3> threadMNK = {16, 4, 32};
+ return setMatmulOpConfig(matmulOp, subgroupSize, workgroupXY, threadMNK,
+ /*useWorkgroupMemory=*/true);
}
return success();
}
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index d8e5c53..ec83565 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -201,6 +201,37 @@
pm.addNestedPass<func::FuncOp>(createSPIRVVectorToCooperativeOpsPass());
}
+void addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(
+ OpPassManager &pm) {
+ pm.addNestedPass<func::FuncOp>(createInsertDistributionInfoPass());
+ pm.addNestedPass<func::FuncOp>(createTileAndDistributeToWorkgroupsPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ addLinalgBufferizePasses(pm, gpuAllocationFunction);
+
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ // Tile and distribute to GPU invocations.
+ pm.addNestedPass<func::FuncOp>(createSPIRVTileAndPromotePass());
+ pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
+ pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+
+ pm.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+
+ // pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
+
+ addLoopMaterializationPasses(pm);
+}
+
void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createInsertDistributionInfoPass());
pm.addNestedPass<func::FuncOp>(createTileAndDistributeToWorkgroupsPass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index c65e1e3..dd06392 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -107,6 +107,10 @@
SPIRVVectorizeToCooperativeOps:
addSPIRVTileAndVectorizeToCooperativeOpsPassPipeline(nestedModulePM);
break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::
+ SPIRVVectorizeWithWorkgroupMemory:
+ addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(nestedModulePM);
+ break;
default:
variantOp.emitOpError("Unsupported pipeline on GPU target.");
return signalPassFailure();
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
index 5bbbb2e..256114a 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
new file mode 100644
index 0000000..d409a81
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -0,0 +1,284 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVTileAndPromote.cpp --------------------------------------------===//
+//
+// This pass tiles promote Linalg ops with buffer semantics to use workgroup
+// memory and then tiles to invocations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-tile-and-promote"
+
+namespace mlir {
+namespace iree_compiler {
+
+//====---------------------------------------------------------------------===//
+// Reduction tiling patterns
+//====---------------------------------------------------------------------===//
+
+static void populateTilingReductionPatterns(
+ RewritePatternSet &patterns, linalg::LinalgTransformationFilter filter) {
+ auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) {
+ return getTileSizes(builder, op, 2);
+ };
+
+ auto tilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(getTileSizeFn);
+ linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::MatmulOp>::insert(
+ patterns, tilingOptions, filter);
+}
+
+//===----------------------------------------------------------------------===//
+// Invocation tiling patterns
+//===----------------------------------------------------------------------===//
+
+static void populateTilingToInvocationPatterns(
+ RewritePatternSet &patterns, linalg::LinalgTransformationFilter filter) {
+ linalg::TileSizeComputationFunction getTileSizeFn = [&](OpBuilder &builder,
+ Operation *op) {
+ return getTileSizes(builder, op, 1);
+ };
+
+ auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
+ builder, loc, parallelLoopRanges.size());
+ };
+ linalg::LinalgLoopDistributionOptions distributionOptions;
+ distributionOptions.procInfo = getThreadProcInfoFn;
+ distributionOptions.distributionMethod = {
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic}};
+
+ auto tilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(getTileSizeFn)
+ .setDistributionOptions(distributionOptions);
+
+ linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::FillOp,
+ linalg::GenericOp,
+ linalg::MatmulOp>::insert(patterns, tilingOptions,
+ filter);
+}
+
+//===----------------------------------------------------------------------===//
+// Promotion patterns
+//===----------------------------------------------------------------------===//
+
+static const char promoteLHSMarker[] = "promote_lhs";
+static const char promoteRHSMarker[] = "promote_rhs";
+static const char promoteBothMarker[] = "promote_lhs_and_rhs";
+
+LogicalResult copyToWorkgroupMemory(OpBuilder &builder, Value src, Value dst) {
+ Operation *copyOp = builder.create<memref::CopyOp>(src.getLoc(), src, dst);
+ setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
+ return success();
+}
+
+static void populatePromotionPatterns(RewritePatternSet &patterns,
+ StringAttr replaceMarker) {
+ MLIRContext *context = patterns.getContext();
+ auto baseOptions =
+ linalg::LinalgPromotionOptions()
+ .setAllocationDeallocationFns(allocateWorkgroupMemory,
+ deallocateWorkgroupMemory)
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
+ .setUseFullTileBuffers({false, false});
+ auto promoteLHSOptions = baseOptions.setOperandsToPromote({0});
+ auto promoteRHSOptions = baseOptions.setOperandsToPromote({1});
+ auto promoteBothOptions = baseOptions.setOperandsToPromote({0, 1});
+
+ linalg::LinalgTransformationFilter promoteLHSFilter(
+ {StringAttr::get(context, promoteLHSMarker)}, replaceMarker);
+ linalg::LinalgTransformationFilter promoteRHSFilter(
+ {StringAttr::get(context, promoteRHSMarker)}, replaceMarker);
+ linalg::LinalgTransformationFilter promoteBothFilter(
+ {StringAttr::get(context, promoteBothMarker)}, replaceMarker);
+
+ patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+ linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+ patterns.getContext(), promoteLHSOptions, promoteLHSFilter);
+ patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+ linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+ patterns.getContext(), promoteRHSOptions, promoteRHSFilter);
+ patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+ linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+ patterns.getContext(), promoteBothOptions, promoteBothFilter);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct SPIRVTileAndPromotePass final
+ : public SPIRVTileAndPromoteBase<SPIRVTileAndPromotePass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<gpu::GPUDialect>();
+ }
+
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void SPIRVTileAndPromotePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ func::FuncOp funcOp = getOperation();
+ auto entryPointOp = getEntryPoint(funcOp);
+ if (!entryPointOp) return;
+
+ { // Tile reduction dimensions.
+ RewritePatternSet tilingPatterns(context);
+ linalg::LinalgTransformationFilter filter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, getWorkgroupKTiledMarker()));
+ populateTilingReductionPatterns(tilingPatterns, filter);
+ if (failed(
+ applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)))) {
+ funcOp.emitOpError() << "failed tiling reduction";
+ return signalPassFailure();
+ }
+
+ RewritePatternSet patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ funcOp.emitOpError() << "failed canonicalization after tiling reduction";
+ return signalPassFailure();
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling reduction dimensions ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ auto workgroupSize = llvm::to_vector<4>(llvm::map_range(
+ entryPointOp.workgroup_size().getValue(),
+ [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+ int64_t flatWorkgroupSize =
+ workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
+ auto subgroupSize =
+ getSPIRVTargetEnvAttr(funcOp).getResourceLimits().subgroup_size();
+
+ funcOp.walk([&](Operation *op) {
+ if (isa<linalg::FillOp, linalg::GenericOp>(op)) {
+ op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+ StringAttr::get(context, getWorkgroupMemoryMarker()));
+ } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp>(op)) {
+ auto lhsShape = op->getOperand(0).getType().cast<ShapedType>().getShape();
+ auto rhsShape = op->getOperand(1).getType().cast<ShapedType>().getShape();
+ bool canPromoteLHS =
+ canPerformVectorAccessUsingAllThreads(lhsShape, flatWorkgroupSize, 4);
+ bool canPromoteRHS =
+ canPerformVectorAccessUsingAllThreads(rhsShape, flatWorkgroupSize, 4);
+ StringAttr promoteMarker =
+ StringAttr::get(context, getWorkgroupMemoryMarker());
+ if (canPromoteLHS && canPromoteRHS) {
+ promoteMarker = StringAttr::get(context, promoteBothMarker);
+ } else if (canPromoteLHS) {
+ promoteMarker = StringAttr::get(context, promoteLHSMarker);
+ } else if (canPromoteRHS) {
+ promoteMarker = StringAttr::get(context, promoteRHSMarker);
+ }
+ op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+ promoteMarker);
+ }
+ return WalkResult::advance();
+ });
+
+ // Only promote to workgroup size if there are multiple warps.
+ if (flatWorkgroupSize > subgroupSize.getInt()) {
+ RewritePatternSet promotionPatterns(&getContext());
+ auto replaceMarker = StringAttr::get(context, getWorkgroupMemoryMarker());
+ populatePromotionPatterns(promotionPatterns, replaceMarker);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(promotionPatterns)))) {
+ return signalPassFailure();
+ }
+
+ // Insert barriers before and after copies to workgroup memory and skip
+ // insert barriers between back to back copy to workgroup memory.
+ OpBuilder builder(&getContext());
+ funcOp.walk([&builder](memref::CopyOp copyOp) {
+ if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) {
+ Operation *prevOp = copyOp->getPrevNode();
+ if (!prevOp || !hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) {
+ builder.setInsertionPoint(copyOp);
+ builder.create<gpu::BarrierOp>(copyOp.getLoc());
+ }
+ Operation *nextOp = copyOp->getNextNode();
+ if (!nextOp || !hasMarker(nextOp, getCopyToWorkgroupMemoryMarker())) {
+ builder.setInsertionPointAfter(copyOp);
+ builder.create<gpu::BarrierOp>(copyOp.getLoc());
+ }
+ }
+ });
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After promotion ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ { // Tile and distribute to invocations.
+ RewritePatternSet tilingPatterns(&getContext());
+ linalg::LinalgTransformationFilter filter(
+ {StringAttr::get(context, getWorkgroupMemoryMarker())}, llvm::None);
+ populateTilingToInvocationPatterns(tilingPatterns, filter);
+ if (failed(
+ applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)))) {
+ funcOp.emitOpError() << "failed tiling and distributing to invocations";
+ return signalPassFailure();
+ }
+
+ RewritePatternSet patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ populateFoldAffineMinInDistributedLoopsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ // TODO(#4759): This does not converge after the max number of iterations.
+ // It indicates that some pattern upstream is generating ops even when the
+ // pattern failed to match. Not related to correctness, but would be good
+ // to figure out and fix.
+ // return signalPassFailure();
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling to invocations ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndPromotePass() {
+ return std::make_unique<SPIRVTileAndPromotePass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 4e0bc80..fd2ae6b 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -225,6 +225,7 @@
// tensors and hoist them out of loop nests. So after it we have
// loop-carried vectors, not loop-carried tensors anymore.
linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
+ linalg::hoistRedundantVectorTransfers(funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After hoisting vector transfers ---\n";
diff --git a/iree/compiler/Codegen/SPIRV/Utils.cpp b/iree/compiler/Codegen/SPIRV/Utils.cpp
index 7b91b13..b3c1dc4 100644
--- a/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
diff --git a/iree/compiler/Codegen/SPIRV/Utils.h b/iree/compiler/Codegen/SPIRV/Utils.h
index c213d56..edcfe26 100644
--- a/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/iree/compiler/Codegen/SPIRV/Utils.h
@@ -20,8 +20,6 @@
namespace mlir {
namespace iree_compiler {
-static constexpr int kNumGPUDims = 3;
-
/// Given an operation, return the `spv.target_env` attribute.
spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op);
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 2fa72eb..7016e86 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -32,10 +32,12 @@
"create_fast_slow_path.mlir",
"distribute_to_invocations.mlir",
"pipeline_matmul_cooperative_ops.mlir",
+ "pipeline_matmul_promotion.mlir",
"pipeline_matmul_vectorization.mlir",
"tile_and_distribute.mlir",
"tile_and_distribute_scatter.mlir",
"tile_and_distribute_sort.mlir",
+ "tile_and_promote_matmul.mlir",
"tile_and_vectorize_batch_matmul.mlir",
"tile_and_vectorize_conv.mlir",
"tile_and_vectorize_matmul.mlir",
@@ -47,10 +49,6 @@
"vectorize_tensor_pad.mlir",
],
include = ["*.mlir"],
- # TODO(b/203528778) reenable
- exclude = [
- "promote_workgroup_memory.mlir",
- ],
),
tools = [
"//iree/tools:iree-opt",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 2714f0b..9ee9691 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -27,10 +27,12 @@
"create_fast_slow_path.mlir"
"distribute_to_invocations.mlir"
"pipeline_matmul_cooperative_ops.mlir"
+ "pipeline_matmul_promotion.mlir"
"pipeline_matmul_vectorization.mlir"
"tile_and_distribute.mlir"
"tile_and_distribute_scatter.mlir"
"tile_and_distribute_sort.mlir"
+ "tile_and_promote_matmul.mlir"
"tile_and_vectorize_batch_matmul.mlir"
"tile_and_vectorize_conv.mlir"
"tile_and_vectorize_matmul.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
new file mode 100644
index 0000000..bf27da5
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
@@ -0,0 +1,62 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-spirv-pipeline))' %s | FileCheck %s
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>
+ ]>
+]>
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+hal.executable @matmul_128x256x64 {
+ hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+ spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU,
+ {max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[65535, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>}> {
+ hal.executable.entry_point public @matmul_128x256x64 ordinal(0) layout(#executable_layout)
+ builtin.module {
+ func.func @matmul_128x256x64() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:64x256xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x256xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<128x64xf32>
+ %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:64x256xf32> -> tensor<64x256xf32>
+ %6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
+ %7 = linalg.init_tensor [128, 256] : tensor<128x256xf32>
+ %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x256xf32>) -> tensor<128x256xf32>
+ %9 = linalg.matmul ins(%4, %5 : tensor<128x64xf32>, tensor<64x256xf32>) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32>
+ %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%9, %6 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%7 : tensor<128x256xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %11 = arith.divf %arg0, %arg1 : f32
+ linalg.yield %11 : f32
+ } -> tensor<128x256xf32>
+ flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : tensor<128x256xf32> -> !flow.dispatch.tensor<writeonly:128x256xf32>
+ return
+ }
+ }
+ }
+}
+
+// CHECK: spv.GlobalVariable @{{.+}} : !spv.ptr<!spv.struct<(!spv.array<1024 x vector<4xf32>, stride=16>)>, Workgroup>
+// CHECK: spv.GlobalVariable @{{.+}} : !spv.ptr<!spv.struct<(!spv.array<1024 x vector<4xf32>, stride=16>)>, Workgroup>
+
+// CHECK-LABEL: spv.func @matmul_128x256x64
+
+// CHECK: spv.mlir.loop
+// CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+// CHECK-COUNT-4: spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
+// CHECK-COUNT-4: spv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
+// CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+// CHECK-COUNT-512: spv.GLSL.Fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
+// CHECK: spv.mlir.merge
+// CHECK-COUNT-16: spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
+// CHECK-COUNT-16: spv.FDiv %{{.+}}, %{{.+}} : vector<4xf32>
+// CHECK-COUNT-16: spv.Store "StorageBuffer" %{{.+}}, %{{.+}} : vector<4xf32>
diff --git a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
deleted file mode 100644
index 252ceb3..0000000
--- a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
+++ /dev/null
@@ -1,135 +0,0 @@
-// TODO(antiagainst): Fix promotion to workgroup and enable the test.
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' | FileCheck %s
-
-hal.executable private @matmul_promote_workgroup_memory {
- hal.interface @io {
- hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
- hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
- hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
- }
- hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @matmul_promote_workgroup_memory interface(@io) {
- workgroup_size = [16: index, 8: index, 1: index]
- }
- builtin.module {
- func.func @matmul_promote_workgroup_memory() {
- %c32 = arith.constant 32 : index
- %c50 = arith.constant 50 : index
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<25x50xf32>
- %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<50x75xf32>
- %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<25x75xf32>
- %3 = hal.interface.workgroup.id[0] : index
- %4 = hal.interface.workgroup.id[1] : index
- scf.for %arg0 = %c0 to %c50 step %c32 {
- %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%4]
- %6 = affine.min affine_map<()[s0] -> (8, s0 * -8 + 25)>()[%4]
- %7 = affine.min affine_map<(d0) -> (32, -d0 + 50)>(%arg0)
- %8 = memref.subview %0[%5, %arg0] [%6, %7] [1, 1] : memref<25x50xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>
- %9 = affine.min affine_map<(d0) -> (32, -d0 + 50)>(%arg0)
- %10 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%3]
- %11 = affine.min affine_map<()[s0] -> (16, s0 * -16 + 75)>()[%3]
- %12 = memref.subview %1[%arg0, %10] [%9, %11] [1, 1] : memref<50x75xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>
- %13 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%4]
- %14 = affine.min affine_map<()[s0] -> (8, s0 * -8 + 25)>()[%4]
- %15 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%3]
- %16 = affine.min affine_map<()[s0] -> (16, s0 * -16 + 75)>()[%3]
- %17 = memref.subview %2[%13, %15] [%14, %16] [1, 1] : memref<25x75xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>
- linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering_config = {tileSizes = [[8, 16, 32], [], [1, 1, 0]]}}
- ins(%8, %12 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
- outs(%17 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
- }
- return
- }
- hal.interface private @io {
- hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
- hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
- hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
- }
- }
- }
-}
-
-// CHECK-LABEL: func @matmul_promote_workgroup_memory()
-// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external
-// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external
-// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external
-// CHECK-DAG: %[[ALLOC1:.+]] = memref.alloc() : memref<8x32xf32, 3>
-// CHECK-DAG: %[[ALLOC2:.+]] = memref.alloc() : memref<32x16xf32, 3>
-// CHECK: scf.for
-// CHECK: %[[ARG0SV:.+]] = memref.subview %[[ARG0]]
-// CHECK: %[[ARG1SV:.+]] = memref.subview %[[ARG1]]
-// CHECK: %[[RET0SV:.+]] = memref.subview %[[RET0]]
-// CHECK: %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC1]]
-// CHECK: %[[SUBVIEW2:.+]] = memref.subview %[[ALLOC2]]
-// CHECK: linalg.generic(%[[ARG0SV]], %[[SUBVIEW1]])
-// CHECK-SAME: "copy_to_workgroup_memory"
-// CHECK: linalg.generic(%[[ARG1SV]], %[[SUBVIEW2]])
-// CHECK-SAME: "copy_to_workgroup_memory"
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK-DAG: memref.subview %[[SUBVIEW1]]
-// CHECK-DAG: memref.subview %[[SUBVIEW2]]
-// CHECK-DAG: memref.subview %[[RET0SV]]
-
-// -----
-
-hal.executable private @conv_promote_workgroup_memory {
- hal.interface @io {
- hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
- hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
- hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
- }
- hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
- hal.executable.entry_point @conv_promote_workgroup_memory interface(@io) {
- workgroup_size = [32: index, 4: index, 1: index]
- }
- builtin.module {
- func.func @conv_promote_workgroup_memory() {
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<3x4x6x14xf32>
- %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<2x15x14x6xf32>
- %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<2x13x11x14xf32>
- %3 = hal.interface.workgroup.id[0] : index
- %4 = hal.interface.workgroup.id[1] : index
- %5 = hal.interface.workgroup.id[2] : index
- %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
- %7 = affine.min affine_map<()[s0] -> (6, s0 * -4 + 15)>()[%4]
- %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
- %9 = affine.min affine_map<()[s0] -> (35, s0 * -32 + 14)>()[%3]
- %10 = memref.subview %1[%5, %6, %8, 0] [1, %7, %9, 6] [1, 1, 1, 1] : memref<2x15x14x6xf32> to memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>
- %11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
- %12 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 13)>()[%4]
- %13 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
- %14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 11)>()[%3]
- %15 = memref.subview %2[%5, %11, %13, 0] [1, %12, %14, 14] [1, 1, 1, 1] : memref<2x13x11x14xf32> to memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>
- linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", lowering_config = {tileSizes = [[0, 1, 4, 32], [], [0, 1, 1, 1]]}, dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
- ins(%10, %0 : memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>, memref<3x4x6x14xf32>)
- outs(%15 : memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>)
- return
- }
- hal.interface private @io {
- hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
- hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
- hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
- }
- }
- }
-}
-
-// CHECK-LABEL: func @conv_promote_workgroup_memory()
-// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external
-// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external
-// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external
-// CHECK-DAG: %[[ALLOC1:.+]] = memref.alloc() : memref<1x6x35x6xf32, 3>
-// CHECK: %[[ARG1SV:.+]] = memref.subview %[[ARG1]]
-// CHECK: %[[RET0SV:.+]] = memref.subview %[[RET0]]
-// CHECK: %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC1]]
-// CHECK: linalg.generic(%[[ARG1SV]], %[[SUBVIEW1]])
-// CHECK-SAME: "copy_to_workgroup_memory"
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK-DAG: memref.subview %[[SUBVIEW1]]
-// CHECK-DAG: memref.subview %[[ARG0]]
-// CHECK-DAG: memref.subview %[[RET0SV]]
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
new file mode 100644
index 0000000..548c498
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
@@ -0,0 +1,130 @@
+// RUN: iree-opt -split-input-file -mlir-print-local-scope -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-promote))))' -cse %s | FileCheck %s
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>
+ ]>
+]>
+#config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [16, 4], [0, 0, 32]]>
+
+hal.executable @matmul_256x1024x128 {
+ hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+ spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU,
+ {max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>}> {
+ hal.executable.entry_point public @matmul_256x1024x128 ordinal(0) layout(#executable_layout) {
+ translation_info = #iree_codegen.translation_info<SPIRVVectorizeWithWorkgroupMemory, workload_per_wg = [128, 128]>,
+ workgroup_size = [32 : index, 8 : index, 1 : index]
+ }
+ builtin.module {
+ func.func @matmul_256x1024x128() {
+ %c1024 = arith.constant 1024 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<256x128xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x1024xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<256x1024xf32>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : memref<256x1024xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_y]
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_y]
+ scf.for %arg0 = %4 to %c256 step %5 {
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
+ scf.for %arg1 = %6 to %c1024 step %7 {
+ %8 = memref.subview %2[%arg0, %arg1] [128, 128] [1, 1] : memref<256x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+ %9 = memref.subview %0[%arg0, 0] [128, 128] [1, 1] : memref<256x128xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>>
+ %10 = memref.subview %1[0, %arg1] [128, 128] [1, 1] : memref<128x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+ %11 = memref.subview %3[%arg0, %arg1] [128, 128] [1, 1] : memref<256x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+ linalg.fill {lowering_config = #config}
+ ins(%cst : f32) outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ linalg.matmul {lowering_config = #config}
+ ins(%9, %10 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>>, memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%11, %8 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+ attrs = {lowering_config = #config} {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+ %12 = arith.divf %arg2, %arg3 : f32
+ linalg.yield %12 : f32
+ }
+ }
+ }
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func @matmul_256x1024x128()
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+
+// CHECK-DAG: %[[MEM_A:.+]] = memref.alloc() : memref<128x32xf32, 3>
+// CHECK-DAG: %[[MEM_B:.+]] = memref.alloc() : memref<32x128xf32, 3>
+
+// CHECK-DAG: %[[BUFFER_A:.+]] = hal.interface.binding.subspan set(0) binding(0) {{.+}} : memref<256x128xf32>
+// CHECK-DAG: %[[BUFFER_B:.+]] = hal.interface.binding.subspan set(0) binding(1) {{.+}} : memref<128x1024xf32>
+// CHECK-DAG: %[[BUFFER_C:.+]] = hal.interface.binding.subspan set(0) binding(3) {{.+}} : memref<256x1024xf32>
+// CHECK-DAG: %[[BUFFER_D:.+]] = hal.interface.binding.subspan set(0) binding(2) {{.+}} : memref<256x1024xf32>
+
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[D:.+]] = memref.subview %[[BUFFER_D]]
+// CHECK: %[[A:.+]] = memref.subview %[[BUFFER_A]]
+// CHECK: %[[B:.+]] = memref.subview %[[BUFFER_B]]
+// CHECK: %[[C:.+]] = memref.subview %[[BUFFER_C]]
+// CHECK: %[[T_ID_X:.+]] = gpu.thread_id x
+// CHECK: %[[T_DIM_X:.+]] = gpu.block_dim x
+// CHECK: %[[T_ID_Y:.+]] = gpu.thread_id y
+// CHECK: %[[T_DIM_Y:.+]] = gpu.block_dim y
+// CHECK: %[[T_OFFSET_Y:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%[[T_ID_Y]]]
+// CHECK: %[[T_SIZE_Y:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%[[T_DIM_Y]]]
+
+// CHECK: scf.for %[[T_IV_Y:.+]] =
+// CHECK: scf.for %[[T_IV_X:.+]] =
+// CHECK: %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]] [16, 4] [1, 1]
+// CHECK: linalg.fill
+// CHECK-SAME: outs(%[[VIEW_C]]
+
+// CHECK: scf.for %[[T_IV_Y:.+]] = %[[C0]] to %[[C128]] step %[[C32]] {
+// CHECK: %[[VIEW_A:.+]] = memref.subview %[[A]][0, %[[T_IV_Y]]] [128, 32]
+// CHECK: %[[VIEW_B:.+]] = memref.subview %[[B]][%[[T_IV_Y]], 0] [32, 128]
+
+// CHECK: gpu.barrier
+// CHECK: memref.copy %[[VIEW_A]], %[[MEM_A]]
+// CHECK-SAME: __internal_linalg_transform__ = "copy_to_workgroup_memory"
+// CHECK: memref.copy %[[VIEW_B]], %[[MEM_B]]
+// CHECK-SAME: __internal_linalg_transform__ = "copy_to_workgroup_memory"
+// CHECK: gpu.barrier
+
+// CHECK: scf.for %[[T_IV_Y:.+]] =
+// CHECK: scf.for %[[T_IV_X:.+]] =
+// CHECK: %[[VIEW_A:.+]] = memref.subview %[[MEM_A]][%[[T_IV_Y]], 0] [16, 32]
+// CHECK: %[[VIEW_B:.+]] = memref.subview %[[MEM_B]][0, %[[T_IV_X]]] [32, 4]
+// CHECK: %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]] [16, 4]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[VIEW_A]], %[[VIEW_B]]
+// CHECK-SAME: outs(%[[VIEW_C]]
+
+// CHECK: scf.for %[[T_IV_Y:.+]] =
+// CHECK: scf.for %[[T_IV_X:.+]] =
+// CHECK: %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]]
+// CHECK: %[[VIEW_D:.+]] = memref.subview %[[D]][%[[T_IV_Y]], %[[T_IV_X]]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[VIEW_C]], %[[VIEW_D]]
+// CHECK-SAME: outs(%[[VIEW_C]]
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index 83067fb..0279fcd 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -55,10 +55,10 @@
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK: scf.for %[[IV_Y:.+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %{{.+}}[%[[IV_Y]], 0] [1, 4]
+// CHECK: %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
// CHECK: scf.for %[[IV_X:.+]] = %[[C0]] to %[[C128]] step %[[C4]] iter_args(%[[ACC_TILE:.+]] =
-// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %{{.+}}[%[[IV_Y]], 0] [1, 4]
// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %{{.+}}[0, %[[IV_X]]] [4, 4]
-// CHECK: %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
// CHECK: %[[RHS_0_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
// CHECK: %[[RHS_1_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C1]], %[[C0]]], %[[PAD]]
// CHECK: %[[RHS_2_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C2]], %[[C0]]], %[[PAD]]
diff --git a/iree/compiler/Codegen/Utils/BUILD b/iree/compiler/Codegen/Utils/BUILD
index 19bddad..e7510ff 100644
--- a/iree/compiler/Codegen/Utils/BUILD
+++ b/iree/compiler/Codegen/Utils/BUILD
@@ -15,10 +15,12 @@
cc_library(
name = "Utils",
srcs = [
+ "GPUUtils.cpp",
"MarkerUtils.cpp",
"Utils.cpp",
],
hdrs = [
+ "GPUUtils.h",
"MarkerUtils.h",
"Utils.h",
],
@@ -29,6 +31,7 @@
"//iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Codegen/Utils/CMakeLists.txt b/iree/compiler/Codegen/Utils/CMakeLists.txt
index de6f31c..ce45f4e 100644
--- a/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -14,14 +14,17 @@
NAME
Utils
HDRS
+ "GPUUtils.h"
"MarkerUtils.h"
"Utils.h"
SRCS
+ "GPUUtils.cpp"
"MarkerUtils.cpp"
"Utils.cpp"
DEPS
IREELinalgExtDialect
LLVMSupport
+ MLIRGPUOps
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
diff --git a/iree/compiler/Codegen/Utils/GPUUtils.cpp b/iree/compiler/Codegen/Utils/GPUUtils.cpp
new file mode 100644
index 0000000..09efafd
--- /dev/null
+++ b/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -0,0 +1,118 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
+ mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
+ llvm::ArrayRef<int64_t> workgroupSize) {
+ assert(numDims <= kNumGPUDims);
+ llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
+ std::array<gpu::Dimension, kNumGPUDims> dimAttr{
+ gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
+ mlir::Type indexType = builder.getIndexType();
+ for (unsigned i = 0; i < numDims; ++i) {
+ procInfo[numDims - 1 - i] = {
+ builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]),
+ builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getIndexAttr(workgroupSize[i]))};
+ }
+ return procInfo;
+}
+
+llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+ mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
+ llvm::ArrayRef<int64_t> numSubgroups) {
+ assert(numDims <= kNumGPUDims);
+ llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
+ std::array<gpu::Dimension, kNumGPUDims> dimAttr{
+ gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
+ mlir::Type indexType = builder.getIndexType();
+ for (unsigned i = 0; i < numDims; ++i) {
+ mlir::Value subgroupId =
+ builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
+ if (i == 0) {
+ mlir::AffineExpr d0 = builder.getAffineDimExpr(0);
+ subgroupId = mlir::makeComposedAffineApply(
+ builder, loc, d0.floorDiv(builder.getAffineConstantExpr(kWarpSize)),
+ {subgroupId});
+ }
+ procInfo[numDims - 1 - i] = {
+ subgroupId, builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getIndexAttr(numSubgroups[i]))};
+ }
+ return procInfo;
+}
+
+std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp) {
+ std::array<int64_t, 3> workgroupSize;
+ auto entryPointOp = mlir::iree_compiler::getEntryPoint(funcOp);
+ llvm::Optional<mlir::ArrayAttr> workgroupSizeAttr =
+ entryPointOp.workgroup_size();
+ assert(workgroupSizeAttr.hasValue());
+ for (auto it : llvm::enumerate(workgroupSizeAttr.getValue())) {
+ workgroupSize[it.index()] =
+ it.value().cast<mlir::IntegerAttr>().getValue().getZExtValue();
+ }
+ return workgroupSize;
+}
+
+bool canPerformVectorAccessUsingAllThreads(ArrayRef<int64_t> shape,
+ int64_t threadCount,
+ int64_t vectorSize) {
+ // Verify that each dimension of the shape can be distributed on the
+ // threads
+ int64_t threadsAvailable = threadCount;
+ for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
+ int64_t numElementPerThread = dim.index() == 0 ? vectorSize : 1;
+ int64_t numThreads = dim.value() / numElementPerThread;
+ if (numThreads == 0) return false;
+ numThreads = std::min(numThreads, threadsAvailable);
+ if (threadsAvailable % numThreads != 0) return false;
+ threadsAvailable = threadsAvailable / numThreads;
+ if (threadsAvailable == 1) break;
+ }
+ return threadsAvailable == 1;
+}
+
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+ memref::SubViewOp subview,
+ ArrayRef<Value> sizeBounds,
+ DataLayout &) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
+ if (!funcOp) return llvm::None;
+
+ // The subview size bounds are expected to be constant; they specify the shape
+ // of the allocation.
+ SmallVector<int64_t, 2> shape;
+ for (Value bound : sizeBounds) {
+ APInt value;
+ if (!matchPattern(bound, m_ConstantInt(&value))) return llvm::None;
+ shape.push_back(value.getSExtValue());
+ }
+
+ builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+ auto type = MemRefType::get(shape, subview.getType().getElementType(), {},
+ gpu::GPUDialect::getWorkgroupAddressSpace());
+ Value buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+ return buffer;
+}
+
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/) {
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Utils/GPUUtils.h b/iree/compiler/Codegen/Utils/GPUUtils.h
new file mode 100644
index 0000000..49593c3
--- /dev/null
+++ b/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
+#define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
+
+#include "iree/compiler/Codegen/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static constexpr int32_t kNumGPUDims = 3;
+static constexpr int32_t kWarpSize = 32;
+
+llvm::SmallVector<linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
+ OpBuilder &builder, Location loc, unsigned numDims,
+ llvm::ArrayRef<int64_t> workgroupSize);
+
+/// Compute subgroup ID. CUDA doesn't have a subgroupId equivalent so we are are
+/// computing the subgroup ID based on the threadID.
+/// When tiling to warp we assume each warp is full and we pick a workgroup
+/// size so that `workgroupSize.x % warpSize == 0`. This is why we can have
+/// warpId = { threadId.x / warpSize, threadId.y, threadId.z }.
+llvm::SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+ OpBuilder &builder, Location loc, unsigned numDims,
+ llvm::ArrayRef<int64_t> numSubgroups);
+
+/// Return the workgroup size associated to the funcOp entry point.
+std::array<int64_t, 3> getWorkgroupSize(func::FuncOp funcOp);
+
+/// Return true if we can use all threads to perform vectorized load/store of
+/// the given `shape`.
+bool canPerformVectorAccessUsingAllThreads(ArrayRef<int64_t> shape,
+ int64_t threadCount,
+ int64_t vectorSize);
+
+/// Allocate GPU workgroup memory matching the given `subview`. If there are
+/// dynamic dimensions, the bounds are in `sizeBounds`.
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+ memref::SubViewOp subview,
+ ArrayRef<Value> sizeBounds,
+ DataLayout &);
+
+/// Deallocate GPU workgroup memory behind `buffer`.
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value buffer);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_