blob: 11dc49f758a7aaa653272b64443a89f510f1ab52 [file]
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d1, d6, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d1, d6, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d5, d6)>
#map5 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#map6 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2)>
func.func @gather_attention() {
%0 = util.unfoldable_constant dense<1.000000e+00> : tensor<32x4x2x32xf16>
%1 = flow.tensor.dynamic_constant dense<5.000000e-01> : tensor<2x4x16x32xf16> -> tensor<?x4x16x32xf16>
%2 = flow.tensor.dynamic_constant dense<1.500000e+00> : tensor<2x4x16x32xf16> -> tensor<?x4x16x32xf16>
%3 = flow.tensor.dynamic_constant dense<1> : tensor<32x2xi64> -> tensor<32x?xi64>
%4 = flow.tensor.dynamic_constant dense<1.500000e+00> : tensor<32x4x2x2x16xf16> -> tensor<32x4x2x?x16xf16>
%c1 = arith.constant 1 : index
%dim = tensor.dim %3, %c1 : tensor<32x?xi64>
%5 = tensor.empty(%dim) : tensor<32x?x4x16x32xf16>
%6 = iree_linalg_ext.gather dimension_map = [0] ins(%1, %3 : tensor<?x4x16x32xf16>, tensor<32x?xi64>) outs(%5 : tensor<32x?x4x16x32xf16>) -> tensor<32x?x4x16x32xf16>
%7 = iree_linalg_ext.gather dimension_map = [0] ins(%2, %3 : tensor<?x4x16x32xf16>, tensor<32x?xi64>) outs(%5 : tensor<32x?x4x16x32xf16>) -> tensor<32x?x4x16x32xf16>
%cst = arith.constant 1.767580e-01 : f16
%8 = tensor.empty() : tensor<32x4x2x32xf32>
%9 = tensor.empty() : tensor<32x4x2xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant -3.40282347E+38 : f32
%cst_2 = arith.constant 0.000000e+00 : f32
%10 = linalg.fill ins(%cst_0 : f32) outs(%8 : tensor<32x4x2x32xf32>) -> tensor<32x4x2x32xf32>
%11 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<32x4x2xf32>) -> tensor<32x4x2xf32>
%12 = linalg.fill ins(%cst_2 : f32) outs(%9 : tensor<32x4x2xf32>) -> tensor<32x4x2xf32>
%13:3 = iree_linalg_ext.online_attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5, #map6, #map6]} ins(%0, %6, %7, %cst, %4 : tensor<32x4x2x32xf16>, tensor<32x?x4x16x32xf16>, tensor<32x?x4x16x32xf16>, f16, tensor<32x4x2x?x16xf16>) outs(%10, %11, %12 : tensor<32x4x2x32xf32>, tensor<32x4x2xf32>, tensor<32x4x2xf32>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<32x4x2x32xf32>, tensor<32x4x2xf32>, tensor<32x4x2xf32>
%out_e = tensor.empty() : tensor<32x4x2x32xf32>
%cst_one = arith.constant 1.000000e+00 : f32
%normalized = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%13#0, %13#2 : tensor<32x4x2x32xf32>, tensor<32x4x2xf32>)
outs(%out_e : tensor<32x4x2x32xf32>) {
^bb0(%a: f32, %s: f32, %_: f32):
%inv = arith.divf %cst_one, %s : f32
%v = arith.mulf %a, %inv : f32
linalg.yield %v : f32
} -> tensor<32x4x2x32xf32>
check.expect_almost_eq_const(%normalized, dense<1.500000e+00> : tensor<32x4x2x32xf32>) : tensor<32x4x2x32xf32>
return
}