Enable alwaysAliasingWithDest in IREE comprehensive bufferization. (#8594)
Fixes https://github.com/google/iree/issues/8500
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 6858248..2278c81 100644
--- a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -103,7 +103,7 @@
options.memCpyFn = memCpyFn;
options.testAnalysisOnly = testAnalysisOnly;
options.printConflicts = printConflicts;
- options.alwaysAliasingWithDest = false;
+ options.alwaysAliasingWithDest = true;
addPostAnalysisTransformations(options);
if (failed(bufferization::runOneShotBufferize(moduleOp, options))) {
diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index b9d84ee..830af58 100644
--- a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s --iree-codegen-iree-comprehensive-bufferize -canonicalize -cse -split-input-file | FileCheck %s
+// RUN: iree-opt %s --iree-codegen-iree-comprehensive-bufferize -canonicalize -cse -canonicalize -split-input-file | FileCheck %s
func @matmul() {
%c0 = arith.constant 0 : index
@@ -142,3 +142,60 @@
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
// CHECK-SAME: outs(%[[RESULT_TILE]]
+
+// -----
+
+func @elementwise() {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x10xf32>
+ %c512 = arith.constant 512 : index
+ %c64 = arith.constant 64 : index
+ %c10 = arith.constant 10 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c512) alignment(64) : !flow.dispatch.tensor<readonly:1x10xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c64) alignment(64) : !flow.dispatch.tensor<writeonly:1x10xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %2 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
+ %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_x]
+ scf.for %arg0 = %2 to %c10 step %3 {
+ %4 = affine.min affine_map<(d0) -> (4, -d0 + 10)>(%arg0)
+ %5 = flow.dispatch.tensor.load %0, offsets = [0, %arg0], sizes = [1, %4], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x10xf32> -> tensor<1x?xf32>
+ %6 = flow.dispatch.tensor.load %1, offsets = [0, %arg0], sizes = [1, %4], strides = [1, 1] : !flow.dispatch.tensor<writeonly:1x10xf32> -> tensor<1x?xf32>
+ %7 = scf.for %arg1 = %c0 to %4 step %c4 iter_args(%arg2 = %6) -> (tensor<1x?xf32>) {
+ %8 = affine.min affine_map<(d0, d1) -> (4, -d0 + d1)>(%arg1, %4)
+ %9 = tensor.extract_slice %5[0, %arg1] [1, %8] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+ %10 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg1, %arg0)
+ %11 = tensor.extract_slice %cst[0, %10] [1, %8] [1, 1] : tensor<1x10xf32> to tensor<1x?xf32>
+ %12 = tensor.extract_slice %arg2[0, %arg1] [1, %8] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
+ %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%9, %11 : tensor<1x?xf32>, tensor<1x?xf32>)
+ outs(%12 : tensor<1x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %15 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x?xf32>
+ %14 = tensor.insert_slice %13 into %arg2[0, %arg1] [1, %8] [1, 1] : tensor<1x?xf32> into tensor<1x?xf32>
+ scf.yield %14 : tensor<1x?xf32>
+ }
+ flow.dispatch.tensor.store %7, %1, offsets = [0, %arg0], sizes = [1, %4], strides = [1, 1] : tensor<1x?xf32> -> !flow.dispatch.tensor<writeonly:1x10xf32>
+ }
+ return
+}
+// CHECK: func @elementwise()
+// CHECK-DAG: %[[GLB_CST:.+]] = memref.get_global @__constant_1x10xf32 : memref<1x10xf32>
+// CHECK-DAG: %[[IN_BUF:.+]] = hal.interface.binding.subspan set(0) binding(0) {{.+}} : memref<1x10xf32>
+// CHECK-DAG: %[[OUT_BUF:.+]] = hal.interface.binding.subspan set(0) binding(1) {{.+}} : memref<1x10xf32>
+// CHECK: scf.for
+// CHECK-DAG: %[[SUB_IN1:.+]] = memref.subview %[[IN_BUF]]
+// CHECK-DAG: %[[SUB_OUT1:.+]] = memref.subview %[[OUT_BUF]]
+// CHECK: scf.for
+// CHECK-DAG: %[[SUB_IN2:.+]] = memref.subview %[[SUB_IN1]]
+// CHECK-DAG: %[[SUB_CST:.+]] = memref.subview %[[GLB_CST]]
+// CHECK-DAG: %[[SUB_OUT2:.+]] = memref.subview %[[SUB_OUT1]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUB_IN2]], %[[SUB_CST]]
+// CHECK-SAME: outs(%[[SUB_OUT2]]