[Codegen] Remove old attention transformations (#18740)
These transformations on iree_linalg_ext.attention have been replaced by
transformations on iree_linalg_ext.online_attention. This path of
attention transformations has been deprecated for a long time and it's
time to delete it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 925106a..522404d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -21,8 +21,6 @@
"amdgpu_chained_matmul.mlir",
"amdgpu_contraction_distribution.mlir",
"amdgpu_set_anchor_layouts.mlir",
- "attention.mlir",
- "attention_mfma.mlir",
"conv_pipeline_test_cuda.mlir",
"conv_pipeline_test_rocm.mlir",
"convert_to_nvvm.mlir",
@@ -82,8 +80,6 @@
# tensor_dialect_*_spec is a an MLIR file that specifies a
# transformation, it needs to be included as data.
exclude = [
- "attention_mfma_transform_spec.mlir",
- "attention_transform_spec.mlir",
"transform_dialect_codegen_bufferize_spec.mlir",
"transform_dialect_codegen_foreach_to_gpu_spec.mlir",
"transform_dialect_codegen_vector_distribution_spec.mlir",
@@ -92,8 +88,6 @@
),
cfg = "//compiler:lit.cfg.py",
data = [
- "attention_mfma_transform_spec.mlir",
- "attention_transform_spec.mlir",
"transform_dialect_codegen_bufferize_spec.mlir",
"transform_dialect_codegen_foreach_to_gpu_spec.mlir",
"transform_dialect_codegen_vector_distribution_spec.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 0c27964..b707dfb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -17,8 +17,6 @@
"amdgpu_chained_matmul.mlir"
"amdgpu_contraction_distribution.mlir"
"amdgpu_set_anchor_layouts.mlir"
- "attention.mlir"
- "attention_mfma.mlir"
"cast_address_space_function.mlir"
"cast_type_to_fit_mma.mlir"
"config_matvec.mlir"
@@ -77,8 +75,6 @@
FileCheck
iree-opt
DATA
- attention_mfma_transform_spec.mlir
- attention_transform_spec.mlir
transform_dialect_codegen_bufferize_spec.mlir
transform_dialect_codegen_foreach_to_gpu_spec.mlir
transform_dialect_codegen_vector_distribution_spec.mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
deleted file mode 100644
index 3f2c090..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ /dev/null
@@ -1,175 +0,0 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(iree-transform-dialect-interpreter{library-file-name=%p/attention_transform_spec.mlir})' \
-// RUN: --iree-gpu-test-target=sm_60 | \
-// RUN: FileCheck --check-prefix=CHECK %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-func.func @_attention_dispatch_0() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 1.250000e-01 : f16
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
- %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
- %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
- %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
- %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
- %7 = tensor.empty() : tensor<192x1024x64xf16>
- %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
- flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
- return
-}
-
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<None workgroup_size = [4, 8, 4] subgroup_size = 32>
-// CHECK: func.func @_attention_dispatch_0()
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<-1.000000e+30> : vector<32xf32>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
-// CHECK-DAG: %[[CST_2:.+]] = arith.constant dense<0.000000e+00> : vector<32x128xf32>
-// CHECK-DAG: %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<64x32xf32>
-// CHECK-DAG: %[[CST_4:.+]] = arith.constant 0.000000e+00 : f16
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
-// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK-DAG: %[[CST_5:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-dAG: %[[CST_6:.+]] = arith.constant dense<1.802980e-01> : vector<128x64xf16>
-// CHECK: %[[D0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64)
-// CHECK-SAME: offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: memref.assume_alignment %[[D0]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[D1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64)
-// CHECK-SAME: offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: memref.assume_alignment %[[D1]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[D2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64)
-// CHECK-SAME: offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: memref.assume_alignment %[[D2]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[D3:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(3) alignment(64)
-// CHECK-SAME: offset(%[[C0]]) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: memref.assume_alignment %[[D3]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[WORKGROUP_ID_X:.+]] = hal.interface.workgroup.id[0] : index
-// CHECK: %[[WORKGROUP_ID_Y:.+]] = hal.interface.workgroup.id[1] : index
-// CHECK-DAG: %[[D4:.+]] = affine.apply #[[MAP]]()[%[[WORKGROUP_ID_Y]]]
-// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[D0]][%[[WORKGROUP_ID_X]], %[[D4]], 0] [1, 128, 64] [1, 1, 1] :
-// CHECK-SAME: memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[SUBVIEW_6:.+]] = memref.subview %[[D3]][%[[WORKGROUP_ID_X]], %[[D4]], 0] [1, 128, 64] [1, 1, 1] :
-// CHECK-SAME: memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<1x128x64xf16,
-// CHECK-SAME: #[[GPU:.+]].address_space<workgroup>>
-// CHECK: gpu.barrier
-// CHECK: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME: "parallel"]} ins(%[[SUBVIEW]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-// CHECK-SAME: outs(%[[ALLOC]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK: ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK: linalg.yield %[[IN]] : f16
-// CHECK: }
-// CHECK: gpu.barrier
-// CHECK: %[[ALLOC_7:.+]] = memref.alloc() {alignment = 64 : i64} : memref<1x128x64xf16,
-// CHECK-SAME: #[[GPU]].address_space<workgroup>>
-// CHECK: gpu.barrier
-// CHECK: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME: "parallel"]} ins(%[[SUBVIEW_6]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-// CHECK-SAME: outs(%[[ALLOC_7]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK: ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK: linalg.yield %[[IN]] : f16
-// CHECK: }
-// CHECK: gpu.barrier
-// CHECK-DAG: %[[D5:.+]] = gpu.thread_id x
-// CHECK-DAG: %[[D6:.+]] = gpu.thread_id y
-// CHECK-DAG: %[[D7:.+]] = gpu.thread_id z
-// CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]]()[%[[D5]], %[[D6]], %[[D7]]]
-// CHECK: %[[D9:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D8]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
-// CHECK-SAME: true]} : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>, vector<32x64xf16>
-// CHECK: %[[D11:.+]]:3 = scf.for %[[ARG0:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C128]]
-// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9_]+]] = %[[CST_0]], %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST_1]],
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<32xf32>, vector<32xf32>, vector<32x64xf32>) {
-// CHECK: %[[SUBVIEW_8:.+]] = memref.subview %[[D1]][%[[WORKGROUP_ID_X]], %[[ARG0]], 0] [1, 128, 64] [1, 1, 1]
-// CHECK-SAME: : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[SUBVIEW_9:.+]] = memref.subview %[[D2]][%[[WORKGROUP_ID_X]], %[[ARG0]], 0] [1, 128, 64] [1, 1, 1]
-// CHECK-SAME: : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK: %[[ALLOC_12:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16, #gpu.address_space<workgroup>>
-// CHECK: vector.transfer_write %[[CST_6:.+]], %[[ALLOC_12]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<128x64xf16>, memref<128x64xf16, #gpu.address_space<workgroup>>
-// CHECK: %[[ALLOC_10:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16,
-// CHECK-SAME: #[[GPU]].address_space<workgroup>>
-// CHECK: gpu.barrier
-// CHECK: linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME: ins(%[[SUBVIEW_8]] : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%[[ALLOC_10]] :
-// CHECK-SAME: memref<128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK: ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK: linalg.yield %[[IN]] : f16
-// CHECK: }
-// CHECK: gpu.barrier
-// CHECK: %[[ALLOC_11:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16,
-// CHECK-SAME: #[[GPU]].address_space<workgroup>>
-// CHECK: gpu.barrier
-// CHECK: linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME: ins(%[[SUBVIEW_9]] : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%[[ALLOC_11]] :
-// CHECK-SAME: memref<128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK: ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK: linalg.yield %[[IN]] : f16
-// CHECK: }
-// CHECK: gpu.barrier
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC_12]][%[[D8]], %[[C0]]], %{{.+}} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<32x64xf16>
-// CHECK: %[[MUL:.+]] = arith.mulf %[[D9]], %[[READ]] : vector<32x64xf16>
-// CHECK: %[[D13:.+]] = vector.transfer_read %[[ALLOC_10]][%[[C0]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
-// CHECK-SAME: true]} : memref<128x64xf16, #[[GPU]].address_space<workgroup>>, vector<128x64xf16>
-// CHECK: %[[D10:.+]] = arith.extf %[[MUL]] : vector<32x64xf16> to vector<32x64xf32>
-// CHECK: %[[D14:.+]] = arith.extf %[[D13]] : vector<128x64xf16> to vector<128x64xf32>
-// CHECK: %[[D15:.+]] = vector.contract {indexing_maps = [#[[MAP4]], #[[MAP5]], #[[MAP6]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "reduction"], kind = #[[VECTOR:.+]].kind<add>} %[[D10]], %[[D14]],
-// CHECK-SAME: %[[CST_2]] : vector<32x64xf32>, vector<128x64xf32> into vector<32x128xf32>
-// CHECK: %[[D16:.+]] = vector.multi_reduction <maximumf>, %[[D15]], %[[ARG1]] [1] : vector<32x128xf32> to
-// CHECK-SAME: vector<32xf32>
-// CHECK: %[[D17:.+]] = vector.broadcast %[[D16]] : vector<32xf32> to vector<128x32xf32>
-// CHECK: %[[D18:.+]] = vector.transpose %[[D17]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
-// CHECK: %[[D19:.+]] = arith.subf %[[D15]], %[[D18]] : vector<32x128xf32>
-// CHECK: %[[D20:.+]] = math.exp2 %[[D19]] : vector<32x128xf32>
-// CHECK: %[[D21:.+]] = arith.subf %[[ARG1]], %[[D16]] : vector<32xf32>
-// CHECK: %[[D22:.+]] = math.exp2 %[[D21]] : vector<32xf32>
-// CHECK: %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
-// CHECK: %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
-// CHECK-SAME: vector<32xf32>
-// CHECK: %[[D29:.+]] = arith.truncf %[[D20]] : vector<32x128xf32> to vector<32x128xf16>
-// CHECK: %[[D31:.+]] = vector.broadcast %[[D22]] : vector<32xf32> to vector<64x32xf32>
-// CHECK: %[[D33:.+]] = vector.transpose %[[D31]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
-// CHECK: %[[D34:.+]] = arith.mulf %[[D33]], %[[ARG3]] : vector<32x64xf32>
-// CHECK: %[[D36:.+]] = arith.extf %[[D29]] : vector<32x128xf16> to vector<32x128xf32>
-// CHECK: %[[D35:.+]] = vector.transfer_read %[[ALLOC_11]][%[[C0]], %[[C0]]], %[[CST_4]]
-// CHECK-SAME: {in_bounds = [true, true], permutation_map = #[[MAP7]]} : memref<128x64xf16, #[[GPU]].address_space<workgroup>>, vector<64x128xf16>
-// CHECK: %[[D37:.+]] = arith.extf %[[D35]] : vector<64x128xf16> to vector<64x128xf32>
-// CHECK: %[[D39:.+]] = vector.contract {indexing_maps = [#[[MAP4]], #[[MAP5]], #[[MAP6]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "reduction"], kind = #[[VECTOR]].kind<add>} %[[D36]], %[[D37]], %[[D34]] :
-// CHECK-SAME: vector<32x128xf32>, vector<64x128xf32> into vector<32x64xf32>
-// CHECK: scf.yield %[[D16]], %[[D24]], %[[D39]] : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
-// CHECK: }
-// CHECK: %[[DSCALE1:.+]] = vector.broadcast %[[D11]]#1 : vector<32xf32> to vector<64x32xf32>
-// CHECK: %[[DSCALE2:.+]] = arith.divf %[[CST_3]], %[[DSCALE1]] : vector<64x32xf32>
-// CHECK: %[[DSCALE3:.+]] = vector.transpose %[[DSCALE2]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
-// CHECK: %[[DSCALE4:.+]] = arith.mulf %[[DSCALE3]], %[[D11]]#2 : vector<32x64xf32>
-// CHECK: %[[D12:.+]] = arith.truncf %[[DSCALE4]] : vector<32x64xf32> to vector<32x64xf16>
-// CHECK: vector.transfer_write %[[D12]], %[[ALLOC_7]][%[[C0]], %[[D8]], %[[C0]]] {in_bounds = [true, true]} :
-// CHECK-SAME: vector<32x64xf16>, memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>
-// CHECK: gpu.barrier
-// CHECK: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME: "parallel"]} ins(%[[ALLOC_7]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>)
-// CHECK-SAME: outs(%[[SUBVIEW_6]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
-// CHECK: ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK: linalg.yield %[[IN]] : f16
-// CHECK: }
-// CHECK: gpu.barrier
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
deleted file mode 100644
index 109b107..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
+++ /dev/null
@@ -1,37 +0,0 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(iree-transform-dialect-interpreter{library-file-name=%p/attention_mfma_transform_spec.mlir})' \
-// RUN: --iree-gpu-test-target=gfx908 | \
-// RUN: FileCheck --check-prefix=CHECK %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-func.func @attention_dispatch_0_attention_16x16384x128xf16() {
- %c0 = arith.constant 0 : index
- %scale = arith.constant 0.08838834764 : f16
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
- %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
- %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
- %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
- %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
- %7 = tensor.empty() : tensor<16x16384x128xf16>
- %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
- flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
- return
-}
- // CHECK-NOT: vector.contract
- // CHECK-NOT: iree_vector_ext.to_simd
- // CHECK-NOT: iree_vector_ext.to_simt
- // CHECK-COUNT-8: vector.load {{.*}} : memref<16x16384x128xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
- // CHECK: scf.for {{.*}} = %c0 to %c16384 step %c64 {{.*}} -> (vector<2xf32>, vector<2xf32>, vector<8x2x4xf32>)
- // CHECK-COUNT-16: vector.load {{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, vector<8xf16>
- // CHECK-COUNT-128: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
deleted file mode 100644
index 9261f2e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
+++ /dev/null
@@ -1,203 +0,0 @@
-#layout = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op) {
- // Get attention op
- // ==========================================
- %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
- // Tile and distribute to workgroups
- // ==========================================
- %tiled_attention, %forall_grid =
- transform.structured.tile_using_forall %attention tile_sizes [1, 128]
- ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- // transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
- // Tile batch dimensions of attention
- // ==========================================
- %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %top_level_func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %top_level_func : !transform.any_op
-
- // Promote query and output operands
- // ==========================================
- //%attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- //%promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 3]
- // : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
- // Tile and decompose attention
- // ==========================================
- %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 {tile_size = 64} :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
- = transform.iree.decompose_tiled_attention %blocked_attention {tile_size = 64} :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
- !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Promote key and value operands
- // ==========================================
- %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Tile and fuse attention ops
- // ==========================================
- %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-
- %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_cse to %func : !transform.any_op
-
- %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_cse to %func : !transform.any_op
-
- %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
- // Distribute fills
- // ==========================================
-
- // Get all fills that haven't been distributed to warps.
- %fills = transform.include @get_undistributed_fills failures(propagate) (%variant_op) : (!transform.any_op) -> !transform.any_op
-
- %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Distribute last_truncate and fuse final_scaling into it
- // ==========================================
- %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- // Vectorize function
- // ==========================================
- transform.apply_patterns to %func {
- transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
- transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- } : !transform.any_op
- %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op)
-
- // Bufferization
- // ==========================================
- transform.apply_patterns to %func_3 {
- transform.apply_patterns.tensor.reassociative_reshape_folding
- transform.apply_patterns.canonicalization
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_3 : !transform.any_op
- transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
- transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op
- %func_4 = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op)
-
- // Step 5. Pre-process the contract and transfer ops to put it in the right form.
- // ===========================================================================
- %func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_2 {
- transform.apply_patterns.vector.fold_arith_extension
- } : !transform.any_op
-
- // Step 6. Post-bufferization vector distribution
- // ===========================================================================
- %func_7 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
- transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> ()
-
- transform.apply_patterns to %func_7 {
- transform.apply_patterns.memref.fold_memref_alias_ops
- } : !transform.any_op
- transform.iree.apply_licm %func_7 : !transform.any_op
- transform.apply_patterns to %func_7 {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_7 : !transform.any_op
- %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7
- : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_8 {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_8 : !transform.any_op
- transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
-
- // Apply chained matmul optimization.
- transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op)
-
- // Get the vector.contract ops.
- %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
- transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array<i64: 0, 1> } : !transform.any_op, !transform.any_param
- transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param
-
- %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %distribute_func_2 = transform.iree.amdgpu_distribute_vectors %distribute_func test_conversion : (!transform.any_op) -> !transform.any_op
-
- transform.apply_patterns to %distribute_func_2 {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %distribute_func_2 : !transform.any_op
-
- // Distribute shared memory copies
- // ==========================================
- %func_10 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> ()
- transform.apply_patterns to %func_10 {
- transform.apply_patterns.memref.fold_memref_alias_ops
- transform.apply_patterns.canonicalization
- transform.apply_patterns.linalg.tiling_canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_10 : !transform.any_op
-
- transform.yield
- }
- transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
- transform.match.operation_name %arg0 ["linalg.fill"] : !transform.any_op
- %0 = transform.get_parent_op %arg0 {allow_empty_results, nth_parent = 2 : i64, op_name = "scf.forall"} : (!transform.any_op) -> !transform.any_op
- transform.match.operation_empty %0 : !transform.any_op
- transform.yield %arg0 : !transform.any_op
- }
- transform.named_sequence @get_undistributed_fills(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
- %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.yield %0 : !transform.any_op
- }
-} //// module
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
deleted file mode 100644
index 2fc82bc..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ /dev/null
@@ -1,158 +0,0 @@
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op) {
- // Get attention op
- // ==========================================
- %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
- // Tile and distribute to workgroups
- // ==========================================
- %tiled_attention, %forall_grid =
- transform.structured.tile_using_forall %attention tile_sizes [1, 128]
- ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- // transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
- // Tile batch dimensions of attention
- // ==========================================
- %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %top_level_func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %top_level_func : !transform.any_op
-
- // Promote query and output operands
- // ==========================================
- %attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 4]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
- // Tile and decompose attention
- // ==========================================
- %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
- = transform.iree.decompose_tiled_attention %blocked_attention :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
- !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Promote key and value operands
- // ==========================================
- %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Tile and fuse attention ops
- // ==========================================
- %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-
- %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_cse to %func : !transform.any_op
-
- %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_cse to %func : !transform.any_op
-
- %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- // Distribute fills
- // ==========================================
- %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
- %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Distribute last_truncate and fuse final_scaling into it
- // ==========================================
- %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- transform.apply_patterns to %func {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func : !transform.any_op
-
- // Vectorize function
- // ==========================================
- transform.apply_patterns to %func {
- transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
- transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- } : !transform.any_op
- %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op)
-
- // Bufferization
- // ==========================================
- transform.apply_patterns to %func_3 {
- transform.apply_patterns.tensor.reassociative_reshape_folding
- transform.apply_patterns.canonicalization
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_3 : !transform.any_op
- transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
- transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op
- %func_4 = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op)
-
- // Step 5. Pre-process the contract and transfer ops to put it in the right form.
- // ===========================================================================
- %func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_2 {
- transform.apply_patterns.iree.prepare_vector_to_mma
- } : !transform.any_op
-
- // Step 6. Post-bufferization vector distribution
- // ===========================================================================
- %func_7 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
- transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [4, 8, 4] subgroup_size = 32 sync_after_distribution = false : (!transform.any_op) -> ()
-
- transform.apply_patterns to %func_7 {
- transform.apply_patterns.memref.fold_memref_alias_ops
- } : !transform.any_op
- transform.iree.apply_licm %func_7 : !transform.any_op
- transform.apply_patterns to %func_7 {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_7 : !transform.any_op
- %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7
- : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_8 {
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.apply_cse to %func_8 : !transform.any_op
- transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
- transform.yield
- }
-} //// module
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
index ef169fd..78e3ea5 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
@@ -43,7 +43,8 @@
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
- funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
+ funcPassManager.addPass(
+ IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index b350053..46c24e4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -27,30 +27,6 @@
// TileAndDecomposeAttention
//===---------------------------------------------------------------------===//
-DiagnosedSilenceableFailure LinalgExt::TileAttentionOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- SmallVector<Operation *> ops;
- LinalgExt::tileAttention(attentionOp, ops, rewriter, getTileSize());
- for (auto op : ops) {
- results.push_back(op);
- }
- return DiagnosedSilenceableFailure::success();
-}
-
-DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- SmallVector<Operation *> ops;
- LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter, getTileSize());
- for (auto op : ops) {
- results.push_back(op);
- }
- return DiagnosedSilenceableFailure::success();
-}
-
DiagnosedSilenceableFailure LinalgExt::DecomposeAggregateOp::applyToOne(
transform::TransformRewriter &rewriter,
linalg::AggregatedOpInterface aggregateOp,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
index 9e761c0..4417c80 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
@@ -12,81 +12,6 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
-def TileAttentionOp : Op<Transform_Dialect, "iree.tile_attention",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Target iree_linalg_ext.attention ops and tile them.
- This transform consumes the target handle and produces a result handle.
- }];
-
- let arguments = (
- ins TransformHandleTypeInterface:$target,
- OptionalAttr<I64Attr>:$tile_size
- );
- let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
- let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
- let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
- let builders = [
- OpBuilder<(ins "Value":$target)>
- ];
-
- let assemblyFormat = [{
- $target attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
-def DecomposeTiledAttentionOp : Op<Transform_Dialect, "iree.decompose_tiled_attention",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Target iree_linalg_ext.attention ops and decompose them.
- This transform consumes the target handle and produces a result handle.
- }];
-
- let arguments = (
- ins TransformHandleTypeInterface:$target,
- OptionalAttr<I64Attr>:$tile_size
- );
- let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
- let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
- let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
- let builders = [
- OpBuilder<(ins "Value":$target)>
- ];
-
- let assemblyFormat = [{
- $target attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
-
def DecomposeAggregateOp : Op<Transform_Dialect, "iree.decompose_aggregate_op",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index b1ff512..9f08e2a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -7,9 +7,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
@@ -19,335 +16,14 @@
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
namespace {
-
-// Computes a reduction along the rows of a 2d tensor of shape MxN
-// to produce a tensor of shape M
-template <typename T>
-static Value computeRowwiseReduction(Value a, Value output, Location loc,
- OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- SmallVector<utils::IteratorType> iteratorTypes{
- utils::IteratorType::parallel, utils::IteratorType::reduction};
- AffineMap id = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- // (d0, d1) -> (d0)
- auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps{id, rowMap};
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), a, output, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<T>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value computePartialSoftmax(Value qkTranspose, Value currentMax,
- Location loc, OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- // (d0, d1) -> (d0)
- auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, qkTranspose.getType(), ValueRange{currentMax}, qkTranspose,
- indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
- Value result = b.create<math::Exp2Op>(loc, diff);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-/// Return the scale factor for the new softmax maximum and add the generic to
-/// the provided list of operations.
-static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
- OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- SmallVector<utils::IteratorType> iteratorTypes(1,
- utils::IteratorType::parallel);
- auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
- SmallVector<AffineMap> indexingMaps(2, identityMap);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
- Value weight = b.create<math::Exp2Op>(loc, diff);
- b.create<linalg::YieldOp>(loc, weight);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value updateAndScale(Value scaleFactor, Value oldSum, Location loc,
- OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- SmallVector<utils::IteratorType> iteratorTypes(1,
- utils::IteratorType::parallel);
- auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
- SmallVector<AffineMap> indexingMaps(2, identityMap);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, oldSum.getType(), scaleFactor, oldSum, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value scaledOldSum = b.create<arith::MulFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, scaledOldSum);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value scalePartialSoftmax(Value softmax, Value inverseNewSum,
- Location loc, OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- // (d0, d1) -> (d0)
- auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, softmax.getType(), ValueRange{inverseNewSum}, softmax, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::MulFOp>(loc, args[1], args[0]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value scaleAccumulator(Value accumulator, Value scaleFactor,
- Location loc, OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- // (d0, d1) -> (d0)
- auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, accumulator.getType(), scaleFactor, accumulator, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value computeQKTranspose(Value query, Value key, Value output,
- Value zero, Location loc, OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- auto fillOp = builder.create<linalg::FillOp>(loc, ValueRange{zero}, output);
- ops.push_back(fillOp);
- Value acc = fillOp.result();
- auto matmulOp = builder.create<linalg::MatmulTransposeBOp>(
- loc, output.getType(), ValueRange{query, key}, acc);
- ops.push_back(matmulOp);
- return matmulOp.getResult(0);
-}
-
-static Value truncateToF16(Value input, Value output,
- SmallVectorImpl<Operation *> &ops,
- OpBuilder &builder, Location loc) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- SmallVector<AffineMap> indexingMaps{identityMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), ValueRange{input}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::TruncFOp>(loc, b.getF16Type(), args[0]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static std::tuple<Value, Value, Value>
-createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
- Value outputSlice, Value maxSlice, Value sumSlice,
- OpFoldResult sequenceTileLength,
- OpFoldResult keyValueTileLength, OpFoldResult headDimension,
- Type elementType, SmallVectorImpl<Operation *> &ops,
- bool transposeV, Location loc, OpBuilder &builder) {
-
- Type f32Type = builder.getF32Type();
- // Compute matmul(q, transpose(k))
- Value zero =
- builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(f32Type));
- SmallVector<OpFoldResult> resultShape{sequenceTileLength, keyValueTileLength};
- Value emptySquare =
- builder.create<tensor::EmptyOp>(loc, resultShape, f32Type);
- Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
- zero, loc, builder, ops);
-
- // Compute current statistics
- Value newMax = computeRowwiseReduction<arith::MaximumFOp>(
- qkTranspose, maxSlice, loc, builder, ops);
- Value partialSoftmax =
- computePartialSoftmax(qkTranspose, newMax, loc, builder, ops);
- Value scaleFactor = computeScaleFactor(maxSlice, newMax, loc, builder, ops);
- Value scaledOldSum = updateAndScale(scaleFactor, sumSlice, loc, builder, ops);
- Value newSum = computeRowwiseReduction<arith::AddFOp>(
- partialSoftmax, scaledOldSum, loc, builder, ops);
- if (elementType.isF16()) {
- Value empty =
- builder.create<tensor::EmptyOp>(loc, resultShape, builder.getF16Type());
- partialSoftmax = truncateToF16(partialSoftmax, empty, ops, builder, loc);
- }
-
- // Update accumulator
- Value scaledAcc =
- scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
-
- // Compute matmul(softmax, v)
- Operation *matmulOp;
- if (transposeV) {
- matmulOp = builder.create<linalg::MatmulTransposeBOp>(
- loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
- scaledAcc);
- } else {
- matmulOp = builder.create<linalg::MatmulOp>(
- loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
- scaledAcc);
- }
- ops.push_back(matmulOp);
- Value result = matmulOp->getResult(0);
- return std::make_tuple(result, newMax, newSum);
-}
-
-static Value scaleQuery(Value querySlice, Value scale, RewriterBase &rewriter) {
- ShapedType queryType = cast<ShapedType>(querySlice.getType());
- Location loc = querySlice.getLoc();
-
- // Create a fill op for scale.
- SmallVector<OpFoldResult> queryDims =
- tensor::getMixedSizes(rewriter, loc, querySlice);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, queryDims,
- queryType.getElementType());
- auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{scale}, empty)
- .getResult(0);
-
- // Create a generic op to multiply the query by the scale.
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto identityMap =
- AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
- SmallVector<AffineMap> indexingMaps(2, identityMap);
- auto scaleOp = rewriter.create<linalg::GenericOp>(
- loc, TypeRange{fillOp.getType()}, ValueRange{querySlice},
- ValueRange{fillOp}, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return scaleOp.getResult(0);
-}
-
-} // namespace
-
-/// This is an implementation of flash attention which
-/// is a tiled and fused implementation of the attention operator.
-/// The attention operator computes:
-/// matmul(softmax(matmul(Q, transpose(K))), V)
-/// where: Q is the query matrix [B x N x d]
-/// K is the key matrix [B x S x d]
-/// V is the value matrix [B x S x d]
-///
-/// The core algorithm is as follows:
-/// For each element in B,
-/// 1. Load a tile from the Q matrix of size T x d -> q
-/// 2. Initialize statistics: running_sum, running_max
-/// 3. for i = 0 to S with step T
-/// a. Load a tile from the K matrix of size T x d -> k
-/// b. Load a tile from the V matrix of size T x d -> v
-/// c. Compute matmul_transpose_b(q, k) -> qkT
-/// d. Compute max(max(qkT) along rows, old_max) -> new_max
-/// e. Compute curent estimate of softmax: exp(qKT - current_max) -> s
-/// f. Compute product of fixup and old_sum -> fsum
-/// g. Compute sum(sum(qkT) along rows, fsum) -> new_sum
-/// h. Compute 1.0 / new_sum -> inv_new_sum
-/// i. Compute softmax = softmax * inv_new_sum
-/// j. Truncate softmax to fp16
-/// k. Compute fsum * inv_new_sum * accumulator -> new_accumulator
-/// j. Compute matmul(s, v) and add new_accumulator
-///
-/// Decompose tiled iree_linalg_ext.attention op.
-/// TODO: Adopt decomposeOperation with this.
-void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter,
- std::optional<uint64_t> tileSize) {
- Location loc = tiledAttnOp.getLoc();
- Value keySlice = tiledAttnOp.getKey();
- Value valueSlice = tiledAttnOp.getValue();
- Value querySlice = tiledAttnOp.getQuery();
- Value tiledResult = tiledAttnOp.getOutput();
- Value max = *tiledAttnOp.getMax();
- Value sum = *tiledAttnOp.getSum();
-
- OpBuilder::InsertionGuard withinScfLoop(rewriter);
- rewriter.setInsertionPointAfter(tiledAttnOp);
- SmallVector<OpFoldResult> queryDimValues =
- tensor::getMixedSizes(rewriter, loc, querySlice);
- OpFoldResult headDimension = queryDimValues[1];
- OpFoldResult sequenceTileLength = queryDimValues[0];
- OpFoldResult keyValueTileLength =
- tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
-
- Type elementType = tiledAttnOp.getQueryType().getElementType();
-
- // Since we use exp2 for attention instead of the original exp, we have to
- // multiply the scale by log2(e). We use exp2 instead of exp as most GPUs
- // have better support for exp2.
- Value scale = tiledAttnOp.getScale();
- Value log2e = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(elementType, M_LOG2E));
- scale = rewriter.create<arith::MulFOp>(loc, scale, log2e);
-
- // In the original algorithm, the scaling is done after the softmax:
- // softmax(Q @ K.T * scale) @ V
- //
- // But, it is mathematically equivalent to do it on Q first and then multiply
- // it by K.T. This just allows us to do the scaling once, instead of each
- // iteration of the loop.
- querySlice = scaleQuery(querySlice, scale, rewriter);
- ops.push_back(querySlice.getDefiningOp());
-
- auto [result, newMax, newSum] = createAttentionBody(
- keySlice, valueSlice, querySlice, tiledResult, max, sum,
- sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
- tiledAttnOp.isTransposeV(), loc, rewriter);
-
- rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
-}
-
-namespace {
struct DecomposeAttentionPass final
: impl::DecomposeAttentionPassBase<DecomposeAttentionPass> {
using impl::DecomposeAttentionPassBase<
DecomposeAttentionPass>::DecomposeAttentionPassBase;
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<
- affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
- linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
+ registry.insert<IREE::LinalgExt::IREELinalgExtDialect,
+ linalg::LinalgDialect, tensor::TensorDialect>();
}
void runOnOperation() override;
};
@@ -356,14 +32,6 @@
void DecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
- std::optional<uint64_t> optionalTileSize{std::nullopt};
- if (tileSize.hasValue()) {
- optionalTileSize = tileSize.getValue();
- }
- getOperation().walk([&](AttentionOp attnOp) {
- SmallVector<Operation *> ops;
- decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
- });
getOperation().walk([&](OnlineAttentionOp onlineAtt) {
rewriter.setInsertionPoint(onlineAtt);
FailureOr<SmallVector<Value>> results =
@@ -375,4 +43,5 @@
rewriter.replaceOp(onlineAtt, results.value());
});
}
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index af7716c..1e858df 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -34,16 +34,6 @@
RewritePatternSet &patterns,
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);
-IREE::LinalgExt::AttentionOp
-tileAttention(IREE::LinalgExt::AttentionOp attnOp,
- SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
- std::optional<uint64_t> tileSize = std::nullopt);
-
-void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter,
- std::optional<uint64_t> tileSize = std::nullopt);
-
void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 6454c5a..3995fba 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -83,24 +83,10 @@
];
}
-def TileAttentionPass :
- InterfacePass<"iree-linalg-ext-tile-attention", "mlir::FunctionOpInterface"> {
- let summary =
- "Tile the attention op along the reduction dimension";
- let options = [
- Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
- "Tile size for sequential for loop in attention">,
- ];
-}
-
def DecomposeAttentionPass :
InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> {
let summary =
"Decomposes attention op into a sequence of linalg ops";
- let options = [
- Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
- "Tile size for sequential for loop in attention">,
- ];
}
def ConvertAttentionToOnlineAttentionPass :
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 8089cd0..472a073 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -16,196 +16,11 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
-#define GEN_PASS_DEF_TILEATTENTIONPASS
#define GEN_PASS_DEF_CONVERTATTENTIONTOONLINEATTENTIONPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
namespace {
-static Value truncateToF16(Value input, Value output,
- SmallVectorImpl<Operation *> &ops,
- OpBuilder &builder, Location loc) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- SmallVector<AffineMap> indexingMaps{identityMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), ValueRange{input}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::TruncFOp>(loc, b.getF16Type(), args[0]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static Value applyFinalScaling(Value result, Value newSum, Location loc,
- OpBuilder &builder,
- SmallVectorImpl<Operation *> &ops) {
- AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(2, builder.getContext());
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- // (d0, d1) -> (d0)
- auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps = {rowMap, identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(2,
- utils::IteratorType::parallel);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, result.getType(), newSum, result, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value one = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(args[0].getType(), 1.0));
- Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
- Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- ops.push_back(genericOp);
- return genericOp.getResult(0);
-}
-
-static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
- Value step, Value ub, ValueRange args,
- Location loc, OpBuilder &builder) {
- SmallVector<Value> lbs{lb};
- SmallVector<Value> steps{step};
- SmallVector<Value> ubs{ub};
- scf::LoopNest loopNest = scf::buildLoopNest(
- builder, loc, lbs, ubs, steps, args,
- [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
- ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
- for (scf::ForOp loop : loopNest.loops) {
- ivs.push_back(loop.getInductionVar());
- }
- return loopNest;
-}
-
-static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
- ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
- OpFoldResult headDimension, Type elementType,
- Location loc, OpBuilder &builder,
- bool swapLastTwoDims = false) {
- auto one = builder.getIndexAttr(1);
- auto zero = builder.getIndexAttr(0);
- SmallVector<OpFoldResult> strides(keyShape.size(), one);
- SmallVector<OpFoldResult> sizes(keyShape.size(), one);
- SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
- sizes[1] = keyValueTileLength;
- sizes[2] = headDimension;
- if (!ivs.empty()) {
- offsets[1] = ivs[0];
- }
- SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
- if (swapLastTwoDims) {
- std::swap(sizes[1], sizes[2]);
- std::swap(offsets[1], offsets[2]);
- std::swap(tensorShape[0], tensorShape[1]);
- }
- auto tensorType = RankedTensorType::get(tensorShape, elementType);
- Value keySlice = builder.create<tensor::ExtractSliceOp>(
- loc, tensorType, key, offsets, sizes, strides);
- return keySlice;
-}
-
-static Value extractOrInsertOutputSlice(Value src, Value dst,
- ArrayRef<int64_t> queryShape,
- OpFoldResult sequenceTileLength,
- OpFoldResult headDimension,
- Location loc, OpBuilder &builder) {
- auto one = builder.getIndexAttr(1);
- auto zero = builder.getIndexAttr(0);
- SmallVector<OpFoldResult> strides(3, one);
- SmallVector<OpFoldResult> sizes = {one, sequenceTileLength, headDimension};
- SmallVector<OpFoldResult> offsets(3, zero);
- Value slice;
- if (!dst) {
- SmallVector<int64_t> accShape{queryShape[1], queryShape[2]};
- Type elementType = cast<ShapedType>(src.getType()).getElementType();
- auto tensorType = RankedTensorType::get(accShape, elementType);
- slice = builder.create<tensor::ExtractSliceOp>(loc, tensorType, src,
- offsets, sizes, strides);
- } else {
- slice = builder.create<tensor::InsertSliceOp>(loc, src, dst, offsets, sizes,
- strides);
- }
- return slice;
-}
-
-static Value extractOutputSlice(Value src, ArrayRef<int64_t> queryShape,
- OpFoldResult sequenceTileLength,
- OpFoldResult headDimension, Location loc,
- OpBuilder &builder) {
- return extractOrInsertOutputSlice(src, {}, queryShape, sequenceTileLength,
- headDimension, loc, builder);
-}
-
-static Value insertOutputSlice(Value src, Value dst,
- OpFoldResult sequenceTileLength,
- OpFoldResult headDimension, Location loc,
- OpBuilder &builder) {
- return extractOrInsertOutputSlice(src, dst, {}, sequenceTileLength,
- headDimension, loc, builder);
-}
-
-static SmallVector<AffineMap>
-getTileAttentionIndexingMaps(RewriterBase &rewriter, int64_t tiledInputRank,
- bool transposeV) {
- MLIRContext *ctx = rewriter.getContext();
- AffineExpr m, k1, k2, n;
- bindDims(ctx, m, k1, k2, n);
-
- AffineMap qMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
- AffineMap kMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
- AffineMap vMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
- AffineMap sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {}, ctx);
- AffineMap rMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
- AffineMap maxMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx);
- AffineMap sumMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx);
-
- if (transposeV) {
- SmallVector<AffineExpr> vDims(vMap.getResults());
- std::swap(vDims[0], vDims[1]);
- vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx);
- }
-
- SmallVector<AffineMap> attentionMaps = {qMap, kMap, vMap, sMap,
- rMap, maxMap, sumMap};
- // Add batches to standard attention indexing maps.
- int64_t numBatches = tiledInputRank - 2;
- for (AffineMap &map : attentionMaps) {
- map = map.shiftDims(numBatches);
- for (int batch : llvm::seq<int>(numBatches)) {
- map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
- }
- }
-
- return attentionMaps;
-}
-
-struct TileAttentionPass final
- : impl::TileAttentionPassBase<TileAttentionPass> {
- using impl::TileAttentionPassBase<TileAttentionPass>::TileAttentionPassBase;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<
- affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
- linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
- }
- TileAttentionPass() = default;
- TileAttentionPass(bool onlyTile, uint64_t tileSize) {
- this->tileSize = tileSize;
- }
- TileAttentionPass(const TileAttentionPass &pass) { tileSize = pass.tileSize; }
- void runOnOperation() override;
-};
-
struct ConvertAttentionToOnlineAttentionPass final
: impl::ConvertAttentionToOnlineAttentionPassBase<
ConvertAttentionToOnlineAttentionPass> {
@@ -219,140 +34,6 @@
} // namespace
-/// Tile iree_linalg_ext.attention.
-/// TODO: Adopt getTiledImplementation with this.
-IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter,
- std::optional<uint64_t> tileSize) {
- Location loc = attnOp.getLoc();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(attnOp);
-
- Value query = attnOp.getQuery();
- ShapedType queryType = attnOp.getQueryType();
- Type elementType = queryType.getElementType();
- ArrayRef<int64_t> queryShape = queryType.getShape();
- SmallVector<OpFoldResult> queryDimValues =
- tensor::getMixedSizes(rewriter, loc, query);
- OpFoldResult headDimension = queryDimValues[2];
- OpFoldResult sequenceTileLength = queryDimValues[1];
- OpFoldResult keyValueTileLength = sequenceTileLength;
- SmallVector<int64_t> keyShape{queryShape};
- if (tileSize) {
- keyValueTileLength = rewriter.getIndexAttr(tileSize.value());
- for (auto [idx, val] : llvm::enumerate(attnOp.getKeyType().getShape())) {
- keyShape[idx] = idx == 1 ? tileSize.value() : val;
- }
- }
-
- Value key = attnOp.getKey();
- Value value = attnOp.getValue();
- SmallVector<OpFoldResult> keyDimValues =
- tensor::getMixedSizes(rewriter, loc, key);
- OpFoldResult sequenceLength = keyDimValues[1];
-
- // Create output accumulator
- Value output = attnOp.getOutput();
- Type f32Type = rewriter.getF32Type();
- SmallVector<OpFoldResult> accShape{queryDimValues[1], queryDimValues[2]};
- Value accumulatorF32 =
- rewriter.create<tensor::EmptyOp>(loc, accShape, f32Type);
-
- // Create accumulator, max and sum statistics
- Value outputSlice = extractOutputSlice(output, queryShape, sequenceTileLength,
- headDimension, loc, rewriter);
- Value zeroF32 =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(f32Type));
- auto accumulatorFill =
- rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, accumulatorF32);
- accumulatorF32 = accumulatorFill.result();
- ops.push_back(accumulatorFill);
-
- Value largeNegativeF32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getFloatAttr(f32Type, -1.0e+30));
- SmallVector<OpFoldResult> dims{sequenceTileLength};
- Value max = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
- auto maxFill =
- rewriter.create<linalg::FillOp>(loc, ValueRange{largeNegativeF32}, max);
- Value negativeMax = maxFill.result();
- ops.push_back(maxFill);
- Value sum = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
- auto sumFill = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, sum);
- Value zeroSum = sumFill.result();
- ops.push_back(sumFill);
-
- // Construct sequential loop
- SmallVector<Value> ivs;
- Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- scf::LoopNest loopNest = createLoopNest(
- ivs, zeroValue,
- getValueOrCreateConstantIndexOp(rewriter, loc, keyValueTileLength),
- getValueOrCreateConstantIndexOp(rewriter, loc, sequenceLength),
- ValueRange({accumulatorF32, negativeMax, zeroSum}), loc, rewriter);
- ops.push_back(loopNest.loops.back());
-
- Value iterArgResult = loopNest.loops.back().getRegionIterArg(0);
- Value iterArgMax = loopNest.loops.back().getRegionIterArg(1);
- Value iterArgSum = loopNest.loops.back().getRegionIterArg(2);
-
- OpBuilder::InsertionGuard guardSecondLoop(rewriter);
- rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
-
- // Extract slices
- Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
- headDimension, elementType, loc, rewriter);
- Value valueSlice =
- extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
- elementType, loc, rewriter, attnOp.isTransposeV());
- Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
- headDimension, elementType, loc, rewriter);
-
- Value scale = attnOp.getScale();
-
- int64_t tiledInputRank = cast<ShapedType>(querySlice.getType()).getRank();
- SmallVector<AffineMap> tiledIndexingMaps = getTileAttentionIndexingMaps(
- rewriter, tiledInputRank, attnOp.isTransposeV());
-
- auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
- attnOp.getLoc(),
- SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
- querySlice, keySlice, valueSlice, scale,
- SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum},
- rewriter.getAffineMapArrayAttr(tiledIndexingMaps));
-
- Value tiledResult = tiledAttentionOp.getResult(0);
- Value newMax = tiledAttentionOp.getResult(1);
- Value newSum = tiledAttentionOp.getResult(2);
-
- if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
- loopNest.loops.back().getBody()->getTerminator())) {
- OpBuilder::InsertionGuard yieldGuard(rewriter);
- rewriter.setInsertionPoint(yieldOp);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(
- yieldOp, ValueRange{tiledResult, newMax, newSum});
- }
-
- OpBuilder::InsertionGuard yieldGuard(rewriter);
- rewriter.setInsertionPointAfter(loopNest.loops.back());
-
- loopNest.results[0] = applyFinalScaling(
- loopNest.results[0], loopNest.results[2], loc, rewriter, ops);
-
- if (elementType.isF16()) {
- loopNest.results[0] =
- truncateToF16(loopNest.results[0], outputSlice, ops, rewriter, loc);
- }
- loopNest.results[0] =
- insertOutputSlice(loopNest.results[0], output, sequenceTileLength,
- headDimension, loc, rewriter);
-
- rewriter.replaceOp(attnOp, loopNest.results[0]);
- ops.push_back(tiledAttentionOp);
-
- return tiledAttentionOp;
-}
-
void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter) {
@@ -462,19 +143,6 @@
rewriter.replaceOp(attnOp, genericOp);
}
-void TileAttentionPass::runOnOperation() {
- MLIRContext *context = &getContext();
- IRRewriter rewriter(context);
- std::optional<uint64_t> optionalTileSize{std::nullopt};
- if (tileSize.hasValue()) {
- optionalTileSize = tileSize.getValue();
- }
- getOperation().walk([&](AttentionOp attnOp) {
- SmallVector<Operation *> ops;
- tileAttention(attnOp, ops, rewriter, optionalTileSize);
- });
-}
-
void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index 0d88661..efe463a6 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -21,14 +21,12 @@
"convert_to_loops.mlir",
"convert_to_online_attention.mlir",
"decompose_aggregate_op.mlir",
- "decompose_attention.mlir",
"decompose_im2col.mlir",
"decompose_online_attention.mlir",
"decompose_winograd.mlir",
"distribution.mlir",
"pad_contraction_to_block_size.mlir",
"split_reduction.mlir",
- "tile_attention.mlir",
"tiling.mlir",
],
include = ["*.mlir"],
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 1268444..3288c14 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -19,14 +19,12 @@
"convert_to_loops.mlir"
"convert_to_online_attention.mlir"
"decompose_aggregate_op.mlir"
- "decompose_attention.mlir"
"decompose_im2col.mlir"
"decompose_online_attention.mlir"
"decompose_winograd.mlir"
"distribution.mlir"
"pad_contraction_to_block_size.mlir"
"split_reduction.mlir"
- "tile_attention.mlir"
"tiling.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
deleted file mode 100644
index 19d6a6b..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
+++ /dev/null
@@ -1,350 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-tile-attention{tileSize=32},iree-linalg-ext-decompose-attention{tileSize=32}),cse)" %s | FileCheck %s
-
-func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
- %0 = tensor.empty() : tensor<1x1024x64xf32>
- %scale = arith.constant 0.05 : f32
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
- return %1 : tensor<1x1024x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func.func @attention
-// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<1x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// CHECK: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
-// CHECK-SAME: tensor<1024x64xf32>
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
-// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf32> to tensor<32x64xf32>
-// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf32> to tensor<32x64xf32>
-// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf32> to tensor<1024x64xf32>
-// CHECK: %[[SCALE_Q:.+]] = linalg.generic {{.+}} ins(%[[EXTRACTED_SLICE_2]] : tensor<1024x64xf32>)
-// CHECK: %[[D8:.+]] = tensor.empty() : tensor<1024x32xf32>
-// CHECK: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x32xf32>) ->
-// CHECK-SAME: tensor<1024x32xf32>
-// CHECK: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME: tensor<1024x64xf32>, tensor<32x64xf32>) outs(%[[D9]] : tensor<1024x32xf32>) -> tensor<1024x32xf32>
-// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D10]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x32xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
-// CHECK: } -> tensor<1024x32xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D12]] : tensor<1024x32xf32>) outs(%[[D14]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x32xf32>,
-// CHECK-SAME: tensor<32x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK: scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
-// CHECK-SAME: {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D9]] : f32
-// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1024x64xf32> into tensor<1x1024x64xf32>
-// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK: }
-
-// -----
-
-func.func @attention(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<?x?x?xf32> {
- %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
- %scale = arith.constant 0.05 : f32
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
- return %1 : tensor<?x?x?xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func.func @attention
-// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index) -> tensor<?x?x?xf32> {
-// CHECK: %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-DAG: %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK: %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
-// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[C32]]
-// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1, 1,
-// CHECK-SAME: 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
-// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1,
-// CHECK-SAME: 1, 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
-// CHECK: %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
-// CHECK-SAME: 1] : tensor<?x?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[SCALE_Q:.+]] = linalg.generic
-// CHECK: %[[D8:.+]] = tensor.empty(%[[DIM_5]]) : tensor<?x32xf32>
-// CHECK: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<?x32xf32>) -> tensor<?x32xf32>
-// CHECK: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME: tensor<?x?xf32>, tensor<32x?xf32>) outs(%[[D9]] : tensor<?x32xf32>) -> tensor<?x32xf32>
-// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D10]] : tensor<?x32xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D11]] : tensor<?xf32>) outs(%[[D10]] : tensor<?x32xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
-// CHECK: } -> tensor<?x32xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
-// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D12]] : tensor<?x32xf32>) outs(%[[D14]] : tensor<?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x32xf32>,
-// CHECK-SAME: tensor<32x?xf32>) outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
-// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
-// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D9]] : f32
-// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
-// CHECK-SAME: [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
-// CHECK: return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK: }
-
-// -----
-
-func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
- %0 = tensor.empty() : tensor<1x1024x64xf16>
- %scale = arith.constant 0.05 : f16
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
- return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func.func @attention_f16
-// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
-// CHECK: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
-// CHECK-SAME: tensor<1024x64xf32>
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
-// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
-// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf16> to tensor<32x64xf16>
-// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf16> to tensor<32x64xf16>
-// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK: %[[SCALE_Q:.+]] = linalg.generic
-// CHECK: %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
-// CHECK: %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
-// CHECK-SAME: tensor<1024x32xf32>
-// CHECK: %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
-// CHECK-SAME: tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
-// CHECK-SAME: tensor<1024x32xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D22:.+]] = math.exp2 %[[D21]] : f32
-// CHECK: linalg.yield %[[D22]] : f32
-// CHECK: } -> tensor<1024x32xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK: %[[D22]] = math.exp2 %[[D21]] : f32
-// CHECK: linalg.yield %[[D22]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
-// CHECK: %[[D18:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK: %[[D21]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK: linalg.yield %[[D21]] : f16
-// CHECK: } -> tensor<1024x32xf16>
-// CHECK: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
-// CHECK-SAME: tensor<32x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK: scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
-// CHECK-SAME: {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK: %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D10]] : f32
-// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK: linalg.yield %[[D9]] : f16
-// CHECK: } -> tensor<1024x64xf16>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
-// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
-// CHECK: }
-
-// -----
-
-// transpose_V is detected through indexingMap.
-
-func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
- %0 = tensor.empty() : tensor<1x1024x64xf16>
- %scale = arith.constant 0.05 : f16
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
- return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-LABEL: func.func @attention_transpose_v
-// There should be two matmul_transpose_b for tranpose_v variant instead of
-// only 1.
-// CHECK: linalg.matmul_transpose_b
-// CHECK-NOT: linalg.matmul
-// CHECK: linalg.matmul_transpose_b
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
deleted file mode 100644
index 51d3c51..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
+++ /dev/null
@@ -1,165 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-tile-attention),cse)" %s | FileCheck %s --check-prefix=CHECK
-
-// TODO: These tests should be moved to tiling.mlir when PartialReductionOpInterface is implemented for attention op.
-
-func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
- %0 = tensor.empty() : tensor<1x1024x64xf32>
- %scale = arith.constant 0.05 : f32
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
- return %1 : tensor<1x1024x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-LABEL: func.func @attention
-// CHECK-SAME: (%[[QUERY:.+]]: tensor<1x1024x64xf32>, %[[KEY:.+]]: tensor<1x1024x64xf32>, %[[VALUE:.+]]: tensor<1x1024x64xf32>)
-// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant 5.000000e-02 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>)
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>)
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
-// CHECK-SAME: iter_args(%[[ARG4:.+]] = %[[D2]], %[[ARG5:.+]] = %[[D4]], %[[ARG6:.+]] = %[[D5]])
-// CHECK: %[[K_S:.+]] = tensor.extract_slice %[[KEY]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1]
-// CHECK: %[[V_S:.+]] = tensor.extract_slice %[[VALUE]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1]
-// CHECK: %[[Q_S:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, 1024, 64] [1, 1, 1]
-// CHECK: %[[ATT:.+]]:3 = iree_linalg_ext.attention
-// CHECK-SAME: ins(%[[Q_S]], %[[K_S]], %[[V_S]], %[[CST_1]]
-// CHECK-SAME: outs(%[[ARG4]], %[[ARG5]], %[[ARG6]]
-// CHECK: scf.yield %[[ATT]]#0, %[[ATT]]#1, %[[ATT]]#2
-// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic
-// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME: ins(%[[D6]]#2 : tensor<1024xf32>)
-// CHECK-SAME: outs(%[[D6]]#0 : tensor<1024x64xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00
-// CHECK: %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]]
-// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]]
-// CHECK: linalg.yield %[[D9]]
-// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]]
-// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK: }
-
-// -----
-
-func.func @attention(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<?x?x?xf32> {
- %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
- %scale = arith.constant 0.05 : f32
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
- return %1 : tensor<?x?x?xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-LABEL: func.func @attention(
-// CHECK-SAME: %[[QUERY:.+]]: tensor<?x?x?xf32>, %[[KEY:.+]]: tensor<?x?x?xf32>, %[[VALUE:.+]]: tensor<?x?x?xf32>,
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index)
-// CHECK: %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[QUERY]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[QUERY]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[KEY]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-DAG: %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK: %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
-// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[DIM]]
-// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[K_S:.+]] = tensor.extract_slice %[[KEY]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK: %[[V_S:.+]] = tensor.extract_slice %[[VALUE]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK: %[[Q_S:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK: %[[ATT:.+]]:3 = iree_linalg_ext.attention
-// CHECK-SAME: ins(%[[Q_S]], %[[K_S]], %[[V_S]], %{{[a-z0-1]+}}
-// CHECK-SAME: outs(%[[ARG7]], %[[ARG8]], %[[ARG9]]
-// CHECK: scf.yield %[[ATT]]#0, %[[ATT]]#1, %[[ATT]]#2
-// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00
-// CHECK: %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]]
-// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]]
-// CHECK: linalg.yield %[[D9]]
-// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
-// CHECK-SAME: [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
-// CHECK: return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK: }
-
-// -----
-
-func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
- %0 = tensor.empty() : tensor<1x1024x64xf16>
- %scale = arith.constant 0.05 : f16
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
- return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-LABEL: @attention_f16
-
-// CHECK: scf.for
-// CHECK: iree_linalg_ext.attention
-// CHECK-SAME: ins(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<1024x64xf16>, tensor<1024x64xf16>, tensor<1024x64xf16>, f16
-// CHECK-SAME: outs(%{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK: scf.yield
-
-// CHECK: linalg.generic
-
-// CHECK: %[[TRUNCED:.+]] = linalg.generic
-// CHECK: arith.truncf
-// CHECK: } -> tensor<1024x64xf16>
-
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[TRUNCED]]
-// CHECK: return %[[INSERTED_SLICE]]
-// CHECK: }
-
-// -----
-
-// transpose_V is detected through indexingMap.
-
-func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
- %0 = tensor.empty() : tensor<1x1024x64xf16>
- %scale = arith.constant 0.05 : f16
- %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
- ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
- return %1 : tensor<1x1024x64xf16>
-}
-// CHECK-LABEL: func.func @attention_transpose_v
-// CHECK: scf.for
-// CHECK: iree_linalg_ext.attention
-// CHECK: scf.yield