[Codegen] Avoid setting anchors for reads used directly by contractions (#16499)
Transfer_read ops used directly by a contraction (i.e. without a copy to
shared memory in between) should take the layout of the contraction op.
This is common for cases where the initial values of the accumulator in
a `linalg.matmul` is read from memory instead of just being a zerofill.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index 61aed35..6fd788d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -142,6 +143,23 @@
void setTransferReadAnchor(MLIRContext *context,
VectorLayoutAnalysis &analysis,
vector::TransferReadOp transfer) {
+
+ // Get the forward slice of the transfer to approximate whether it will take
+ // the layout of a contraction instead. Transfer_read ops used directly by a
+ // contraction (i.e. without a copy to shared memory in between) should take
+ // the layout of the contraction op. This is common for cases where the
+ // initial values of the accumulator in a linalg.matmul is read from memory
+ // instead of just being a zerofill.
+ SetVector<Operation *> forwardSlice;
+ ForwardSliceOptions options;
+ getForwardSlice(transfer.getResult(), &forwardSlice, options);
+
+ if (llvm::any_of(forwardSlice, [](Operation *op) {
+ return llvm::isa<vector::ContractionOp>(op);
+ })) {
+ return;
+ }
+
// TODO: Support masking.
if (transfer.getMask()) {
return;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir
index 7d75f81..219d1ad 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir
@@ -71,11 +71,13 @@
%alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%cst = arith.constant 0.000000e+00 : f16
- %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+ %cst_f32 = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
- %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) {
+ %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]}
+ : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x16xf32>
+ %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) {
%6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf16>
%7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x16xf16>
vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
@@ -100,9 +102,11 @@
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func.func @matmul_256x256x256
-// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32>
// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
+// CHECK: %[[INIT_READ:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<4x1xf32>
+// CHECK: %[[INIT_TRANSP:.+]] = vector.transpose %[[INIT_READ]], [1, 0]
+// CHECK: %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_TRANSP]]
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16>