Merge pull request #8379 from matthias-springer/assert_align
Generate memref.assert_alignment ops
diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index f6ecd57..5fb20f8 100644
--- a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -79,7 +79,7 @@
%m = hal.interface.constant.load[0] : index
%n = hal.interface.constant.load[1] : index
%k = hal.interface.constant.load[2] : index
- %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
+ %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%m, %k}
%rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%k, %n}
%result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readwrite:?x?xf32>{%m, %n}
%wg_id_y = hal.interface.workgroup.id[1] : index
@@ -114,7 +114,8 @@
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
-// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+// CHECK-DAG: memref.assume_alignment %[[LHS]], 32
// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
diff --git a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index 626dea9..e50ff3e 100644
--- a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -259,6 +259,11 @@
subspanOp->getLoc(), memRefType, subspanOp.set(), subspanOp.binding(),
subspanOp.type(), subspanOp.byte_offset(), subspanOp.dynamic_dims(),
subspanOp.alignmentAttr());
+ if (subspanOp.alignment()) {
+ rewriter.create<memref::AssumeAlignmentOp>(
+ subspanOp->getLoc(), baseBuffer,
+ subspanOp.alignment()->getZExtValue());
+ }
flowState.subspan_to_buffer[tensor] = baseBuffer;
}