Removing hal.buffer_view.compute_offset/compute_range. Buffer views are fully type erased and these won't work when we have non-dense tensors.
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 4df117a..312ead6 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -427,77 +427,6 @@ namespace { -/// Expands a hal.buffer_view.compute_offset op to use -/// hal.allocator.compute_offset. This allows for all of the shape math to -/// happen in the VM where we can better optimize it. -struct ExpandBufferViewComputeOffsetOp - : public OpRewritePattern<BufferViewComputeOffsetOp> { - using OpRewritePattern<BufferViewComputeOffsetOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(BufferViewComputeOffsetOp op, - PatternRewriter &rewriter) const override { - auto bufferValue = rewriter.createOrFold<BufferViewBufferOp>( - op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), - op.buffer_view()); - auto allocatorValue = rewriter.createOrFold<BufferAllocatorOp>( - op.getLoc(), AllocatorType::get(rewriter.getContext()), bufferValue); - int rank = op.indices().size(); - SmallVector<Type, 4> dimTypes(rank, rewriter.getIndexType()); - auto dimsOp = rewriter.create<BufferViewDimsOp>(op.getLoc(), dimTypes, - op.buffer_view()); - auto elementTypeValue = rewriter.createOrFold<BufferViewElementTypeOp>( - op.getLoc(), rewriter.getI32Type(), op.buffer_view()); - rewriter.replaceOpWithNewOp<AllocatorComputeOffsetOp>( - op, allocatorValue, dimsOp.result(), elementTypeValue, op.indices()); - return success(); - } -}; - -} // namespace - -void BufferViewComputeOffsetOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<ExpandBufferViewComputeOffsetOp>(context); -} - -namespace { - -/// Expands a hal.buffer_view.compute_range op to use -/// hal.allocator.compute_range. This allows for all of the shape math to -/// happen in the VM where we can better optimize it. -struct ExpandBufferViewComputeRangeOp - : public OpRewritePattern<BufferViewComputeRangeOp> { - using OpRewritePattern<BufferViewComputeRangeOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(BufferViewComputeRangeOp op, - PatternRewriter &rewriter) const override { - auto bufferValue = rewriter.createOrFold<BufferViewBufferOp>( - op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), - op.buffer_view()); - auto allocatorValue = rewriter.createOrFold<BufferAllocatorOp>( - op.getLoc(), AllocatorType::get(rewriter.getContext()), bufferValue); - int rank = op.indices().size(); - SmallVector<Type, 4> dimTypes(rank, rewriter.getIndexType()); - auto dimsOp = rewriter.create<BufferViewDimsOp>(op.getLoc(), dimTypes, - op.buffer_view()); - auto elementTypeValue = rewriter.createOrFold<BufferViewElementTypeOp>( - op.getLoc(), rewriter.getI32Type(), op.buffer_view()); - rewriter.replaceOpWithNewOp<AllocatorComputeRangeOp>( - op, allocatorValue, dimsOp.result(), elementTypeValue, op.indices(), - op.lengths()); - return success(); - } -}; - -} // namespace - -void BufferViewComputeRangeOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<ExpandBufferViewComputeRangeOp>(context); -} - -namespace { - /// Expands a hal.buffer_view.dims op into individual ops for each dimension. struct ExpandBufferViewDimsOp : public OpRewritePattern<BufferViewDimsOp> { using OpRewritePattern<BufferViewDimsOp>::OpRewritePattern;
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 52a03a9..7e3fb55 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -547,41 +547,6 @@ } //===----------------------------------------------------------------------===// -// hal.buffer_view.compute_offset -//===----------------------------------------------------------------------===// - -void BufferViewComputeOffsetOp::build(OpBuilder &builder, OperationState &state, - Value bufferView, ValueRange indices) { - state.addOperands({bufferView}); - state.addOperands(indices); - state.addTypes({builder.getIndexType()}); -} - -void BufferViewComputeOffsetOp::getAsmResultNames( - function_ref<void(Value, StringRef)> setNameFn) { - setNameFn(offset(), "off"); -} - -//===----------------------------------------------------------------------===// -// hal.buffer_view.compute_range -//===----------------------------------------------------------------------===// - -void BufferViewComputeRangeOp::build(OpBuilder &builder, OperationState &state, - Value bufferView, ValueRange indices, - ValueRange lengths) { - state.addOperands({bufferView}); - state.addOperands(indices); - state.addOperands(lengths); - state.addTypes({builder.getIndexType(), builder.getIndexType()}); -} - -void BufferViewComputeRangeOp::getAsmResultNames( - function_ref<void(Value, StringRef)> setNameFn) { - setNameFn(offset(), "off"); - setNameFn(length(), "len"); -} - -//===----------------------------------------------------------------------===// // hal.command_buffer.create //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td index 585561d..91cfbf3 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -154,7 +154,6 @@ let summary = [{buffer view indices to byte offset computation operation}]; let description = [{ Computes an element byte offset within a buffer produced by the allocator. - This returns the same value as `hal.buffer_view.compute_offset`. }]; let arguments = (ins @@ -194,7 +193,6 @@ let summary = [{buffer view byte range computation operation}]; let description = [{ Computes a byte range within a buffer for one or more elements. - This returns the same value as `hal.buffer_view.compute_range`. }]; let arguments = (ins @@ -649,76 +647,6 @@ ]; } -def HAL_BufferViewComputeOffsetOp : HAL_PureOp<"buffer_view.compute_offset", [ - DeclareOpInterfaceMethods<OpAsmOpInterface>, - ]> { - let summary = [{buffer view indices to byte offset computation operation}]; - let description = [{ - Computes an element byte offset within a buffer view. - }]; - - let arguments = (ins - HAL_BufferView:$buffer_view, - HAL_Dims:$indices - ); - - let results = (outs - HAL_DeviceSize:$offset - ); - - let assemblyFormat = [{ - $buffer_view `,` - `indices` `=` `[` $indices `]` - attr-dict `:` type($offset) - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$bufferView, "ValueRange":$indices)>, - ]; - - let hasCanonicalizer = 1; -} - -def HAL_BufferViewComputeRangeOp : HAL_PureOp<"buffer_view.compute_range", [ - DeclareOpInterfaceMethods<OpAsmOpInterface>, - SameVariadicOperandSize, - ]> { - let summary = [{buffer view byte range computation operation}]; - let description = [{ - Computes a byte range within a buffer for one or more elements. - }]; - - let arguments = (ins - HAL_BufferView:$buffer_view, - HAL_Dims:$indices, - HAL_Dims:$lengths - ); - let results = (outs - // TODO(benvanik): return a strides tuple instead, or one per dim. - HAL_DeviceSize:$offset, - HAL_DeviceSize:$length - ); - - let assemblyFormat = [{ - $buffer_view `,` - `indices` `=` `[` $indices `]` `,` - `lengths` `=` `[` $lengths `]` - attr-dict `:` type($offset) `,` type($length) - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins - "Value":$bufferView, - "ValueRange":$indices, - "ValueRange":$lengths - )>, - ]; - - let hasCanonicalizer = 1; -} - def HAL_BufferViewElementTypeOp : HAL_PureOp<"buffer_view.element_type"> { let summary = [{buffer view element type query}]; let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir index cfc752e..802b0ae 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
@@ -14,46 +14,6 @@ // ----- -// CHECK-LABEL: func @buffer_view_compute_offset -// CHECK-SAME: %[[VIEW:.+]]: !hal.buffer_view -func @buffer_view_compute_offset(%arg0 : !hal.buffer_view) -> index { - // CHECK: %[[INDICES:.+]]:2 = "test_hal.indices"() : () -> (index, index) - %0:2 = "test_hal.indices"() : () -> (index, index) - // CHECK: %[[D0:.+]] = hal.buffer_view.dim %[[VIEW]], 1 : index - // CHECK: %[[TYPE:.+]] = hal.buffer_view.element_type %[[VIEW]] : i32 - // CHECK: %[[T0:.+]] = muli %[[INDICES]]#0, %[[D0]] : index - // CHECK: %[[T1:.+]] = addi %[[T0]], %[[INDICES]]#1 : index - // CHECK: %[[T2:.+]] = index_cast %[[TYPE]] : i32 to index - // CHECK: %[[T3:.+]] = and %[[T2]], %c255 : index - // CHECK: %[[T4:.+]] = addi %[[T3]], %c7 : index - // CHECK: %[[T6:.+]] = divi_unsigned %[[T4]], %c8 : index - // CHECK: %[[T7:.+]] = muli %[[T1]], %[[T6]] : index - %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] : index - // CHECK: return %[[T7]] - return %off : index -} - -// ----- - -// CHECK-LABEL: func @buffer_view_compute_range -// CHECK-SAME: %[[VIEW:.+]]: !hal.buffer_view -func @buffer_view_compute_range(%arg0 : !hal.buffer_view) -> (index, index) { - %0:2 = "test_hal.indices"() : () -> (index, index) - %1:2 = "test_hal.lengths"() : () -> (index, index) - // Testing things like this is brittle :/ - // Since the canonicalizers are taking these buffer view ops to allocator ops - // the testing there should cover with the checks here just to make sure the - // right values from the buffer view are passed in. - // CHECK: = hal.buffer_view.dim %[[VIEW]], 1 : index - // CHECK: = hal.buffer_view.element_type %[[VIEW]] : i32 - // << A BUNCH OF MATH >> - %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] : index, index - // CHECK: return - return %off, %len : index, index -} - -// ----- - // CHECK-LABEL: func @expand_buffer_view_dims // CHECK-SAME: %[[VIEW:.+]]: !hal.buffer_view func @expand_buffer_view_dims(%arg0 : !hal.buffer_view) -> (index, index, index) {
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir index 15b2add..d254350 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir
@@ -29,27 +29,6 @@ // ----- -// CHECK-LABEL: @buffer_view_compute_offset -func @buffer_view_compute_offset(%arg0: !hal.buffer_view) -> index { - %0:2 = "test_hal.indices"() : () -> (index, index) - // CHECK: %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] - %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] : index - return %off : index -} - -// ----- - -// CHECK-LABEL: @buffer_view_compute_range -func @buffer_view_compute_range(%arg0: !hal.buffer_view) -> (index, index) { - %0:2 = "test_hal.indices"() : () -> (index, index) - %1:2 = "test_hal.lengths"() : () -> (index, index) - // CHECK: %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] - %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] : index, index - return %off, %len : index, index -} - -// ----- - // CHECK-LABEL: @buffer_view_shape_queries func @buffer_view_shape_queries(%arg0: !hal.buffer_view) -> (index, index, index, index) { // CHECK: %{{.+}} = hal.buffer_view.rank %arg0 : index
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp index 1eb2477..36e091c 100644 --- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp +++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -201,30 +201,19 @@ } Value TensorRewriteAdaptor::computeOffset(ValueRange indices) { - if (isBufferView()) { - return rewriter_.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>( - loc_, getBufferView(), indices); - } else { - auto shapeDims = getShapeDims(); - if (!shapeDims) return {}; - return rewriter_.createOrFold<IREE::HAL::AllocatorComputeOffsetOp>( - loc_, getAllocator(), *shapeDims, getElementType(), indices); - } + auto shapeDims = getShapeDims(); + if (!shapeDims) return {}; + return rewriter_.createOrFold<IREE::HAL::AllocatorComputeOffsetOp>( + loc_, getAllocator(), *shapeDims, getElementType(), indices); } llvm::Optional<TensorRewriteAdaptor::Range> TensorRewriteAdaptor::computeRange( ValueRange indices, ValueRange lengths) { - if (isBufferView()) { - auto range = rewriter_.create<IREE::HAL::BufferViewComputeRangeOp>( - loc_, getBufferView(), indices, lengths); - return Range{range.offset(), range.length()}; - } else { - auto shapeDims = getShapeDims(); - if (!shapeDims) return llvm::None; - auto range = rewriter_.create<IREE::HAL::AllocatorComputeRangeOp>( - loc_, getAllocator(), *shapeDims, getElementType(), indices, lengths); - return Range{range.offset(), range.length()}; - } + auto shapeDims = getShapeDims(); + if (!shapeDims) return llvm::None; + auto range = rewriter_.create<IREE::HAL::AllocatorComputeRangeOp>( + loc_, getAllocator(), *shapeDims, getElementType(), indices, lengths); + return Range{range.offset(), range.length()}; } } // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h index f640f2f..f91304a 100644 --- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h +++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h
@@ -91,7 +91,7 @@ // Performs the equivalent of a hal.buffer_view.byte_length. Value getByteLength(); - // Performs the equivalent of a hal.buffer_view.compute_offset. + // Performs the equivalent of a hal.allocator.compute_offset. Value computeOffset(ValueRange indices); struct Range { @@ -99,7 +99,7 @@ Value length; }; - // Performs the equivalent of a hal.buffer_view.compute_range. + // Performs the equivalent of a hal.allocator.compute_range. llvm::Optional<Range> computeRange(ValueRange indices, ValueRange lengths); private: