blob: e3debb518db3a8537bb1f5246297514cf4078a8a [file] [log] [blame]
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-tm-tensor-to-linalg-ext))" %s | FileCheck %s
func.func @scatter_update(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors
tm_tensor.yield %update: i32
} -> tensor<8xi32>
return %0 : tensor<8xi32>
}
// CHECK-LABEL: func.func @scatter_update(
// CHECK-SAME: %[[ORIGINAL:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES:.*]]: tensor<3xi32>) -> tensor<8xi32> {
// CHECK: %[[SCATTER:.*]] = iree_linalg_ext.scatter
// CHECK-SAME: dimension_map = [0]
// CHECK-SAME: unique_indices(true)
// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<3xi32>, tensor<3x1xi32>)
// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<8xi32>) {
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
// CHECK: iree_linalg_ext.yield %[[ARG0]] : i32
// CHECK: } -> tensor<8xi32>
// CHECK: return %[[SCATTER:.*]] : tensor<8xi32>
// -----
func.func @scatter_add(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors
%add = arith.addi %orig, %update: i32
tm_tensor.yield %add: i32
} -> tensor<8xi32>
return %0 : tensor<8xi32>
}
// CHECK-LABEL: func.func @scatter_add(
// CHECK-SAME: %[[ORIGINAL:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES:.*]]: tensor<3xi32>) -> tensor<8xi32> {
// CHECK: %[[SCATTER:.*]] = iree_linalg_ext.scatter
// CHECK-SAME: unique_indices(true)
// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<3xi32>, tensor<3x1xi32>)
// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<8xi32>) {
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG0]] : i32
// CHECK: iree_linalg_ext.yield %[[ADD]] : i32
// CHECK: } -> tensor<8xi32>
// CHECK: return %[[SCATTER:.*]] : tensor<8xi32>
// -----