blob: a4b01aa5b629b202012a16cb5f9d2ee8a075e617 [file] [log] [blame]
Nicolas Vasilache0573f4f2022-12-06 22:03:15 +01001!in_tensor_t = tensor<8x64xf32>
2!out_tensor_t = tensor<8xf32>
3
4func.func @reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
5 %cst = arith.constant -0.000000e+00 : f32
6
7 %0 = tensor.empty() : !out_tensor_t
8 %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
9 %2 = tensor.empty() : !in_tensor_t
10 %3 = linalg.generic {
11 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
12 affine_map<(d0, d1) -> (d0, d1)>],
13 iterator_types = ["parallel", "parallel"]}
14 ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
15 ^bb0(%arg3: f32, %arg4: f32):
16 %4 = arith.addf %arg3, %arg3 : f32
17 %5 = arith.addf %4, %4 : f32
18 linalg.yield %5 : f32
19 } -> !in_tensor_t
20
21 %6 = linalg.generic {
22 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
23 affine_map<(d0, d1) -> (d0)>],
24 iterator_types = ["parallel", "reduction"]}
25 ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
26 ^bb0(%arg3: f32, %arg4: f32):
27 %4 = arith.addf %arg3, %arg4 : f32
28 linalg.yield %4 : f32
29 } -> !out_tensor_t
30
31 return %6 : !out_tensor_t
32}
33
34// RUN: iree-opt %s --iree-hal-target-backends=cuda \
35// RUN: --iree-abi-transformation-pipeline \
36// RUN: --iree-flow-transformation-pipeline \
37// RUN: --iree-stream-transformation-pipeline \
38// RUN: --iree-hal-configuration-pipeline | \
39// RUN: FileCheck %s --check-prefix=DISPATCH
40
41// RUN: iree-opt %s --iree-hal-target-backends=cuda \
42// RUN: --iree-abi-transformation-pipeline \
43// RUN: --iree-flow-transformation-pipeline \
44// RUN: --iree-stream-transformation-pipeline \
45// RUN: --iree-hal-configuration-pipeline | \
Ben Vanikf65c5cb2023-02-01 11:02:10 -080046// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))'
Nicolas Vasilache0573f4f2022-12-06 22:03:15 +010047// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/%S_codegen_spec.mlir | \
48// RUN: FileCheck %s
49
Nicolas Vasilache5b89a142023-05-10 12:04:31 +020050// RUN: iree-compile %s --iree-hal-target-backends=cuda | \
Ben Vanik9461d3b2023-04-18 16:39:25 -070051// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="8x64xf32=1" |\
Nicolas Vasilache0573f4f2022-12-06 22:03:15 +010052// RUN: FileCheck %s --check-prefix=EXEC
53
54// Check that both generics ended up in the same region.
55// DISPATCH: hal.executable.variant
56// DISPATCH: linalg.fill
57// DISPATCH-NOT: hal.executable.variant
58// DISPATCH: linalg.generic
59// DISPATCH-NOT: hal.executable.variant
60// DISPATCH: linalg.generic
61
62// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
63// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
64// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
65// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
66// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 128 : i64} : memref<1x2xf32, 3>
67// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
68// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
69// CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z
70
71// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
72
73// Distributed reduction: everyone loads, does the elementwise then 5 xor + addf expected
74// CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
75// CHECK: arith.addf
76// CHECK: arith.addf
77// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
78
79// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
80
81// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
82// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
83// CHECK: scf.if %[[CONDXIS0]]
84// CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
85// CHECK: gpu.barrier
86
87// Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
88// CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
89// TODO: cond eq 0 and cond ult 1 do not CSE atm.
90// CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1
91// CHECK: scf.if %[[CONXANDYARE0]] {
92// CHECK: vector.transfer_read
93// CHECK: vector.reduction <add>
94// CHECK: vector.transfer_write
95// CHECK: gpu.barrier
96// CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3>
97
98// EXEC: result[0]: hal.buffer_view
99// EXEC-NEXT: 8xf32=256 256 256 256 256 256 256 256