blob: 0ddf5d1d2f7a8f8d23f473a3e5508c5271b5811e [file] [log] [blame]
// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s0) + s1, s0)>
// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)>
#map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0, d1) -> (d0 - d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0) -> (d0)>
module {
// CHECK-LABEL: func @static_tile
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
func @static_tile(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.dim %arg2, %c0 : memref<?xf32>
%1 = affine.apply #map0(%0)[%arg0]
// CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
// CHECK: %[[group:.*]] = async.create_group {{.*}}: !async.group
// CHECK: scf.for %[[IV:.*]] = {{.*}}
// CHECK: %[[token:.*]] = async.execute {
// CHECK: subview
// CHECK: subview
// CHECK: linalg.generic
// CHECK: async.yield
// CHECK: }
// CHECK: async.add_to_group %[[token]], %[[group]] : !async.token
// CHECK: }
// CHECK: async.await_all %[[group]]
iree_linalg_ext.in_parallel %1 -> () {
^bb0(%arg3: index): // no predecessors
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.apply #map2(%0, %3)
%5 = affine.min #map3(%4, %arg0)
%6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, offset:?, strides:[?]>
%7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, offset:?, strides:[?]>
linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
ins(%7 : memref<?xf32, offset:?, strides:[?]>) outs(%6 : memref<?xf32, offset:?, strides:[?]>) {
^bb0(%arg4: f32, %arg5: f32): // no predecessors
%9 = arith.mulf %arg4, %cst : f32
linalg.yield %9 : f32
}
iree_linalg_ext.perform_concurrently {
}
}
return
}
pdl.pattern @match_iree_linalg_ext_in_parallel : benefit(1) {
%0 = operands
%1 = types
%2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
rewrite %2 with "iree_linalg_transform.apply"
}
iree_linalg_transform.sequence {
%0 = match @match_iree_linalg_ext_in_parallel
%1 = rewrite_iree_linalg_ext_in_parallel_to_async %0
}
}