blob: f4871a838af29e8ffae08edd106125f0c51f3e9b [file] [log] [blame]
// RUN: iree-dialects-opt -split-input-file -iree-linalg-ext-to-loops %s | FileCheck %s
func @sort_1d(%arg0: memref<128xi32>) {
iree_linalg_ext.sort dimension(0)
outs(%arg0 : memref<128xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%0 = arith.cmpi sgt, %arg2, %arg3 : i32
iree_linalg_ext.yield %0 : i1
}
return
}
// CHECK-LABEL: func @sort_1d
// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]]
// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG2]]]
// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]]]
// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32
// CHECK: scf.if %[[COND]] {
// CHECK: } else {
// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]]]
// CHECK: }
// -----
func @sort_2d(%arg0: memref<16x32xi32>) {
iree_linalg_ext.sort dimension(0)
outs(%arg0 : memref<16x32xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%0 = arith.cmpi sgt, %arg2, %arg3 : i32
iree_linalg_ext.yield %0 : i1
}
return
}
// CHECK-LABEL: func @sort_2d
// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C15]] step %[[C1]]
// CHECK: %[[T1:.+]] = arith.addi %[[ARG3]], %[[C1]] : index
// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG3]], %[[ARG2]]]
// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]], %[[ARG2]]]
// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32
// CHECK: scf.if %[[COND]] {
// CHECK: } else {
// CHECK: %[[T2:.+]] = arith.addi %[[ARG3]], %[[C1]] : index
// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG3]], %[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]], %[[ARG2]]]
// CHECK: }
// -----
func @sort_multi(%arg0: memref<128xf32>, %arg1: memref<128xi32>) {
iree_linalg_ext.sort
dimension(0)
outs(%arg0, %arg1 : memref<128xf32>, memref<128xi32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: i32, %arg5: i32): // no predecessors
%0 = arith.cmpf ogt, %arg2, %arg3 : f32
iree_linalg_ext.yield %0 : i1
}
return
}
// CHECK-LABEL: func @sort_multi
// CHECK-SAME: %[[BUF1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUF2:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]]
// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
// CHECK: %[[V1:.+]] = memref.load %[[BUF1]][%[[ARG2]]]
// CHECK: %[[V2:.+]] = memref.load %[[BUF1]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUF2]][%[[ARG2]]]
// CHECK: %[[V4:.+]] = memref.load %[[BUF2]][%[[T1]]]
// CHECK: %[[COND:.+]] = arith.cmpf ogt, %[[V1]], %[[V2]] : f32
// CHECK: scf.if %[[COND]] {
// CHECK: } else {
// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
// CHECK: memref.store %[[V2]], %[[BUF1]][%[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUF1]][%[[T2]]]
// CHECK: memref.store %[[V4]], %[[BUF2]][%[[ARG2]]]
// CHECK: memref.store %[[V3]], %[[BUF2]][%[[T2]]]
// CHECK: }
// -----
func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
iree_linalg_ext.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
}
return
}
// CHECK: func @scatter_update_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
// CHECK: }
// CHECK: }
// -----
func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
iree_linalg_ext.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
iree_linalg_ext.yield %0 : i32
}
return
}
// CHECK: func @scatter_add_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
iree_linalg_ext.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
iree_linalg_ext.scatter unique_indices(true)
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
}
return
}
// CHECK: func @scatter_update_slice_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
iree_linalg_ext.scatter
unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
outs(%arg0 : memref<2x64x12xf32>) {
^bb0(%arg3: f32, %arg4: f32):
iree_linalg_ext.yield %arg4 : f32
}
return
}
// CHECK-LABEL: @scatter_partial_slices
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-DAG: %[[C1:.+]] = arith.constant
// CHECK-DAG: %[[C2:.+]] = arith.constant
// CHECK-DAG: %[[C12:.+]] = arith.constant
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
// -----
func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
%stage = arith.constant 1 : index
iree_linalg_ext.fft
ins(%stage: index)
outs(%real, %imag: memref<16xf32>, memref<16xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
// CHECK: func @fft_1D
// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[COEFF:.+]] = arith.constant -3.14159274 : f32
// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C2]]
// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[C2]], %[[C1]] : index
// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[K]]] [%[[HM]]] [1]
// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[K]]] [%[[HM]]] [1]
// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[R_OFFSET]]] [%[[HM]]] [1]
// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[R_OFFSET]]] [%[[HM]]] [1]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
// CHECK: ^bb0(%[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
//
// Compute exp coeff.
// CHECK: %[[J_IDX:.+]] = linalg.index 0 : index
// CHECK: %[[J_I32:.+]] = arith.index_cast %[[J_IDX]] : index to i32
// CHECK: %[[J_F32:.+]] = arith.sitofp %[[J_I32]] : i32 to f32
// CHECK: %[[EXP_COEF:.+]] = arith.mulf %[[J_F32]], %[[COEFF]] : f32
// CHECK: %[[W_REAL:.+]] = math.cos %[[EXP_COEF]]
// CHECK: %[[W_IMAG:.+]] = math.sin %[[EXP_COEF]]
//
// Compute "t = w * a[k + j + mh]" by expanding
// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]]
// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]]
// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]]
// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]]
// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]]
// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]]
//
// Compute the results.
// u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]]
// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]]
// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]]
// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]]
// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
// -----
func @fft_2D(%real: memref<?x16xf32>, %imag: memref<?x16xf32>) {
%stage = arith.constant 2 : index
iree_linalg_ext.fft
ins(%stage: index)
outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @fft_2D(
// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C4]]
// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[C4]], %[[C1]] : index
// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
//
// The computation is bascially the same, and they are
// checked above. Here only checks the different part.
// CHECK: %{{.+}} = linalg.index 1 : index
// -----
func @fft_2D_coef_buf(%real: memref<?x16xf32>, %imag: memref<?x16xf32>,
%coef_real: memref<1xf32>, %coef_imag: memref<1xf32>) {
%stage = arith.constant 1 : index
iree_linalg_ext.fft
ins(%stage, %coef_real, %coef_imag: index, memref<1xf32>, memref<1xf32>)
outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @fft_2D_coef_buf
// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C2]]
// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[C2]], %[[C1]] : index
// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[COEF_REAL]], %[[COEF_IMAG]]
// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
// CHECK: ^bb0(%[[W_REAL:.+]]: f32, %[[W_IMAG:.+]]: f32, %[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
// Compute "t = w * a[k + j + mh]" by expanding
// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]]
// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]]
// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]]
// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]]
// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]]
// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]]
//
// Compute the results.
// u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]]
// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]]
// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]]
// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]]
// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
// -----
func @reverse_dim_0(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>) {
iree_linalg_ext.reverse
dimensions(dense<0> : tensor<1xi64>)
ins(%arg0 : memref<?x?xi32>)
outs(%arg1 : memref<?x?xi32>)
return
}
// CHECK-LABEL: func @reverse_dim_0
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]]
// CHECK: %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index
// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index
// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
iree_linalg_ext.scan dimension(0) inclusive(true)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_1d_inclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
iree_linalg_ext.scan dimension(0) inclusive(false)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_1d_exclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref<i32>
// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
%t0 = memref.alloc() : memref<32xi32>
iree_linalg_ext.scan dimension(0) inclusive(true)
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_2d
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]]
// CHECK: }