| // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> { |
| %0 = tensor.empty() : tensor<20x4096x64xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<20x4096x64xf16> |
| %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> |
| util.return %expanded : tensor<2x10x4096x64xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_static( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16) |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> |
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_static_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x4096x64xf16> { |
| %0 = tensor.empty() : tensor<20x4096x64xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<20x4096x64xf16> |
| %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> |
| util.return %expanded : tensor<2x10x4096x64xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_static_masked( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>) |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] |
| // CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 1024] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> { |
| %0 = tensor.empty() : tensor<20x4096x64xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<20x4096x64xf16> |
| %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16> |
| util.return %expanded : tensor<2x10x2048x2x2x32xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_expand_all( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16) |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16> |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16> |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16> |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> |
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_expand_all_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x2048x2x2x32xf16> { |
| %0 = tensor.empty() : tensor<20x4096x64xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4: tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<20x4096x64xf16> |
| %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16> |
| util.return %expanded : tensor<2x10x2048x2x2x32xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_expand_all_masked( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>) |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16> |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16> |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16> |
| // CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 1024] : tensor<20x4096x1024xf16> into tensor<2x10x2048x2x1024xf16> |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16) -> tensor<2x?x?x?xf16> { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c2 = arith.constant 2 : index |
| %c3 = arith.constant 3 : index |
| %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16> |
| %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16> |
| %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16> |
| %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16> |
| %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16> |
| %0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<?x?x?xf16> |
| %split = arith.divsi %d0, %c2 : index |
| %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] |
| : tensor<?x?x?xf16> into tensor<2x?x?x?xf16> |
| util.return %expanded : tensor<2x?x?x?xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_dynamic( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16) |
| // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index |
| // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index |
| // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index |
| // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] |
| // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] |
| // CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] |
| // CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] |
| // CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]] |
| // CHECK-DAG: %[[VAL:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]] |
| // CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]] |
| // CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]] |
| // CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]] |
| // CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]] |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]] |
| // CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]] |
| // CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]] |
| // CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]] |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> |
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>) -> tensor<2x?x?x?xf16> { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c2 = arith.constant 2 : index |
| %c3 = arith.constant 3 : index |
| %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16> |
| %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16> |
| %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16> |
| %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16> |
| %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16> |
| %0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16> |
| %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16, tensor<?x?x?xf16>) outs(%0 : tensor<?x?x?xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<?x?x?xf16> |
| %split = arith.divsi %d0, %c2 : index |
| %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] |
| : tensor<?x?x?xf16> into tensor<2x?x?x?xf16> |
| util.return %expanded : tensor<2x?x?x?xf16> |
| } |
| |
| //CHECK-LABEL: func public @attention_dynamic_masked( |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16> |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>) |
| // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index |
| // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index |
| // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index |
| // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] |
| // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] |
| // CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] |
| // CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] |
| // CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]] |
| // CHECK-DAG: %[[VAL:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] |
| // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16> |
| // CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]] |
| // CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]] |
| // CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]] |
| // CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]] |
| // CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]] |
| // CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]] |
| // CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]] |
| // CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]] |
| // CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]] |
| // CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]] |
| // CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]] |
| // CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]] |
| // CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]] |
| // CHECK-DAG: %[[SPLIT3:.+]] = arith.divui %[[D10]], %[[C2]] |
| // CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : |
| // CHECK-SAME: outs(%[[EMPTY]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { |
| %13 = tensor.empty() : tensor<4x32x64x128xf16> |
| %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %17 = tensor.empty() : tensor<128x64x128xf16> |
| %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<128x64x128xf16> |
| util.return %18 : tensor<128x64x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @sink_through_attention |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : |
| // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| // CHECK: util.return %[[RET]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> |
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @sink_through_attention_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16, %3 : tensor<4x32x64x64xf16>) -> (tensor<128x64x128xf16>) { |
| %13 = tensor.empty() : tensor<4x32x64x128xf16> |
| %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %collapsed_15 = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<4x32x64x64xf16> into tensor<128x64x64xf16> |
| %17 = tensor.empty() : tensor<128x64x128xf16> |
| %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst, %collapsed_15 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<128x64x128xf16> |
| util.return %18 : tensor<128x64x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @sink_through_attention_masked |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]] : |
| // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| // CHECK: util.return %[[RET]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @sink_single_collapse(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { |
| %13 = tensor.empty() : tensor<4x32x64x128xf16> |
| %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %17 = tensor.empty() : tensor<128x64x128xf16> |
| %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<128x64x128xf16> |
| util.return %18 : tensor<128x64x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @sink_single_collapse |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-DAG: %[[EXPANDED1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] |
| // CHECK-DAG: %[[EXPANDED2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]] : |
| // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| // CHECK: util.return %[[RET]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> |
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16, %3 : tensor<128x64x64xf16>) -> (tensor<128x64x128xf16>) { |
| %13 = tensor.empty() : tensor<4x32x64x128xf16> |
| %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| %17 = tensor.empty() : tensor<128x64x128xf16> |
| %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %1, %2, %cst, %3 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<128x64x128xf16> |
| util.return %18 : tensor<128x64x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @sink_single_collapse_masked |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[EXPANDED1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] |
| // CHECK-DAG: %[[EXPANDED2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] |
| // CHECK-DAG: %[[EXPANDED3:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 64] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> |
| // CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] : |
| // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> |
| // CHECK: util.return %[[RET]] |
| |
| // ----- |
| |
| #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| |
| util.func public @dont_sink_through_k2(%0 : tensor<128x64x1x1x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { |
| %13 = tensor.empty() : tensor<4x32x64x128xf16> |
| %collapsed_12 = tensor.collapse_shape %0 [[0], [1, 2, 3], [4]] : tensor<128x64x1x1x128xf16> into tensor<128x64x128xf16> |
| %17 = tensor.empty() : tensor<128x64x128xf16> |
| %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) { |
| ^bb0(%score: f16): |
| iree_linalg_ext.yield %score: f16 |
| } -> tensor<128x64x128xf16> |
| util.return %18 : tensor<128x64x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @dont_sink_through_k2 |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG3:.+]]: f16 |
| // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] |
| // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention |
| // CHECK-SAME: indexing_maps = |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()> |
| // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> |
| // CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] : |
| // CHECK: util.return %[[ATTENTION]] |
| |
| // ----- |
| |
| util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> { |
| %collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor<?x2x16x4x128xf16> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) { |
| ^bb0(%arg7: f16, %arg8: f16): |
| iree_linalg_ext.yield %arg7 : f16 |
| } -> tensor<?x2x16x4x128xf16> |
| util.return %1 : tensor<?x2x16x4x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_updates |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK: %[[INDICES:.+]] = tensor.expand_shape |
| // CHECK-SAME: tensor<?x1xi32> into tensor<4x?x1xi32> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[ARG0]], %[[INDICES]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| // ----- |
| |
| util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> { |
| %collapsed = tensor.collapse_shape %arg0[[0, 1, 2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x16x4x128xf16> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) { |
| ^bb0(%arg7: f16, %arg8: f16): |
| iree_linalg_ext.yield %arg7 : f16 |
| } -> tensor<10x16x4x128xf16> |
| util.return %1 : tensor<10x16x4x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_updates_partial |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x2x2x1xi32> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[ARG0]], %[[INDICES]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| // ----- |
| |
| util.func @scatter_collapse_original(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> { |
| %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2], [3, 4], [5, 6]] : tensor<?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) { |
| ^bb0(%arg6: f16, %arg7: f16): |
| iree_linalg_ext.yield %arg6 : f16 |
| } -> tensor<?x32x8x128xf16> |
| util.return %1 : tensor<?x32x8x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_original |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] |
| // CHECK-SAME: tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] |
| // CHECK: util.return %[[COLLAPSE]] |
| |
| // ----- |
| |
| util.func @scatter_original_noop(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x80x2x32x8x128xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> { |
| %collapsed = tensor.collapse_shape %arg2 [[0, 1, 2], [3], [4], [5]] : tensor<?x80x2x32x8x128xf16> into tensor<?x32x8x128xf16> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) { |
| ^bb0(%arg6: f16, %arg7: f16): |
| iree_linalg_ext.yield %arg6 : f16 |
| } -> tensor<?x32x8x128xf16> |
| util.return %1 : tensor<?x32x8x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_original_noop |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG2]] |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[COLLAPSE]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| |
| // ----- |
| |
| util.func @scatter_collapse_original_partial(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<5x?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> { |
| %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<5x?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) { |
| ^bb0(%arg6: f16, %arg7: f16): |
| iree_linalg_ext.yield %arg6 : f16 |
| } -> tensor<?x32x8x128xf16> |
| util.return %1 : tensor<?x32x8x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_original_partial |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16> |
| // TODO(IanWood1): fix this so the collapse folds with the expand |
| // CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor<?x32x8x128xf16> into tensor<?x2x16x4x2x64x2xf16> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[ORIGINAL]] |
| // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] |
| // CHECK: util.return %[[COLLAPSE]] |
| |
| // ----- |
| |
| util.func @scatter_collapse_indices(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> { |
| %collapsed = tensor.collapse_shape %arg1[[0, 1], [2]] : tensor<4x?x1xi32> into tensor<?x1xi32> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) { |
| ^bb0(%arg7: f16, %arg8: f16): |
| iree_linalg_ext.yield %arg7 : f16 |
| } -> tensor<?x2x16x4x128xf16> |
| util.return %1 : tensor<?x2x16x4x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_indices |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| // ----- |
| |
| util.func @scatter_collapse_indices_partial(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> { |
| %collapsed = tensor.collapse_shape %arg1[[0, 1], [2, 3]] : tensor<4x?x1x1xi32> into tensor<?x1xi32> |
| %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) { |
| ^bb0(%arg7: f16, %arg8: f16): |
| iree_linalg_ext.yield %arg7 : f16 |
| } -> tensor<?x2x16x4x128xf16> |
| util.return %1 : tensor<?x2x16x4x128xf16> |
| } |
| |
| // CHECK-LABEL: util.func public @scatter_collapse_indices_partial |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16> |
| // CHECK-DAG: %[[ORIGINAL:.+]] = tensor.collapse_shape {{.*}} tensor<4x?x1x1xi32> into tensor<4x?x1xi32> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[UPDATES]], %[[ORIGINAL]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| // ----- |
| |
| util.func public @scatter_collapse(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x4x32xf16> { |
| %c0 = arith.constant 0 : index |
| %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) { |
| ^bb0(%arg3: f16, %arg4: f16): |
| iree_linalg_ext.yield %arg3 : f16 |
| } -> tensor<?x2x16x4x128xf16> |
| %dim = tensor.dim %arg0, %c0 : tensor<?x2x16x4x128xf16> |
| %expanded = tensor.expand_shape %0 [[0], [1], [2], [3], [4, 5]] output_shape [%dim, 2, 16, 4, 4, 32] : tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16> |
| util.return %expanded : tensor<?x2x16x4x4x32xf16> |
| } |
| // CHECK-LABEL: util.func public @scatter_collapse |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16> |
| // CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape %[[ARG2]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16> |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[ORIGINAL]] |
| // CHECK: util.return %[[SCATTER]] |
| |
| // ----- |
| |
| util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x1xi32>, %arg2: tensor<128xf16>) -> tensor<4x4x4x2xf16> { |
| %c0 = arith.constant 0 : index |
| %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<10xf16>, tensor<10x1xi32>) outs(%arg2 : tensor<128xf16>) { |
| ^bb0(%arg3: f16, %arg4: f16): |
| iree_linalg_ext.yield %arg3 : f16 |
| } -> tensor<128xf16> |
| %expanded = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape[4, 4, 4, 2] : tensor<128xf16> into tensor<4x4x4x2xf16> |
| util.return %expanded : tensor<4x4x4x2xf16> |
| } |
| // CHECK-LABEL: util.func public @scatter_collapse_noop |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: |
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: |
| // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] |
| // CHECK-SAME: outs(%[[ARG2]] |
| // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SCATTER]] |
| // CHECK: util.return %[[EXPANDED]] |