[LinalgExt] Do not decompose attention op with manual analysis. (#16525)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 4107c55..4bff8f0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -100,10 +100,8 @@
createFuseTensorPadWithConsumerPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConcretizePadResultShapePass());
- // TODO(#16421) Decomposition is failing during bufferization.
- // Disable till fixed
- // nestedModulePM.addNestedPass<func::FuncOp>(
- // IREE::LinalgExt::createTileAndDecomposeAttentionPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ IREE::LinalgExt::createTileAndDecomposeAttentionPass());
nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createTileAndDecomposeWinogradTransformPass());
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index cff5a3e..ded6cb0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -113,9 +113,8 @@
nestedModulePM.addNestedPass<func::FuncOp>(
createConvertToDestinationPassingStylePass(
useWARForCooperativeMatrixCodegen));
- // TODO(#16421): Disable decomposition due to failure in bufferization.
- // nestedModulePM.addNestedPass<func::FuncOp>(
- // IREE::LinalgExt::createTileAndDecomposeAttentionPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ IREE::LinalgExt::createTileAndDecomposeAttentionPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index ce450ba..48bd221 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -132,7 +132,10 @@
// CHECK: %[[D18:.+]] = vector.transpose %[[D17]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
// CHECK: %[[D19:.+]] = arith.subf %[[D15]], %[[D18]] : vector<32x128xf32>
// CHECK: %[[D20:.+]] = math.exp2 %[[D19]] : vector<32x128xf32>
-// CHECK: %[[D21:.+]] = arith.subf %[[ARG1]], %[[D16]] : vector<32xf32>
+// CHECK: %[[ALLOC_12:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32xf32,
+// CHECK-SAME: #[[GPU]].address_space<workgroup>>
+// CHECK: %[[READ_ALLOC_12:.+]] = vector.transfer_read %[[ALLOC_12]]
+// CHECK: %[[D21:.+]] = arith.subf %[[READ_ALLOC_12]], %[[D16]] : vector<32xf32>
// CHECK: %[[D22:.+]] = math.exp2 %[[D21]] : vector<32xf32>
// CHECK: %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
// CHECK: %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 539f20b..c2232c8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -18,6 +18,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
+#include <mlir/IR/TypeUtilities.h>
namespace mlir {
namespace iree_compiler {
@@ -79,12 +80,20 @@
static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
+ // Create tensor.empty op. This is a workaround to avoid a failure in
+ // bufferization. For more details, please see
+ // https://github.com/openxla/iree/issues/16421
+ SmallVector<OpFoldResult> initShape =
+ tensor::getMixedSizes(builder, loc, newMax);
+ Value init = builder.create<tensor::EmptyOp>(
+ loc, initShape, getElementTypeOrSelf(oldMax.getType()));
+
SmallVector<utils::IteratorType> iteratorTypes(1,
utils::IteratorType::parallel);
auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
SmallVector<AffineMap> indexingMaps(2, identityMap);
auto genericOp = builder.create<linalg::GenericOp>(
- loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
+ loc, init.getType(), newMax, init, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
Value weight = b.create<math::Exp2Op>(loc, diff);
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 3339eb5..3531fb1 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -53,7 +53,7 @@
// TILESIZE: linalg.yield %[[D19]] : f32
// TILESIZE: } -> tensor<1024x32xf32>
// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// TILESIZE: %[[D19]] = math.exp2 %[[D18]] : f32
@@ -183,7 +183,7 @@
// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<1024x1024xf32>
// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18:.+]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
@@ -280,8 +280,10 @@
// TILESIZE: %[[D19:.+]] = math.exp2 %[[D18]] : f32
// TILESIZE: linalg.yield %[[D19]] : f32
// TILESIZE: } -> tensor<?x32xf32>
+// TILESIZE: %[[DIM6:.+]] = tensor.dim %[[D11]], %[[C0]]
+// TILESIZE: %[[INIT:.+]] = tensor.empty(%[[DIM6]])
// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME: ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE-SAME: ins(%[[D11]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>) {
// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// TILESIZE: %[[D19]] = math.exp2 %[[D18]] : f32
@@ -413,8 +415,10 @@
// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<?x?xf32>
+// CHECK: %[[DIM6:.+]] = tensor.dim %[[D11]], %[[C0]]
+// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM6]])
// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D10]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// CHECK-SAME: ins(%[[D10]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D19]] = math.exp2 %[[D18]] : f32
@@ -511,7 +515,7 @@
// TILESIZE: linalg.yield %[[D22]] : f32
// TILESIZE: } -> tensor<1024x32xf32>
// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// TILESIZE: %[[D22]] = math.exp2 %[[D21]] : f32
@@ -663,7 +667,7 @@
// CHECK: linalg.yield %[[D22]] : f32
// CHECK: } -> tensor<1024x1024xf32>
// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D22]] = math.exp2 %[[D21]] : f32
@@ -774,7 +778,7 @@
// TILESIZE: linalg.yield %[[D22]] : f32
// TILESIZE: } -> tensor<1024x32xf32>
// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// TILESIZE: %[[D22]] = math.exp2 %[[D21]] : f32
@@ -926,7 +930,7 @@
// CHECK: linalg.yield %[[D22]] : f32
// CHECK: } -> tensor<1024x1024xf32>
// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D22]] = math.exp2 %[[D21]] : f32