Adding SubrangeOperandOpInterface to better fold util.buffer.subspan. (#11340)
This will allow ops outside of the util dialect to have subranges
updated by generic canonicalization patterns.
Progress on #11027, VMVX ops need to implement these interfaces.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
index 945d309..a3ddc56 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
@@ -34,30 +34,32 @@
return false;
}
-/// Returns the offset, in bytes, of an index within a linearized dense buffer.
+/// Returns the offset, in bytes, of an index within a linearized dense buffer
+/// and the element length accessed.
/// Expects that the |memrefValue| has been linearized already.
-static Value getBufferOffset(Location loc, Value memrefValue,
- ValueRange indices,
- ConversionPatternRewriter &rewriter) {
- auto memrefType = memrefValue.getType().cast<ShapedType>();
- if (memrefType.getRank() == 0) {
- // Rank 0 buffers (like memref<i32>) have only a single valid offset at 0.
- return rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
- }
- assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
-
+static std::pair<Value, Value> getBufferOffsetAndLength(
+ Location loc, Value memrefValue, ValueRange indices,
+ ConversionPatternRewriter &rewriter) {
// Element type byte length as the base. Note that this is the unconverted
// element type. Since these are storage types within a buffer, they are
// not subject to general type conversion (i.e. a general type converter
// may elect to represent all i8 registers as i32, but this does not mean
// that all memrefs are widened from i8 to i32).
+ auto memrefType = memrefValue.getType().cast<ShapedType>();
auto elementType = memrefType.getElementType();
auto elementSize =
rewriter.createOrFold<IREE::Util::SizeOfOp>(loc, elementType);
+ if (memrefType.getRank() == 0) {
+ // Rank 0 buffers (like memref<i32>) have only a single valid offset at 0.
+ return {rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0), elementSize};
+ }
+ assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
+
// Rank 1 memrefs are just offset by their element width by the offset.
auto elementCount = indices.front();
- return rewriter.create<arith::MulIOp>(loc, elementSize, elementCount);
+ return {rewriter.create<arith::MulIOp>(loc, elementSize, elementCount),
+ elementSize};
}
/// Pattern to lower operations that become a no-ops at this level.
@@ -202,10 +204,11 @@
auto newType = getTypeConverter()->convertType(oldType);
auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
loadOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
- auto byteOffset = getBufferOffset(loadOp.getLoc(), loadOp.getMemref(),
- loadOp.getIndices(), rewriter);
+ auto [byteOffset, byteLength] = getBufferOffsetAndLength(
+ loadOp.getLoc(), loadOp.getMemref(), loadOp.getIndices(), rewriter);
Value loaded = rewriter.create<IREE::Util::BufferLoadOp>(
- loadOp.getLoc(), oldType, adaptor.getMemref(), memRefSize, byteOffset);
+ loadOp.getLoc(), oldType, adaptor.getMemref(), memRefSize, byteOffset,
+ byteLength);
if (newType != oldType) {
// Since the BufferLoadOp semantics include its result type (i.e. a load
// of an i8 is different than a load of an i32), in the presence of type
@@ -235,8 +238,8 @@
}
auto memRefSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
storeOp.getLoc(), rewriter.getIndexType(), adaptor.getMemref());
- auto byteOffset = getBufferOffset(storeOp.getLoc(), storeOp.getMemref(),
- storeOp.getIndices(), rewriter);
+ auto [byteOffset, byteLength] = getBufferOffsetAndLength(
+ storeOp.getLoc(), storeOp.getMemref(), storeOp.getIndices(), rewriter);
Value newValue = adaptor.getValue();
if (newValue.getType() != storeOp.getValue().getType()) {
// In combination with type conversion, the elemental type may change,
@@ -252,7 +255,8 @@
.getResult(0);
}
rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
- storeOp, newValue, adaptor.getMemref(), memRefSize, byteOffset);
+ storeOp, newValue, adaptor.getMemref(), memRefSize, byteOffset,
+ byteLength);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
index c12384f..9fa8ee7 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/memref_ops.mlir
@@ -78,10 +78,10 @@
func.func @load_store_f32(%buffer: memref<?xf32>, %idx0: index, %idx1: index) -> f32 {
// CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]]
// CHECK: %[[IDX0_BYTES:.+]] = arith.muli %[[IDX0]], %c4
- // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX0_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX0_BYTES]] for %c4] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
%0 = memref.load %buffer[%idx0] : memref<?xf32>
// CHECK: %[[IDX1_BYTES:.+]] = arith.muli %[[IDX1]], %c4
- // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[IDX1_BYTES]]] : f32 -> !util.buffer{%[[BUFFER_SIZE]]}
+ // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[IDX1_BYTES]] for %c4] : f32 -> !util.buffer{%[[BUFFER_SIZE]]}
memref.store %0, %buffer[%idx1] : memref<?xf32>
// CHECK: return %[[VALUE]] : f32
return %0 : f32
@@ -101,7 +101,7 @@
%0 = memref.get_global @__constant_f32 : memref<2xf32>
// CHECK: %[[BUFFER_SIZE:.+]] = util.buffer.size %[[BUFFER]]
// CHECK: %[[IDX_BYTES:.+]] = arith.muli %[[IDX]], %c4
- // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX_BYTES]]] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[IDX_BYTES]] for %c4] : !util.buffer{%[[BUFFER_SIZE]]} -> f32
%1 = memref.load %0[%idx] : memref<2xf32>
// CHECK: return %[[VALUE]] : f32
return %1 : f32
@@ -115,10 +115,10 @@
// CHECK-DAG: %[[SZ:.*]] = util.buffer.size %[[BUFFER]]
// CHECK-DAG: %[[OFS0:.*]] = arith.muli %[[IDX0]], %[[C2]] : index
// CHECK-DAG: %[[UCST0:.*]] = builtin.unrealized_conversion_cast %[[VALUE]] : i32 to i16
- // CHECK: util.buffer.store %[[UCST0]], %[[BUFFER]][%[[OFS0]]] : i16 -> !util.buffer{%[[SZ]]}
+ // CHECK: util.buffer.store %[[UCST0]], %[[BUFFER]][%[[OFS0]] for %[[C2]]] : i16 -> !util.buffer{%[[SZ]]}
memref.store %value, %buffer[%idx0] : memref<?xi16>
// CHECK: %[[OFS1:.*]] = arith.muli %[[IDX1]], %[[C2]] : index
- // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]]] : !util.buffer{%[[SZ]]} -> i16
+ // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]] for %c2] : !util.buffer{%[[SZ]]} -> i16
// CHECK: %[[UCST1:.*]] = builtin.unrealized_conversion_cast %[[LD]] : i16 to i32
%1 = memref.load %buffer[%idx1] : memref<?xi16>
// CHECK: return %[[UCST1]]
@@ -132,10 +132,10 @@
// CHECK-DAG: %[[SIZEOF:.*]] = util.sizeof index
// CHECK-DAG: %[[SZ:.*]] = util.buffer.size %[[BUFFER]]
// CHECK-DAG: %[[OFS0:.*]] = arith.muli %[[SIZEOF]], %[[IDX0]] : index
- // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFS0]]] : index -> !util.buffer{%[[SZ]]}
+ // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFS0]] for %[[SIZEOF]]] : index -> !util.buffer{%[[SZ]]}
memref.store %value, %buffer[%idx0] : memref<?xindex>
// CHECK: %[[OFS1:.*]] = arith.muli %[[SIZEOF]], %[[IDX1]] : index
- // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]]] : !util.buffer{%[[SZ]]} -> index
+ // CHECK: %[[LD:.*]] = util.buffer.load %[[BUFFER]][%[[OFS1]] for %[[SIZEOF]]] : !util.buffer{%[[SZ]]} -> index
%1 = memref.load %buffer[%idx1] : memref<?xindex>
// CHECK: return %[[LD]]
return %1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index 2ef7cdd..0857220 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -1053,6 +1053,34 @@
];
}
+def Util_SubrangeOperandOpInterface : OpInterface<"SubrangeOperandOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Interface for operations that consume subranges of size-aware resources.
+ The methods are used to manipulate the subranges during transformation.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the subrange operand values for the given flat operand index.
+ }],
+ /*retTy=*/"SubrangeOperand",
+ /*methodName=*/"getSubrangeOperand",
+ /*args=*/(ins "unsigned":$operandIndex)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Updates the subrange operand values for the given flat operand index.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setSubrangeOperand",
+ /*args=*/(ins "unsigned":$operandIndex, "SubrangeOperand":$operand)
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// IREE::Util::SerializableAttrInterface
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index f073eb4..461eeb2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -585,45 +585,6 @@
}
//===----------------------------------------------------------------------===//
-// util.buffer.slice
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Folds subspan ranges into slice ranges.
-//
-// Example:
-// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
-// %1 = util.buffer.slice %0[%slice_offset] ... -> {%slice_length}
-// ->
-// %new_offset = arith.addi %slice_offset, %subspan_offset
-// %1 = util.buffer.slice %src[%new_offset] ... -> {%slice_length}
-struct FoldSubspansIntoSliceOp : public OpRewritePattern<BufferSliceOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferSliceOp op,
- PatternRewriter &rewriter) const override {
- auto subspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
- if (!subspanOp) return failure();
- auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, subspanOp.getSourceOffset(), op.getSourceOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getSourceMutable().assign(subspanOp.getSource());
- op.getSourceSizeMutable().assign(subspanOp.getSourceSize());
- op.getSourceOffsetMutable().assign(newOffset);
- });
- return success();
- }
-};
-
-} // namespace
-
-void BufferSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspansIntoSliceOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
// util.buffer.subspan
//===----------------------------------------------------------------------===//
@@ -656,6 +617,42 @@
}
};
+// Folds subspan ranges into consumer ranges.
+//
+// Example:
+// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
+// %1 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
+// util.buffer.copy %0[%offset], %1[%offset], %length
+// ->
+// %new_offset = arith.addi %offset, %subspan_offset
+// util.buffer.copy %src[%new_offset], %dst[%new_offset], %subspan_length
+struct FoldBufferSubspanOpsIntoConsumers
+ : public OpRewritePattern<BufferSubspanOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferSubspanOp op,
+ PatternRewriter &rewriter) const override {
+ bool didUpdateAny = false;
+ for (auto &use : llvm::make_early_inc_range(op.getResult().getUses())) {
+ auto subrangeOp =
+ dyn_cast<IREE::Util::SubrangeOperandOpInterface>(use.getOwner());
+ if (!subrangeOp) continue;
+ didUpdateAny = true;
+ rewriter.setInsertionPoint(subrangeOp);
+ auto oldRange = subrangeOp.getSubrangeOperand(use.getOperandNumber());
+ auto fusedLoc =
+ rewriter.getFusedLoc({op.getLoc(), use.getOwner()->getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, op.getSourceOffset(), oldRange.offset);
+ auto newRange = SubrangeOperand{op.getSource(), op.getSourceSize(),
+ newOffset, oldRange.length};
+ rewriter.updateRootInPlace(subrangeOp, [&]() {
+ subrangeOp.setSubrangeOperand(use.getOperandNumber(), newRange);
+ });
+ }
+ return success(didUpdateAny);
+ }
+};
+
// Turns selects of subspans of a buffer into selects of the offset.
// This only works if the subspan sizes match.
//
@@ -697,6 +694,7 @@
void BufferSubspanOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<FoldBufferSubspanOps>(context);
+ results.insert<FoldBufferSubspanOpsIntoConsumers>(context);
results.insert<SinkSubspanAcrossSelectOps>(context);
}
@@ -815,237 +813,14 @@
}
//===----------------------------------------------------------------------===//
-// util.buffer.copy
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Folds subspan ranges into copy ranges.
-//
-// Example:
-// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
-// %1 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
-// util.buffer.copy %0[%offset], %1[%offset], %length
-// ->
-// %new_offset = arith.addi %offset, %subspan_offset
-// util.buffer.copy %src[%new_offset], %dst[%new_offset], %subspan_length
-struct FoldSubspansIntoCopyOp : public OpRewritePattern<BufferCopyOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferCopyOp op,
- PatternRewriter &rewriter) const override {
- auto sourceSubspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
- auto targetSubspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
- if (!sourceSubspanOp && !targetSubspanOp) return failure();
- if (sourceSubspanOp) {
- auto fusedLoc =
- rewriter.getFusedLoc({sourceSubspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, sourceSubspanOp.getSourceOffset(), op.getSourceOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getSourceMutable().assign(sourceSubspanOp.getSource());
- op.getSourceSizeMutable().assign(sourceSubspanOp.getSourceSize());
- op.getSourceOffsetMutable().assign(newOffset);
- });
- }
- if (targetSubspanOp) {
- auto fusedLoc =
- rewriter.getFusedLoc({targetSubspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, targetSubspanOp.getSourceOffset(), op.getTargetOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getTargetMutable().assign(targetSubspanOp.getSource());
- op.getTargetSizeMutable().assign(targetSubspanOp.getSourceSize());
- op.getTargetOffsetMutable().assign(newOffset);
- });
- }
- return success();
- }
-};
-
-} // namespace
-
-void BufferCopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspansIntoCopyOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// util.buffer.compare
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Folds subspan ranges into copy ranges.
-//
-// Example:
-// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
-// %1 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
-// util.buffer.copy %0[%offset], %1[%offset], %length
-// ->
-// %new_offset = arith.addi %offset, %subspan_offset
-// util.buffer.copy %src[%new_offset], %dst[%new_offset], %subspan_length
-struct FoldSubspansIntoCompareOp : public OpRewritePattern<BufferCompareOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferCompareOp op,
- PatternRewriter &rewriter) const override {
- auto sourceSubspanOp = BufferSubspanOp::findSubspanOp(op.getLhs());
- auto targetSubspanOp = BufferSubspanOp::findSubspanOp(op.getRhs());
- if (!sourceSubspanOp && !targetSubspanOp) return failure();
- if (sourceSubspanOp) {
- auto fusedLoc =
- rewriter.getFusedLoc({sourceSubspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, sourceSubspanOp.getSourceOffset(), op.getLhsOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getLhsMutable().assign(sourceSubspanOp.getSource());
- op.getLhsSizeMutable().assign(sourceSubspanOp.getSourceSize());
- op.getLhsOffsetMutable().assign(newOffset);
- });
- }
- if (targetSubspanOp) {
- auto fusedLoc =
- rewriter.getFusedLoc({targetSubspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, targetSubspanOp.getSourceOffset(), op.getRhsOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getRhsMutable().assign(targetSubspanOp.getSource());
- op.getRhsSizeMutable().assign(targetSubspanOp.getSourceSize());
- op.getRhsOffsetMutable().assign(newOffset);
- });
- }
- return success();
- }
-};
-
-} // namespace
-
-void BufferCompareOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspansIntoCompareOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// util.buffer.fill
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Folds subspan ranges into fill ranges.
-//
-// Example:
-// %0 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
-// util.buffer.fill %cst, %0[%offset for %length]
-// ->
-// %new_offset = arith.addi %offset, %subspan_offset
-// util.buffer.fill %cst, %dst[%new_offset for %subspan_length]
-struct FoldSubspansIntoFillOp : public OpRewritePattern<BufferFillOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferFillOp op,
- PatternRewriter &rewriter) const override {
- auto subspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
- if (!subspanOp) return failure();
- auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, subspanOp.getSourceOffset(), op.getTargetOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getTargetMutable().assign(subspanOp.getSource());
- op.getTargetSizeMutable().assign(subspanOp.getSourceSize());
- op.getTargetOffsetMutable().assign(newOffset);
- });
- return success();
- }
-};
-
-} // namespace
-
-void BufferFillOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspansIntoFillOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
// util.buffer.load
//===----------------------------------------------------------------------===//
-namespace {
-
-// Folds subspan offsets into loads.
-//
-// Example:
-// %0 = util.buffer.subspan %src[%subspan_offset] ... -> {%subspan_length}
-// %1 = util.buffer.load %0[%offset]
-// ->
-// %new_offset = arith.addi %offset, %subspan_offset
-// %1 = util.buffer.load %src[%new_offset]
-struct FoldSubspanIntoLoadOp : public OpRewritePattern<BufferLoadOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferLoadOp op,
- PatternRewriter &rewriter) const override {
- auto subspanOp = BufferSubspanOp::findSubspanOp(op.getSource());
- if (!subspanOp) return failure();
- auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, subspanOp.getSourceOffset(), op.getSourceOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getSourceMutable().assign(subspanOp.getSource());
- op.getSourceSizeMutable().assign(subspanOp.getSourceSize());
- op.getSourceOffsetMutable().assign(newOffset);
- });
- return success();
- }
-};
-
-} // namespace
-
OpFoldResult BufferLoadOp::fold(ArrayRef<Attribute> operands) {
// TODO(benvanik): if source is a constant then perform the load.
return {};
}
-void BufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspanIntoLoadOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// util.buffer.store
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Folds subspan offsets into stores.
-//
-// Example:
-// %0 = util.buffer.subspan %dst[%subspan_offset] ... -> {%subspan_length}
-// util.buffer.store %c123_i32, %0[%offset]
-// ->
-// %new_offset = arith.addi %offset, %subspan_offset
-// util.buffer.store %c123_i32, %dst[%new_offset]
-struct FoldSubspanIntoStoreOp : public OpRewritePattern<BufferStoreOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(BufferStoreOp op,
- PatternRewriter &rewriter) const override {
- auto subspanOp = BufferSubspanOp::findSubspanOp(op.getTarget());
- if (!subspanOp) return failure();
- auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()});
- auto newOffset = rewriter.createOrFold<arith::AddIOp>(
- fusedLoc, subspanOp.getSourceOffset(), op.getTargetOffset());
- rewriter.updateRootInPlace(op, [&]() {
- op.getTargetMutable().assign(subspanOp.getSource());
- op.getTargetSizeMutable().assign(subspanOp.getSourceSize());
- op.getTargetOffsetMutable().assign(newOffset);
- });
- return success();
- }
-};
-
-} // namespace
-
-void BufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSubspanIntoStoreOp>(context);
-}
-
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index d1a71ad..21ac8c8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -992,6 +992,25 @@
setNameFn(getResult(), "buffer");
}
+SubrangeOperand BufferSliceOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 0) {
+ return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(),
+ getResultSize()};
+ } else {
+ assert(false && "only source is a subrange");
+ return {};
+ }
+}
+
+void BufferSliceOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ assert(operandIndex == 0 && "only source is a subrange");
+ getSourceMutable().assign(operand.resource);
+ getSourceSizeMutable().assign(operand.resourceSize);
+ getSourceOffsetMutable().assign(operand.offset);
+ getResultSizeMutable().assign(operand.length);
+}
+
void BufferSubspanOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "buffer_span");
@@ -1003,6 +1022,25 @@
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
}
+SubrangeOperand BufferSubspanOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 0) {
+ return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(),
+ getResultSize()};
+ } else {
+ assert(false && "only source is a subrange");
+ return {};
+ }
+}
+
+void BufferSubspanOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ assert(operandIndex == 0 && "only source is a subrange");
+ getSourceMutable().assign(operand.resource);
+ getSourceSizeMutable().assign(operand.resourceSize);
+ getSourceOffsetMutable().assign(operand.offset);
+ getResultSizeMutable().assign(operand.length);
+}
+
::llvm::Optional<unsigned> BufferSubspanOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
@@ -1045,6 +1083,121 @@
setNameFn(getOffset(), "buffer_offset");
}
+SubrangeOperand BufferCopyOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 0) {
+ return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(),
+ getLength()};
+ } else if (operandIndex == 3) {
+ return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(),
+ getLength()};
+ } else {
+ assert(false && "only source/target are subranges");
+ return {};
+ }
+}
+
+void BufferCopyOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ if (operandIndex == 0) {
+ getSourceMutable().assign(operand.resource);
+ getSourceSizeMutable().assign(operand.resourceSize);
+ getSourceOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+ } else if (operandIndex == 3) {
+ getTargetMutable().assign(operand.resource);
+ getTargetSizeMutable().assign(operand.resourceSize);
+ getTargetOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+ } else {
+ assert(false && "only source/target are subranges");
+ }
+}
+
+SubrangeOperand BufferCompareOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 0) {
+ return SubrangeOperand{getLhs(), getLhsSize(), getLhsOffset(), getLength()};
+ } else if (operandIndex == 3) {
+ return SubrangeOperand{getRhs(), getRhsSize(), getRhsOffset(), getLength()};
+ } else {
+ assert(false && "only lhs/rhs are subranges");
+ return {};
+ }
+}
+
+void BufferCompareOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ if (operandIndex == 0) {
+ getLhsMutable().assign(operand.resource);
+ getLhsSizeMutable().assign(operand.resourceSize);
+ getLhsOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+ } else if (operandIndex == 3) {
+ getRhsMutable().assign(operand.resource);
+ getRhsSizeMutable().assign(operand.resourceSize);
+ getRhsOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+ } else {
+ assert(false && "only lhs/rhs are subranges");
+ }
+}
+
+SubrangeOperand BufferFillOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 1) {
+ return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(),
+ getLength()};
+ } else {
+ assert(false && "only target is a subrange");
+ return {};
+ }
+}
+
+void BufferFillOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ assert(operandIndex == 1 && "only target is a subrange");
+ getTargetMutable().assign(operand.resource);
+ getTargetSizeMutable().assign(operand.resourceSize);
+ getTargetOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+}
+
+SubrangeOperand BufferLoadOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 0) {
+ return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(),
+ getLength()};
+ } else {
+ assert(false && "only source is a subrange");
+ return {};
+ }
+}
+
+void BufferLoadOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ assert(operandIndex == 0 && "only source is a subrange");
+ getSourceMutable().assign(operand.resource);
+ getSourceSizeMutable().assign(operand.resourceSize);
+ getSourceOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+}
+
+SubrangeOperand BufferStoreOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 1) {
+ return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(),
+ getLength()};
+ } else {
+ assert(false && "only target is a subrange");
+ return {};
+ }
+}
+
+void BufferStoreOp::setSubrangeOperand(unsigned operandIndex,
+ SubrangeOperand operand) {
+ assert(operandIndex == 1 && "only target is a subrange");
+ getTargetMutable().assign(operand.resource);
+ getTargetSizeMutable().assign(operand.resourceSize);
+ getTargetOffsetMutable().assign(operand.offset);
+ getLengthMutable().assign(operand.length);
+}
+
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 7e67fd7..0fceb70 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -838,6 +838,7 @@
MemoryEffects<[MemAlloc, MemRead]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{clones a subregion of a buffer}];
let description = [{
@@ -865,8 +866,6 @@
Value getOperandSize(unsigned idx) { return getSourceSize(); }
Value getResultSize(unsigned idx) { return getResultSize(); }
}];
-
- let hasCanonicalizer = 1;
}
def Util_BufferSubspanOp : Util_PureOp<"buffer.subspan", [
@@ -875,6 +874,7 @@
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Util_SizeAwareOp,
Util_SubrangeOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedResult",
"getTiedResultOperandIndex",
@@ -996,6 +996,7 @@
def Util_BufferCopyOp : Util_Op<"buffer.copy", [
MemoryEffects<[MemRead, MemWrite]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{copies a range of bytes between buffers}];
let description = [{
@@ -1025,13 +1026,12 @@
Value getOperandSize(unsigned idx) { return idx == 0 ? getSourceSize() : getTargetSize(); }
Value getResultSize(unsigned idx) { return {}; }
}];
-
- let hasCanonicalizer = 1;
}
def Util_BufferCompareOp : Util_PureOp<"buffer.compare", [
MemoryEffects<[MemRead]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{compares a range of two buffers}];
let description = [{
@@ -1064,13 +1064,12 @@
Value getOperandSize(unsigned idx) { return idx == 0 ? getLhsSize() : getRhsSize(); }
Value getResultSize(unsigned idx) { return {}; }
}];
-
- let hasCanonicalizer = 1;
}
def Util_BufferFillOp : Util_Op<"buffer.fill", [
MemoryEffects<[MemWrite]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{fills a range of bytes with a value}];
let description = [{
@@ -1098,13 +1097,12 @@
Value getOperandSize(unsigned idx) { return getTargetSize(); }
Value getResultSize(unsigned idx) { return {}; }
}];
-
- let hasCanonicalizer = 1;
}
def Util_BufferLoadOp : Util_Op<"buffer.load", [
MemoryEffects<[MemRead]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{loads a value from a buffer}];
let description = [{
@@ -1115,14 +1113,15 @@
let arguments = (ins
Util_BufferType:$source,
Util_Size:$source_size,
- Util_Offset:$source_offset
+ Util_Offset:$source_offset,
+ Util_Size:$length
);
let results = (outs
Util_Primitive:$result
);
let assemblyFormat = [{
- $source `[` $source_offset `]`
+ $source `[` $source_offset `for` $length `]`
`:` type($source) `` `{` $source_size `}` `->` type($result)
attr-dict-with-keyword
}];
@@ -1132,13 +1131,13 @@
Value getResultSize(unsigned idx) { return {}; }
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
def Util_BufferStoreOp : Util_Op<"buffer.store", [
MemoryEffects<[MemWrite]>,
Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
]> {
let summary = [{stores a value into a buffer}];
let description = [{
@@ -1150,12 +1149,13 @@
Util_Primitive:$source,
Util_BufferType:$target,
Util_Size:$target_size,
- Util_Offset:$target_offset
+ Util_Offset:$target_offset,
+ Util_Size:$length
);
let assemblyFormat = [{
$source `,`
- $target `[` $target_offset `]`
+ $target `[` $target_offset `for` $length `]`
`:` type($source) `->` type($target) `` `{` $target_size `}`
attr-dict-with-keyword
}];
@@ -1164,8 +1164,6 @@
Value getOperandSize(unsigned idx) { return getTargetSize(); }
Value getResultSize(unsigned idx) { return {}; }
}];
-
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
index 82c1b6f..0d990a2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
@@ -77,6 +77,18 @@
static ValueAccess DiscardWrite() { return ValueAccess(false, true, true); }
};
+// An (offset, length) range within a size-aware resource.
+struct SubrangeOperand {
+ // Base resource the subrange references into.
+ Value resource;
+ // Size of the full base resource.
+ Value resourceSize;
+ // Offset into the base resource the range begins.
+ Value offset;
+ // Total length of the range within the base resource.
+ Value length;
+};
+
//===----------------------------------------------------------------------===//
// Global and structural interface utilities
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir
index 94d700f..811dacb 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_folding.mlir
@@ -60,10 +60,11 @@
// CHECK-LABEL: @FoldBufferSizeOp
func.func @FoldBufferSizeOp(%arg0: !util.buffer, %arg1: index) -> (index, i32) {
%c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
// CHECK-NOT: util.buffer.size
%0 = util.buffer.size %arg0 : !util.buffer
// CHECK: %[[LOAD:.+]] = util.buffer.load
- %1 = util.buffer.load %arg0[%c0] : !util.buffer{%arg1} -> i32
+ %1 = util.buffer.load %arg0[%c0 for %c4] : !util.buffer{%arg1} -> i32
// CHECK: return %arg1, %[[LOAD]]
return %0, %1 : index, i32
}
@@ -81,13 +82,13 @@
// CHECK: %[[BUFFER_SIZE_INNER:.+]] = util.buffer.size %[[BUFFER]]
%buffer_size_inner = util.buffer.size %buffer : !util.buffer
// CHECK: util.buffer.load %[[BUFFER]]{{.+}} : !util.buffer{%[[BUFFER_SIZE_INNER]]}
- %inner = util.buffer.load %buffer[%i] : !util.buffer{%buffer_size_inner} -> i8
+ %inner = util.buffer.load %buffer[%i for %c1] : !util.buffer{%buffer_size_inner} -> i8
util.optimization_barrier %inner : i8
}
// CHECK: %[[BUFFER_SIZE_OUTER:.+]] = util.buffer.size %[[BUFFER]]
%buffer_size_outer = util.buffer.size %buffer : !util.buffer
// CHECK: util.buffer.load %[[BUFFER]]{{.+}} : !util.buffer{%[[BUFFER_SIZE_OUTER]]}
- %outer = util.buffer.load %buffer[%c128] : !util.buffer{%buffer_size_outer} -> i8
+ %outer = util.buffer.load %buffer[%c128 for %c1] : !util.buffer{%buffer_size_outer} -> i8
util.optimization_barrier %outer : i8
return
}
@@ -184,13 +185,14 @@
// CHECK-LABEL: @FoldSubspanIntoLoadOp
func.func @FoldSubspanIntoLoadOp(%arg0: !util.buffer, %arg1: index) -> i32 {
+ %c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
// CHECK-NOT: util.buffer.subspan
%0 = util.buffer.subspan %arg0[%c128] : !util.buffer{%arg1} -> !util.buffer{%c256}
- // CHECK: = util.buffer.load %arg0[%c192] : !util.buffer{%arg1} -> i32
- %1 = util.buffer.load %0[%c64] : !util.buffer{%c256} -> i32
+ // CHECK: = util.buffer.load %arg0[%c192 for %c4] : !util.buffer{%arg1} -> i32
+ %1 = util.buffer.load %0[%c64 for %c4] : !util.buffer{%c256} -> i32
return %1 : i32
}
@@ -198,13 +200,14 @@
// CHECK-LABEL: @FoldSubspanIntoStoreOp
func.func @FoldSubspanIntoStoreOp(%arg0: !util.buffer, %arg1: index) {
+ %c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
%c123_i32 = arith.constant 123 : i32
// CHECK-NOT: util.buffer.subspan
%0 = util.buffer.subspan %arg0[%c128] : !util.buffer{%arg1} -> !util.buffer{%c256}
- // CHECK: util.buffer.store %c123_i32, %arg0[%c192] : i32 -> !util.buffer{%arg1}
- util.buffer.store %c123_i32, %0[%c64] : i32 -> !util.buffer{%c256}
+ // CHECK: util.buffer.store %c123_i32, %arg0[%c192 for %c4] : i32 -> !util.buffer{%arg1}
+ util.buffer.store %c123_i32, %0[%c64 for %c4] : i32 -> !util.buffer{%c256}
return
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
index 21c5049..1f36809 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/buffer_ops.mlir
@@ -109,9 +109,10 @@
// CHECK-LABEL: @buffer_load
func.func @buffer_load(%arg0: !util.buffer, %arg1: index) -> i32 {
+ %c4 = arith.constant 4 : index
%c100 = arith.constant 100 : index
- // CHECK: = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32
- %0 = util.buffer.load %arg0[%c100] : !util.buffer{%arg1} -> i32
+ // CHECK: = util.buffer.load %arg0[%c100 for %c4] : !util.buffer{%arg1} -> i32
+ %0 = util.buffer.load %arg0[%c100 for %c4] : !util.buffer{%arg1} -> i32
return %0 : i32
}
@@ -119,8 +120,9 @@
// CHECK-LABEL: @buffer_store
func.func @buffer_store(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
+ %c4 = arith.constant 4 : index
%c100 = arith.constant 100 : index
- // CHECK: util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1}
- util.buffer.store %arg2, %arg0[%c100] : i32 -> !util.buffer{%arg1}
+ // CHECK: util.buffer.store %arg2, %arg0[%c100 for %c4] : i32 -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%c100 for %c4] : i32 -> !util.buffer{%arg1}
return
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
index f20947a..b6eb2ef 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/buffer_ops.mlir
@@ -178,11 +178,12 @@
// CHECK-LABEL: @buffer_load_i1
func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i1 {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 1 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128
// CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128
// CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i8.s %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32
- %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i1
+ %0 = util.buffer.load %arg0[%byte_offset for %element_size] : !util.buffer{%arg1} -> i1
// CHECK: return %[[VALUE]]
return %0 : i1
}
@@ -192,11 +193,12 @@
// CHECK-LABEL: @buffer_load_i32
func.func @buffer_load_i32(%arg0: !util.buffer, %arg1: index) -> i32 {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 4 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32
// CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32
// CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i32 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i32
- %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i32
+ %0 = util.buffer.load %arg0[%byte_offset for %element_size] : !util.buffer{%arg1} -> i32
// CHECK: return %[[VALUE]]
return %0 : i32
}
@@ -206,11 +208,12 @@
// CHECK-LABEL: @buffer_load_i64
func.func @buffer_load_i64(%arg0: !util.buffer, %arg1: index) -> i64 {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 8 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16
// CHECK-32: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i64
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16
// CHECK-64: %[[VALUE:.+]] = vm.buffer.load.i64 %arg0[%[[ELEMENT_OFFSET]]] : !vm.buffer -> i64
- %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> i64
+ %0 = util.buffer.load %arg0[%byte_offset for %element_size] : !util.buffer{%arg1} -> i64
// CHECK: return %[[VALUE]]
return %0 : i64
}
@@ -220,9 +223,10 @@
// CHECK-LABEL: @buffer_load_index
func.func @buffer_load_index(%arg0: !util.buffer, %arg1: index) -> index {
%byte_offset = arith.constant 100 : index
+ %element_size = util.sizeof index
// CHECK-32: vm.buffer.load.i32
// CHECK-64: vm.buffer.load.i64
- %0 = util.buffer.load %arg0[%byte_offset] : !util.buffer{%arg1} -> index
+ %0 = util.buffer.load %arg0[%byte_offset for %element_size] : !util.buffer{%arg1} -> index
return %0 : index
}
@@ -231,11 +235,12 @@
// CHECK-LABEL: @buffer_store_i1
func.func @buffer_store_i1(%arg0: !util.buffer, %arg1: index, %arg2: i1) {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 1 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128
// CHECK-32: vm.buffer.store.i8 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 128
// CHECK-64: vm.buffer.store.i8 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer
- util.buffer.store %arg2, %arg0[%byte_offset] : i1 -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%byte_offset for %element_size] : i1 -> !util.buffer{%arg1}
return
}
@@ -244,11 +249,12 @@
// CHECK-LABEL: @buffer_store_i32
func.func @buffer_store_i32(%arg0: !util.buffer, %arg1: index, %arg2: i32) {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 4 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32
// CHECK-32: vm.buffer.store.i32 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 32
// CHECK-64: vm.buffer.store.i32 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i32 -> !vm.buffer
- util.buffer.store %arg2, %arg0[%byte_offset] : i32 -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%byte_offset for %element_size] : i32 -> !util.buffer{%arg1}
return
}
@@ -257,11 +263,12 @@
// CHECK-LABEL: @buffer_store_i64
func.func @buffer_store_i64(%arg0: !util.buffer, %arg1: index, %arg2: i64) {
%byte_offset = arith.constant 128 : index
+ %element_size = arith.constant 8 : index
// CHECK-32-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16
// CHECK-32: vm.buffer.store.i64 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i64 -> !vm.buffer
// CHECK-64-DAG: %[[ELEMENT_OFFSET:.+]] = vm.const.i64 16
// CHECK-64: vm.buffer.store.i64 %arg2, %arg0[%[[ELEMENT_OFFSET]]] : i64 -> !vm.buffer
- util.buffer.store %arg2, %arg0[%byte_offset] : i64 -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%byte_offset for %element_size] : i64 -> !util.buffer{%arg1}
return
}
@@ -270,8 +277,9 @@
// CHECK-LABEL: @buffer_store_index
func.func @buffer_store_index(%arg0: !util.buffer, %arg1: index, %arg2: index) {
%byte_offset = arith.constant 100 : index
+ %element_size = util.sizeof index
// CHECK-32: vm.buffer.store.i32
// CHECK-64: vm.buffer.store.i64
- util.buffer.store %arg2, %arg0[%byte_offset] : index -> !util.buffer{%arg1}
+ util.buffer.store %arg2, %arg0[%byte_offset for %element_size] : index -> !util.buffer{%arg1}
return
}
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
index 43d0046..db61c5a 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
@@ -199,7 +199,7 @@
auto byteOffset = rewriter.createOrFold<arith::MulIOp>(
op.getLoc(), elementSize, constantIndex);
rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
- op, resultType, constantsArg, constantsSize, byteOffset);
+ op, resultType, constantsArg, constantsSize, byteOffset, elementSize);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
index 0016aa8..96224eb 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
@@ -107,19 +107,22 @@
Value buffer = setterOp.getArgument(0);
Value bufferSize = setterBuilder.create<arith::ConstantIndexOp>(
buffer.getLoc(), allLoadOps.size() * sizeof(uint32_t));
- Value elementSize = setterBuilder.create<arith::ConstantIntOp>(
+ Value elementSizeIndex = setterBuilder.create<arith::ConstantIndexOp>(
+ buffer.getLoc(), sizeof(uint32_t));
+ Value elementSizeI32 = setterBuilder.create<arith::ConstantIntOp>(
buffer.getLoc(), sizeof(uint32_t), 32);
for (auto [ordinalGlobalOp, valueGlobalOp] :
llvm::zip(ordinalGlobalOps, valueGlobalOps)) {
Value loadedOrdinal = setterBuilder.create<IREE::Util::GlobalLoadOp>(
ordinalGlobalOp.getLoc(), ordinalGlobalOp);
Value bufferOffset = setterBuilder.create<arith::MulIOp>(
- loadedOrdinal.getLoc(), loadedOrdinal, elementSize);
+ loadedOrdinal.getLoc(), loadedOrdinal, elementSizeI32);
Value loadedValue = setterBuilder.create<IREE::Util::BufferLoadOp>(
valueGlobalOp.getLoc(), loadedOrdinal.getType(), buffer, bufferSize,
setterBuilder.create<arith::IndexCastOp>(bufferOffset.getLoc(),
setterBuilder.getIndexType(),
- bufferOffset));
+ bufferOffset),
+ elementSizeIndex);
setterBuilder.create<IREE::Util::GlobalStoreOp>(
valueGlobalOp.getLoc(), loadedValue, valueGlobalOp);
}
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/materialize_constants.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/materialize_constants.mlir
index efbc483..9cd25c1 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/materialize_constants.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/materialize_constants.mlir
@@ -21,12 +21,12 @@
// CHECK: %[[FOO_ORDINAL:.+]] = util.global.load @__constant_foo_ordinal
// CHECK: %[[FOO_OFFSET:.+]] = arith.muli %[[FOO_ORDINAL]], %c4
// CHECK: %[[FOO_OFFSET_IDX:.+]] = arith.index_cast %[[FOO_OFFSET]]
-// CHECK: %[[FOO_VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[FOO_OFFSET_IDX]]] : !util.buffer{%[[BUFFER_SIZE]]}
+// CHECK: %[[FOO_VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[FOO_OFFSET_IDX]] for {{.+}}] : !util.buffer{%[[BUFFER_SIZE]]}
// CHECK: util.global.store %[[FOO_VALUE]], @__constant_foo : i32
// CHECK: %[[BAR_ORDINAL:.+]] = util.global.load @__constant_bar_ordinal
// CHECK: %[[BAR_OFFSET:.+]] = arith.muli %[[BAR_ORDINAL]], %c4
// CHECK: %[[BAR_OFFSET_IDX:.+]] = arith.index_cast %[[BAR_OFFSET]]
-// CHECK: %[[BAR_VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[BAR_OFFSET_IDX]]] : !util.buffer{%[[BUFFER_SIZE]]}
+// CHECK: %[[BAR_VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[BAR_OFFSET_IDX]] for {{.+}}] : !util.buffer{%[[BUFFER_SIZE]]}
// CHECK: util.global.store %[[BAR_VALUE]], @__constant_bar : i32
// CHECK: return
// CHECK: }
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp
index d35e9c8..4904d66 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp
@@ -60,8 +60,11 @@
Value storageSize = rewriter.create<IREE::HAL::Inline::BufferLengthOp>(
op.getLoc(), adaptor.getSourceBuffer());
auto loadType = getTypeConverter()->convertType(op.getResult().getType());
+ auto elementSize =
+ rewriter.createOrFold<IREE::Util::SizeOfOp>(op.getLoc(), loadType);
rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
- op, loadType, storageBuffer, storageSize, adaptor.getSourceOffset());
+ op, loadType, storageBuffer, storageSize, adaptor.getSourceOffset(),
+ elementSize);
return success();
}
};
@@ -77,9 +80,11 @@
op.getLoc(), adaptor.getTargetBuffer());
Value storageSize = rewriter.create<IREE::HAL::Inline::BufferLengthOp>(
op.getLoc(), adaptor.getTargetBuffer());
+ auto elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+ op.getLoc(), adaptor.getValue().getType());
rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
op, adaptor.getValue(), storageBuffer, storageSize,
- adaptor.getTargetOffset());
+ adaptor.getTargetOffset(), elementSize);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir
index b6ce93a..815bd44 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_ops.mlir
@@ -32,7 +32,7 @@
%rel_offset = arith.constant 100 : index
// CHECK-DAG: %[[STORAGE:.+]] = hal_inline.buffer.storage<%[[BUFFER:.+]] : !hal.buffer> : !util.buffer
// CHECK-DAG: %[[LENGTH:.+]] = hal_inline.buffer.length<%[[BUFFER]] : !hal.buffer> : index
- // CHECK: %[[VALUE:.+]] = util.buffer.load %[[STORAGE]][%[[REL_OFFSET]]] : !util.buffer{%[[LENGTH]]} -> i32
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[STORAGE]][%[[REL_OFFSET]] for {{.+}}] : !util.buffer{%[[LENGTH]]} -> i32
%value = hal.buffer.load<%buffer : !hal.buffer>[%rel_offset] : i32
// CHECK-NEXT: return %[[VALUE]]
return %value : i32
@@ -47,7 +47,7 @@
%rel_offset = arith.constant 100 : index
// CHECK-DAG: %[[STORAGE:.+]] = hal_inline.buffer.storage<%[[BUFFER:.+]] : !hal.buffer> : !util.buffer
// CHECK-DAG: %[[LENGTH:.+]] = hal_inline.buffer.length<%[[BUFFER]] : !hal.buffer> : index
- // CHECK: util.buffer.store %[[VALUE]], %[[STORAGE]][%[[REL_OFFSET]]] : i32 -> !util.buffer{%[[LENGTH]]}
+ // CHECK: util.buffer.store %[[VALUE]], %[[STORAGE]][%[[REL_OFFSET]] for {{.+}}] : i32 -> !util.buffer{%[[LENGTH]]}
hal.buffer.store<%buffer : !hal.buffer>[%rel_offset] value(%value : i32)
return
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
index 16fcb11..f893502 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
@@ -184,9 +184,11 @@
adaptor.getSourceSize(), rewriter);
auto loadType =
getTypeConverter()->convertType(loadOp.getResult().getType());
+ auto elementSize =
+ rewriter.createOrFold<IREE::Util::SizeOfOp>(loc, loadType);
rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
loadOp, loadType, storage.buffer, storage.bufferSize,
- adaptor.getSourceOffset());
+ adaptor.getSourceOffset(), elementSize);
return success();
}
};
@@ -200,9 +202,11 @@
auto loc = storeOp.getLoc();
auto storage = getResourceStorage(loc, adaptor.getTarget(),
adaptor.getTargetSize(), rewriter);
+ auto elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
+ loc, adaptor.getValue().getType());
rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
storeOp, adaptor.getValue(), storage.buffer, storage.bufferSize,
- adaptor.getTargetOffset());
+ adaptor.getTargetOffset(), elementSize);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir
index 1d18138..f356e16 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/resource_ops.mlir
@@ -105,7 +105,7 @@
// CHECK-LABEL: @resourceLoad
// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer, %[[BUFFER_SIZE:.+]]: index, %[[OFFSET:.+]]: index)
func.func @resourceLoad(%resource: !stream.resource<staging>, %resource_size: index, %offset: index) -> i32 {
- // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[OFFSET]]] : !util.buffer{%[[BUFFER_SIZE]]} -> i32
+ // CHECK: %[[VALUE:.+]] = util.buffer.load %[[BUFFER]][%[[OFFSET]] for {{.+}}] : !util.buffer{%[[BUFFER_SIZE]]} -> i32
%0 = stream.resource.load %resource[%offset] : !stream.resource<staging>{%resource_size} -> i32
// CHECK: return %[[VALUE]]
return %0 : i32
@@ -118,7 +118,7 @@
func.func @resourceStore(%resource: !stream.resource<staging>, %resource_size: index, %offset: index) {
// CHECK-DAG: %[[VALUE:.+]] = arith.constant 123
%value = arith.constant 123 : i32
- // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFFSET]]] : i32 -> !util.buffer{%[[BUFFER_SIZE]]}
+ // CHECK: util.buffer.store %[[VALUE]], %[[BUFFER]][%[[OFFSET]] for {{.+}}] : i32 -> !util.buffer{%[[BUFFER_SIZE]]}
stream.resource.store %value, %resource[%offset] : i32 -> !stream.resource<staging>{%resource_size}
return
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
index a826f1b..ce313ae 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/inline_executables.mlir
@@ -36,10 +36,12 @@
%workgroup_x: i32, %workgroup_y: i32, %workgroup_z: i32,
%workgroup_size_x: i32, %workgroup_size_y: i32, %workgroup_size_z: i32,
%workgroup_count_x: i32, %workgroup_count_y: i32, %workgroup_count_z: i32) {
+ %c4 = arith.constant 4 : index
+
// Unpack push constants:
%constants_size = util.buffer.size %constants : !util.buffer
%constant1_offset = arith.constant 4 : index
- %constant1_i32 = util.buffer.load %constants[%constant1_offset] : !util.buffer{%constants_size} -> i32
+ %constant1_i32 = util.buffer.load %constants[%constant1_offset for %c4] : !util.buffer{%constants_size} -> i32
%constant1_f32 = arith.sitofp %constant1_i32 : i32 to f32
// Unpack buffer bindings:
@@ -57,16 +59,14 @@
%global_constant = util.global.load @global_constant : !util.buffer
util.optimization_barrier %global_constant : !util.buffer
-
- %c4 = arith.constant 4 : index
%workgroup_x_idx = arith.index_cast %workgroup_x : i32 to index
scf.for %i = %c0 to %workgroup_x_idx step %c1 {
%idx = arith.muli %i, %c4 : index
- %lhs = util.buffer.load %buffer0[%idx] : !util.buffer{%buffer0_size} -> f32
- %rhs = util.buffer.load %buffer1[%idx] : !util.buffer{%buffer1_size} -> f32
+ %lhs = util.buffer.load %buffer0[%idx for %c4] : !util.buffer{%buffer0_size} -> f32
+ %rhs = util.buffer.load %buffer1[%idx for %c4] : !util.buffer{%buffer1_size} -> f32
%mul = arith.mulf %lhs, %rhs : f32
%scaled = arith.mulf %mul, %constant1_f32 : f32
- util.buffer.store %scaled, %buffer2[%idx] : f32 -> !util.buffer{%buffer2_size}
+ util.buffer.store %scaled, %buffer2[%idx for %c4] : f32 -> !util.buffer{%buffer2_size}
}
return
}
@@ -110,11 +110,11 @@
// CHECK: %[[X_IDX:.+]] = arith.index_cast %[[X_I32]]
// CHECK: scf.for %[[ELEMENT_INDEX:.+]] = %c0 to %[[X_IDX]]
// CHECK: %[[ELEMENT_OFFSET:.+]] = arith.muli %[[ELEMENT_INDEX]]
-// CHECK: %[[LHS:.+]] = util.buffer.load %[[BINDING0]][%[[ELEMENT_OFFSET]]] : !util.buffer{%[[BINDING0_SIZE]]} -> f32
-// CHECK: %[[RHS:.+]] = util.buffer.load %[[BINDING1]][%[[ELEMENT_OFFSET]]] : !util.buffer{%[[BINDING1_SIZE]]} -> f32
+// CHECK: %[[LHS:.+]] = util.buffer.load %[[BINDING0]][%[[ELEMENT_OFFSET]] for {{.+}}] : !util.buffer{%[[BINDING0_SIZE]]} -> f32
+// CHECK: %[[RHS:.+]] = util.buffer.load %[[BINDING1]][%[[ELEMENT_OFFSET]] for {{.+}}] : !util.buffer{%[[BINDING1_SIZE]]} -> f32
// CHECK: %[[MUL:.+]] = arith.mulf %[[LHS]], %[[RHS]] : f32
// CHECK: %[[SCALED:.+]] = arith.mulf %[[MUL]], %[[CONSTANT1_F32]] : f32
-// CHECK: util.buffer.store %[[SCALED]], %[[BINDING2]][%[[ELEMENT_OFFSET]]] : f32 -> !util.buffer{%[[BINDING2_SIZE]]}
+// CHECK: util.buffer.store %[[SCALED]], %[[BINDING2]][%[[ELEMENT_OFFSET]] for {{.+}}] : f32 -> !util.buffer{%[[BINDING2_SIZE]]}
// CHECK: return
// CHECK-LABEL: func private @__dispatch_ex_dispatch_0