[spirv] Fix C matrix promotion with bufferization allocations (#11418)
For cases like linalg.matmul + arith.extf, bufferization inserts an
allocation because of the different element type of those two ops. Here
we already have the C matrix in workgroup memory; there is no need to
promote it again. Though we need to make sure to use the proper marker
so that we can insert barriers and enable following steps to distribute
to SIMT threads.
Co-authored-by: Thomas <thomasraoux@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 8096a2e..29a7372 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -82,18 +82,11 @@
static LogicalResult gpuCopyFn(OpBuilder &builder, Location loc, Value from,
Value to) {
- Optional<unsigned> workgroupMemorySpace =
- spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass::Workgroup);
auto fromType = from.getType().cast<MemRefType>();
auto toType = to.getType().cast<MemRefType>();
- bool needsBarrier = false;
- if (auto attr = fromType.getMemorySpace().dyn_cast_or_null<IntegerAttr>()) {
- if (attr.getInt() == *workgroupMemorySpace) needsBarrier = true;
- }
- if (auto attr = toType.getMemorySpace().dyn_cast_or_null<IntegerAttr>()) {
- if (attr.getInt() == *workgroupMemorySpace) needsBarrier = true;
- }
+ bool needsBarrier =
+ isInWorkgroupMemory(fromType) || isInWorkgroupMemory(toType);
if (needsBarrier) builder.create<gpu::BarrierOp>(loc);
Operation *copy = builder.create<memref::CopyOp>(loc, from, to);
if (needsBarrier) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index ae1c6a3..bfd7bfa 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -13,7 +13,6 @@
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Common/GPUPatterns.h"
-#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
@@ -27,7 +26,6 @@
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -137,6 +135,8 @@
void runOnOperation() override;
private:
+ /// Promotes C matrix to shared memory when necessary and returns success if
+ /// no error happens.
LogicalResult doPromoteCMatrix(func::FuncOp funcOp) const;
// Whether to promote C matrix to use shared memory.
@@ -292,28 +292,60 @@
MLIRContext *context = funcOp.getContext();
if (!promoteCMatrix) return success();
- // If there are no fused elementwise ops, we can avoid promoting C matrix.
SmallVector<Operation *> computeOps;
if (failed(getComputeOps(funcOp, computeOps)))
return funcOp.emitError("failed to get compute ops");
- unsigned count = 0;
- for (Operation *op : computeOps) {
- if (!isa<linalg::FillOp>(op)) ++count;
- }
- if (count <= 1) return success();
+ SmallVector<Operation *> linalgOps;
+ for (Operation *op : computeOps) {
+ if (isa<linalg::FillOp>(op)) continue; // Don't care
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ linalgOps.push_back(linalgOp);
+ } else {
+ return funcOp.emitError("unknown compute op ") << *op;
+ }
+ }
+
+ if (linalgOps.size() > 2) {
+ return funcOp.emitError("unhandled multiple matmul/generic cases");
+ }
+
+ // If there are no fused elementwise ops, we can avoid promoting C matrix.
+ if (linalgOps.size() <= 1) return success();
+
+ linalg::LinalgOp matmulOp = linalgOps.front();
+ auto genericOp = cast<linalg::GenericOp>(*linalgOps.back());
+
+ auto matmulType =
+ matmulOp.getDpsInitOperand(0)->get().getType().cast<MemRefType>();
+ if (isInWorkgroupMemory(matmulType)) {
+ // The matmul output is already in shared memory. This can happen when
+ // bufferization decides an allocation is needed, e.g., matmul + arith.extf,
+ // where the output have different element types. For such cases, don't need
+ // to promote and propagate shared memory copy anymore. Just mark the
+ // following generic op for distribution accordingly.
+ setMarker(genericOp, getCopyToWorkgroupMemoryMarker());
+ return success();
+ }
+
+ // Finally do promote C matrix.
RewritePatternSet patterns(context);
populateContractPromotionPatterns(patterns, {2});
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return failure();
}
- propagateSharedMemoryCopy(funcOp);
-
LLVM_DEBUG({
llvm::dbgs() << "--- After promoting C matrix ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
+
+ propagateSharedMemoryCopy(funcOp);
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After propagating shared memory copy ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
index 981ce4e..bb65600 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -15,9 +15,9 @@
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Support/LogicalResult.h"
@@ -37,6 +37,15 @@
return config.getAs<spirv::TargetEnvAttr>(spirv::getTargetEnvAttrName());
}
+/// Returns true if the given MemRef is in workgroup memory.
+bool isInWorkgroupMemory(MemRefType memrefType) {
+ Optional<unsigned> workgroupMemorySpace =
+ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass::Workgroup);
+ if (auto attr = memrefType.getMemorySpace().dyn_cast_or_null<IntegerAttr>())
+ if (attr.getInt() == *workgroupMemorySpace) return true;
+ return false;
+}
+
llvm::Optional<int> getSPIRVSubgroupSize(func::FuncOp funcOp) {
auto moduleOp = funcOp->getParentOfType<ModuleOp>();
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps =
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
index f920886..3ade347 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
@@ -32,6 +32,9 @@
/// Returns the attribute name carrying information about distribution.
const char *getSPIRVDistributeAttrName();
+/// Returns true if the given MemRef is in workgroup memory.
+bool isInWorkgroupMemory(MemRefType memrefType);
+
/// Returns the tile sizes at the given `tilingLevel` for compute ops in
/// `funcOp`.
FailureOr<SmallVector<int64_t>> getSPIRVTileSize(func::FuncOp funcOp,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
index c486354..a781df8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_promote_cooperative_matrix.mlir
@@ -370,3 +370,110 @@
// PROMOTEC-SAME: __internal_linalg_transform__ = "workgroup_memory"
// PROMOTEC-NOT: gpu.barrier
// PROMOTEC-NOT: memref.copy
+
+// -----
+
+// No need to promote again with allocations from bufferization.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+#config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 128], [1, 32, 64], [0, 0, 0, 32], [1, 16, 16, 16]]>
+
+hal.executable @batch_matmul_f16_1x64x128x512 {
+ hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.6,
+ [Shader, Float16, StorageBuffer16BitAccess, StorageUniform16, CooperativeMatrixNV],
+ [SPV_NV_cooperative_matrix]>, AMD:DiscreteGPU,
+ #spirv.resource_limits<
+ cooperative_matrix_properties_nv = [
+ #spirv.coop_matrix_props<
+ a_type = f16, b_type = f16, c_type = f16, k_size = 16,
+ m_size = 16, n_size = 16, result_type = f16, scope = <Subgroup>>
+ ],
+ max_compute_shared_memory_size = 65536,
+ max_compute_workgroup_invocations = 1024,
+ max_compute_workgroup_size = [1024, 1024, 1024],
+ subgroup_size = 64>
+ >}> {
+ hal.executable.export public @batch_matmul_f16_1x64x128x512 ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>,
+ workgroup_size = [128 : index, 2 : index, 1 : index]
+ }
+ builtin.module {
+ func.func @batch_matmul_f16_1x64x128x512() {
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<1x4096x512xf16>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<1x512x4096xf16>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<1x4096x4096xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
+ scf.for %arg0 = %3 to %c4096 step %4 {
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
+ scf.for %arg1 = %5 to %c4096 step %6 {
+ %subview = memref.subview %2[0, %arg0, %arg1] [1, 64, 128] [1, 1, 1] : memref<1x4096x4096xf32> to memref<1x64x128xf32, strided<[16777216, 4096, 1], offset: ?>>
+ %subview_0 = memref.subview %0[0, %arg0, 0] [1, 64, 512] [1, 1, 1] : memref<1x4096x512xf16> to memref<1x64x512xf16, strided<[2097152, 512, 1], offset: ?>>
+ %subview_1 = memref.subview %1[0, 0, %arg1] [1, 512, 128] [1, 1, 1] : memref<1x512x4096xf16> to memref<1x512x128xf16, strided<[2097152, 4096, 1], offset: ?>>
+ %alloc = memref.alloc() {alignment = 128 : i64} : memref<1x64x128xf16, 3>
+ linalg.fill ins(%cst : f16) outs(%alloc : memref<1x64x128xf16, 3>)
+ linalg.batch_matmul {lowering_config = #config}
+ ins(%subview_0, %subview_1 : memref<1x64x512xf16, strided<[2097152, 512, 1], offset: ?>>, memref<1x512x128xf16, strided<[2097152, 4096, 1], offset: ?>>)
+ outs(%alloc : memref<1x64x128xf16, 3>)
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%alloc : memref<1x64x128xf16, 3>)
+ outs(%subview : memref<1x64x128xf32, strided<[16777216, 4096, 1], offset: ?>>) {
+ ^bb0(%in: f16, %out: f32):
+ %7 = arith.extf %in : f16 to f32
+ linalg.yield %7 : f32
+ }
+ }
+ }
+ return
+ }
+ }
+ }
+}
+
+// PROMOTEC-LABEL: func.func @batch_matmul_f16_1x64x128x512()
+
+// PROMOTEC-DAG: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<1x64x32xf16, 3>
+// PROMOTEC-DAG: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<1x32x128xf16, 3>
+// PROMOTEC-DAG: %[[C_ALLOC:.+]] = memref.alloc() {alignment = 128 : i64} : memref<1x64x128xf16, 3>
+
+// PROMOTEC: linalg.fill
+// PROMOTEC-SAME: __internal_linalg_transform__ = "workgroup_memory"
+// PROMOTEC-SAME: outs(%[[C_ALLOC]]
+
+// PROMOTEC: scf.for %{{.+}} = %c0 to %c512 step %c32 {
+// PROMOTEC: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][0, 0, 0] [%c1, %c64, %c32]
+// PROMOTEC: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, 0, 0] [%c1, %c32, %c128]
+// PROMOTEC: gpu.barrier
+// PROMOTEC: memref.copy %{{.+}}, %[[LHS_VIEW]]
+// PROMOTEC-SAME: __internal_linalg_transform__ = "copy_to_workgroup_memory"
+// PROMOTEC: memref.copy %{{.+}}, %[[RHS_VIEW]]
+// PROMOTEC-SAME: __internal_linalg_transform__ = "copy_to_workgroup_memory"
+// PROMOTEC: gpu.barrier
+// PROMOTEC: linalg.batch_matmul
+// PROMOTEC-SAME: __internal_linalg_transform__ = "workgroup_memory"
+// PROMOTEC-SAME: ins(%[[LHS_VIEW]], %[[RHS_VIEW]]
+// PROMOTEC-SAME: outs(%[[C_ALLOC]]
+// PROMOTEC: }
+// PROMOTEC: gpu.barrier
+// PROMOTEC: linalg.generic
+// PROMOTEC: ins(%[[C_ALLOC]]
+// PROMOTEC-SAME: __internal_linalg_transform__ = "copy_to_workgroup_memory"
+// PROMOTEC: gpu.barrier