[CUDA] Make sure all linalg ops are distributed to thread (#8548)
Fix bug due to some convolution ops not being distributed to threads
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 8f579b6..9cd86f6 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -167,13 +167,8 @@
// FFT doesn't support second level of tiling yet.
return success(!isa<IREE::LinalgExt::FftOp>(op));
}).setMatchByDefault();
- linalg::TilingPatterns<
- linalg::MatmulOp, linalg::FillOp, linalg::BatchMatmulOp,
- linalg::GenericOp, linalg::Conv2DNhwcHwcfOp,
- linalg::DepthwiseConv2DNhwcHwcOp, linalg::DepthwiseConv2DNhwcHwcmOp,
- linalg::PoolingNhwcMaxOp, linalg::PoolingNhwcMinOp,
- linalg::PoolingNhwcSumOp>::insert(patterns, tilingOptions, f);
- patterns.insert<IREE::LinalgExt::TiledOpInterfaceTilingPattern>(
+ patterns.insert<linalg::LinalgTilingPattern,
+ IREE::LinalgExt::TiledOpInterfaceTilingPattern>(
context, tilingOptions, f);
}
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
index aba5248..847e138 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
@@ -293,3 +293,70 @@
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.*}} : memref<1000xf32>) outs(%{{.*}} : memref<f32>)
// CHECK-SAME: lowering_config = #[[CONFIG]]
+
+// -----
+
+#translation = #iree_codegen.translation_info<LLVMGPUVectorize, workload_per_wg = [256, 1, 1]>
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+hal.executable private @conv_dispatch {
+ hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.entry_point @conv_dispatch layout(#executable_layout) {
+ translation_info = #translation,
+ workgroup_size = [64 : index, 1 : index, 1 : index]
+ }
+ builtin.module {
+ func @conv_dispatch() {
+ %c56 = arith.constant 56 : index
+ %c64 = arith.constant 64 : index
+ %c802816 = arith.constant 802816 : index
+ %c41664 = arith.constant 41664 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<1x64x56x56xf32>
+ memref.assume_alignment %0, 64 : memref<1x64x56x56xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c41664) alignment(64) : memref<64x64x1x1xf32>
+ memref.assume_alignment %1, 64 : memref<64x64x1x1xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c802816) alignment(64) : memref<1x64x56x56xf32>
+ memref.assume_alignment %2, 64 : memref<1x64x56x56xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ scf.for %arg0 = %workgroup_id_z to %c64 step %workgroup_count_z {
+ scf.for %arg1 = %workgroup_id_y to %c56 step %workgroup_count_y {
+ %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x]
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x]
+ scf.for %arg2 = %3 to %c56 step %4 {
+ %5 = affine.min affine_map<(d0) -> (256, -d0 + 56)>(%arg2)
+ %6 = memref.subview %0[0, 0, %arg1, %arg2] [1, 64, 1, %5] [1, 1, 1, 1] : memref<1x64x56x56xf32> to memref<1x64x1x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 200704 + s0 + d1 * 3136 + d2 * 56 + d3)>>
+ %7 = memref.subview %1[%arg0, 0, 0, 0] [1, 64, 1, 1] [1, 1, 1, 1] : memref<64x64x1x1xf32> to memref<1x64x1x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 + d2 + d3)>>
+ %8 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, 1, 1, %5] [1, 1, 1, 1] : memref<1x64x56x56xf32> to memref<1x1x1x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 200704 + s0 + d1 * 3136 + d2 * 56 + d3)>>
+ linalg.fill(%cst, %8) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 1, 256, 4, 4, 4]]>} : f32, memref<1x1x1x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 200704 + s0 + d1 * 3136 + d2 * 56 + d3)>>
+ linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 1, 256, 4, 4, 4]]>, strides = dense<1> : vector<2xi64>} ins(%6, %7 : memref<1x64x1x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 200704 + s0 + d1 * 3136 + d2 * 56 + d3)>>, memref<1x64x1x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 + d2 + d3)>>) outs(%8 : memref<1x1x1x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 200704 + s0 + d1 * 3136 + d2 * 56 + d3)>>)
+ }
+ }
+ }
+ return
+ }
+ }
+ }
+}
+
+// Check that the convolution is distributed.
+// CHECK-LABEL: func @conv_dispatch
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: linalg.fill
+// CHECK: scf.for
+// CHECK: linalg.conv_2d_nchw_fchw