| // RUN: iree-opt --split-input-file --torch-iree-tm-tensor-to-linalg-ext %s | FileCheck %s |
| |
| // https://github.com/openxla/iree/issues/14916 |
| // XFAIL: * |
| |
| func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %arg2: tensor<5x2x3x4xf32>, %arg3: tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xf32>) { |
| %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>) outs(%arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> |
| return %0 : tensor<5x2x3x4xf32> |
| } |
| |
| // CHECK-LABEL: func.func @attention( |
| // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, |
| // CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> { |
| // CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> |
| // CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> |
| // CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> |
| // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x3x4xf32> |
| // CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]] : tensor<10x3x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>) outs(%[[EMPTY]] : tensor<10x3x4xf32>) -> tensor<10x3x4xf32> |
| // CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x3x4xf32> into tensor<5x2x3x4xf32> |
| // CHECK: return %[[RET]] : tensor<5x2x3x4xf32> |
| |
| // ----- |
| func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %arg2: tensor<5x2x3x4xf32>, %arg3: tensor<5x2x8x4xf32>) -> (tensor<5x2x8x4xf32>) { |
| %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>) outs(%arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> |
| return %0 : tensor<5x2x8x4xf32> |
| } |
| |
| // CHECK-LABEL: func.func @attention( |
| // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, |
| // CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> { |
| // CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x8x4xf32> into tensor<10x8x4xf32> |
| // CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> |
| // CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> |
| // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x8x4xf32> |
| // CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]] : tensor<10x8x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>) outs(%[[EMPTY]] : tensor<10x8x4xf32>) -> tensor<10x8x4xf32> |
| // CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x8x4xf32> into tensor<5x2x8x4xf32> |
| // CHECK: return %[[RET]] : tensor<5x2x8x4xf32> |
| |
| // ----- |
| func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: tensor<1x3x4xf32>, %arg3: tensor<1x3x4xf32>) -> (tensor<1x3x4xf32>) { |
| %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>) outs(%arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> |
| return %0 : tensor<1x3x4xf32> |
| } |
| |
| // CHECK-LABEL: func.func @attention( |
| // CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>, |
| // CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> { |
| // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32> |
| // CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> |
| // CHECK: return %[[ATTN]] : tensor<1x3x4xf32> |