Add a pass to bufferize early copy only dispatches. (#8648)
For cases where the dispatches are copy only , i.e. data is transfered
from one interface binding to another, bufferizing early results in a
linalg.generic (i.e. a copy) operation in the dispatch. The backends
can use this to generate code. This is the first step in dropping
TiledOpInterface for tensor.insert_slice and
tensor.extract_slice operations. This commit also adds patterns to
fold these operations with the flow.dispatch.tensor.load and
flow.dispatch.tensor.store operations. Eventually these patterns
will be moved to canonicalizations when the TiledOpInterface
implementation for tensor.insert_slice and tensor.extract_slice
are dropped.
Fixes #8509
Resubmit of PR #8529
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index 2a0a363..284bed4 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -35,6 +35,7 @@
name = "Common",
srcs = [
"BufferizationAnalysis.cpp",
+ "BufferizeCopyOnlyDispatchesPass.cpp",
"CleanupBufferAllocViewPass.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"DemoteF32ToF16.cpp",
@@ -104,5 +105,6 @@
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
"@llvm-project//mlir:VectorTransforms",
+ "@llvm-project//mlir:ViewLikeInterface",
],
)
diff --git a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index 4b2d5de..f7b22b0 100644
--- a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -14,11 +14,13 @@
#include "iree/compiler/Codegen/Common/BufferizationAnalysis.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -66,7 +68,7 @@
/// Walks the use-def chain and see if this value comes from a read-only tensor.
static bool isFromReadOnlyTensor(Value v, const BufferizationPlan &plan) {
- auto definingOp = v.getDefiningOp();
+ Operation *definingOp = v.getDefiningOp();
if (!definingOp) {
auto arg = v.cast<BlockArgument>();
return TypeSwitch<Operation *, bool>(arg.getOwner()->getParentOp())
@@ -79,22 +81,7 @@
})
.Default([&](Operation *op) { return false; });
}
- return TypeSwitch<Operation *, bool>(definingOp)
- .Case<arith::ConstantOp>(
- [&](arith::ConstantOp constantOp) { return true; })
- .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
- [&](auto op) { return isFromReadOnlyTensor(op.src(), plan); })
- .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) {
- return isFromReadOnlyTensor(sliceOp.source(), plan);
- })
- .Case<IREE::Flow::DispatchTensorLoadOp>(
- [&](IREE::Flow::DispatchTensorLoadOp loadOp) {
- return loadOp.source()
- .getType()
- .cast<IREE::Flow::DispatchTensorType>()
- .getAccess() == IREE::Flow::TensorAccess::ReadOnly;
- })
- .Default([&](Operation *op) { return false; });
+ return isReadOnly(v);
}
/// Adds the result of `std.constant` to its set (there is nothing to tie to
@@ -574,7 +561,7 @@
})
.Case<vector::TransferWriteOp>(
[&](vector::TransferWriteOp transferWriteOp) {
- if (!transferWriteOp.result().getType().isa<RankedTensorType>()) {
+ if (!transferWriteOp.source().getType().isa<RankedTensorType>()) {
return success();
}
return analyseDestructiveUpdateOp(transferWriteOp, nullptr,
@@ -586,7 +573,7 @@
.Case<scf::ForOp>(
[&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); })
.Case<scf::YieldOp, linalg::InitTensorOp, tensor::DimOp,
- tensor::ExtractOp, tensor::PadOp>(
+ tensor::ExtractOp, tensor::PadOp, bufferization::ToMemrefOp>(
[&](Operation *op) { return success(); })
.Default([&](Operation *op) -> LogicalResult {
if (llvm::any_of(op->getOperands(),
diff --git a/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp b/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp
new file mode 100644
index 0000000..30d8a83
--- /dev/null
+++ b/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp
@@ -0,0 +1,283 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- BufferizeCopyOnlyDispatchesPassPass.cpp ----------------------------===//
+//
+// This pass converts dispatches that are copy only into a form where backends
+// can tile and distribute them appropriately.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Helper function to create `AffineExpr` from `OpFoldResult`. If the
+/// `OpFoldResult` is a `Value`, creates a `AffineSymbolExpr` and appends it to
+/// `symbols`.
+static AffineExpr getAffineExpr(OpFoldResult ofr, SmallVector<Value> &symbols) {
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
+ attr.getContext());
+ }
+ Value v = ofr.get<Value>();
+ AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
+ symbols.push_back(v);
+ return expr;
+}
+/// Converts an `AffineExpr` to `OpFoldResult` by generating an `affine.apply`
+/// operation.
+static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
+ AffineExpr expr,
+ SmallVector<Value> &symbols) {
+ AffineMap m = AffineMap::get(0, symbols.size(), expr);
+ return applyMapToValues(builder, loc, m, symbols)[0];
+}
+
+/// Methods to build the Affine Expr for arithmetic operations.
+static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
+ SmallVector<Value> &symbols) {
+ return expr + getAffineExpr(ofr, symbols);
+}
+static AffineExpr add(OpFoldResult lhs, OpFoldResult rhs,
+ SmallVector<Value> &symbols) {
+ return getAffineExpr(lhs, symbols) + getAffineExpr(rhs, symbols);
+}
+static AffineExpr mul(AffineExpr expr, OpFoldResult ofr,
+ SmallVector<Value> &symbols) {
+ return expr * getAffineExpr(ofr, symbols);
+}
+static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
+ SmallVector<Value> &symbols) {
+ return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
+}
+
+/// Returns the offsets to use when combining two operations that implement the
+/// `OffsetSizeAndStrideOpInterface`. Also checks that the strides are 1.
+static LogicalResult foldOffsetsSizesAndStrides(
+ PatternRewriter &rewriter, Location loc,
+ OffsetSizeAndStrideOpInterface producer,
+ OffsetSizeAndStrideOpInterface consumer,
+ SmallVector<OpFoldResult> &combinedOffsets,
+ SmallVector<OpFoldResult> &combinedSizes,
+ SmallVector<OpFoldResult> &combinedStrides) {
+ SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
+ SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
+ SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
+ SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
+ if (producerOffsets.size() != consumerOffsets.size()) {
+ return rewriter.notifyMatchFailure(
+ consumer,
+ "expected op and producer to have same number of offset values");
+ }
+
+ combinedOffsets.resize(producerOffsets.size());
+ combinedSizes.resize(producerOffsets.size());
+ combinedStrides.resize(producerOffsets.size());
+ for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
+ SmallVector<Value> offsetSymbols, strideSymbols;
+ // The combined offset is computed as
+ // producer_offset + consumer_offset * producer_strides.
+ combinedOffsets[i] = getOpFoldResult(
+ rewriter, loc,
+ add(mul(consumerOffsets[i], producerStrides[i], offsetSymbols),
+ producerOffsets[i], offsetSymbols),
+ offsetSymbols);
+ // The combined stride is computed as
+ // producer_stride * consumer_stride.
+ combinedStrides[i] = getOpFoldResult(
+ rewriter, loc,
+ mul(producerStrides[i], consumerStrides[i], strideSymbols),
+ strideSymbols);
+ }
+ combinedSizes = consumer.getMixedSizes();
+ return success();
+}
+
+/// Returns the `hal.interface.binding` a value comes from.
+static Optional<IREE::HAL::InterfaceBindingSubspanOp> getBindingSubspanOp(
+ Value v) {
+ Operation *definingOp = v.getDefiningOp();
+ if (!definingOp) return llvm::None;
+ if (auto interfaceOp =
+ dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(definingOp)) {
+ return interfaceOp;
+ }
+ if (auto loadOp = dyn_cast<IREE::Flow::DispatchTensorLoadOp>(definingOp)) {
+ return getBindingSubspanOp(loadOp.source());
+ }
+ return llvm::None;
+}
+
+namespace {
+
+/// Pattern to fold `flow.dispatch.tensor.load` -> `tensor.extract_slice`.
+// TODO(ravishankarm): Eventually this should go in as a canonicalization at the
+// Flow level.
+struct FoldTensorLoadWithExtractSlice
+ : OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto dispatchTensorLoadOp =
+ extractSliceOp.source()
+ .getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!dispatchTensorLoadOp) return failure();
+
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(foldOffsetsSizesAndStrides(
+ rewriter, dispatchTensorLoadOp->getLoc(), dispatchTensorLoadOp,
+ extractSliceOp, offsets, sizes, strides))) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
+ extractSliceOp, extractSliceOp.getType(), dispatchTensorLoadOp.source(),
+ dispatchTensorLoadOp.source_dims(), offsets, sizes, strides);
+ return success();
+ }
+};
+
+/// Pattern to fold `tensor.insert_slice` with `flow.dispatch.tensor.store`
+/// oeprations.
+// TODO(ravishankarm): Eventually this should go in as a canonicalization at the
+// Flow level.
+struct FoldInsertSliceWithTensorStoreOp
+ : OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
+ using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ IREE::Flow::DispatchTensorStoreOp dispatchTensorStoreOp,
+ PatternRewriter &rewriter) const override {
+ auto insertSliceOp =
+ dispatchTensorStoreOp.value().getDefiningOp<tensor::InsertSliceOp>();
+ if (!insertSliceOp) return failure();
+
+ // Check that the `dest` of the `tensor.insert_slice` and target of the
+ // `flow.dispatch.tensor.store` are the same interface binding.
+ Optional<IREE::HAL::InterfaceBindingSubspanOp> destBinding =
+ getBindingSubspanOp(insertSliceOp.dest());
+ Optional<IREE::HAL::InterfaceBindingSubspanOp> targetBinding =
+ getBindingSubspanOp(dispatchTensorStoreOp.target());
+ if (!destBinding || !targetBinding ||
+ destBinding.getValue() != targetBinding.getValue()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ // Treat the `flow.dispatch.tensor.store` as the producer and the
+ // `tensor.insert_slice` as the consumer since that would be the case for
+ // the final subview created.
+ if (failed(foldOffsetsSizesAndStrides(
+ rewriter, dispatchTensorStoreOp->getLoc(), dispatchTensorStoreOp,
+ insertSliceOp, offsets, sizes, strides))) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
+ dispatchTensorStoreOp, insertSliceOp.source(),
+ dispatchTensorStoreOp.target(), dispatchTensorStoreOp.target_dims(),
+ offsets, sizes, strides);
+ return success();
+ }
+};
+
+/// Pass to bufferize early copy-only dispatches. This allows backends
+/// to use the `linalg.generic` operation generated for lowering the dispatch.
+struct BufferizeCopyOnlyDispatchesPass
+ : public BufferizeCopyOnlyDispatchesBase<BufferizeCopyOnlyDispatchesPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, bufferization::BufferizationDialect,
+ IREE::Flow::FlowDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void BufferizeCopyOnlyDispatchesPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ ModuleOp module = getOperation();
+
+ /// First apply the `flow.dispatch.tensor.load` -> `tensor.extract_slice` and
+ /// `tensor.insert_slice` -> `flow.dispatch.tensor.store` patterns.
+ RewritePatternSet patterns(context);
+ patterns
+ .insert<FoldInsertSliceWithTensorStoreOp, FoldTensorLoadWithExtractSlice>(
+ context);
+ if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+
+ SmallVector<Operation *> copyOnlyFunctions;
+ auto funcOps = module.getOps<FuncOp>();
+ for (auto funcOp : funcOps) {
+ /// Check if the dispatch has all sources for `flow.dispatch.tensor.store`
+ /// operations coming from `flow.dispatch.tensor.load` operations. If so,
+ /// this dispatch is just a copy dispatch.
+ auto walkResult = funcOp.walk(
+ [&](IREE::Flow::DispatchTensorStoreOp storeOp) -> WalkResult {
+ return success(isReadOnly(storeOp.value()));
+ });
+ if (walkResult.wasInterrupted()) continue;
+ // The function is just a copy.
+ copyOnlyFunctions.push_back(funcOp);
+ }
+
+ // There are no copy-only functions. So nothing to do.
+ if (copyOnlyFunctions.empty()) return;
+
+ // Bufferize the dispatch to create a `linalg.generic` as a copy operation.
+ // This can then be used by the backends to tile and distribute.
+ // Currently bufferization does not handle single function bufferization. So
+ // check that all functions are copy only and can be bufferized.
+ if (copyOnlyFunctions.size() !=
+ std::distance(funcOps.begin(), funcOps.end())) {
+ module.emitOpError(
+ "module contains functions that are both copy only and not copy only. "
+ "This is currently unhandled.");
+ return signalPassFailure();
+ }
+
+ // Apply the bufferization passes.
+ OpPassManager bufferizationPipeline(module.getOperationName());
+ addLinalgBufferizePasses(bufferizationPipeline);
+ if (failed(runPipeline(bufferizationPipeline, module))) {
+ return signalPassFailure();
+ }
+
+ // Check that there are no allocs created.
+ auto hasAlloc = module.walk(
+ [&](memref::AllocOp /*op*/) -> WalkResult { return failure(); });
+ if (hasAlloc.wasInterrupted()) {
+ module.emitOpError(
+ "unexpected allocations while bufferizing copy dispatch");
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createBufferizeCopyOnlyDispatchesPass() {
+ return std::make_unique<BufferizeCopyOnlyDispatchesPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index db3e857..9f42782 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -27,6 +27,7 @@
"DestructiveUpdateUtils.h"
SRCS
"BufferizationAnalysis.cpp"
+ "BufferizeCopyOnlyDispatchesPass.cpp"
"CleanupBufferAllocViewPass.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"DemoteF32ToF16.cpp"
@@ -79,6 +80,7 @@
MLIRTransforms
MLIRVector
MLIRVectorTransforms
+ MLIRViewLikeInterface
iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
diff --git a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 6bb0a7e..0d9b83e 100644
--- a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -323,24 +323,6 @@
return preservedAttrs;
}
-static bool isFromReadOnlyTensor(Value v) {
- return TypeSwitch<Operation *, bool>(v.getDefiningOp())
- .Case<arith::ConstantOp>(
- [&](arith::ConstantOp constantOp) { return true; })
- .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
- [&](auto op) { return isFromReadOnlyTensor(op.src()); })
- .Case<tensor::CastOp, tensor::ExtractSliceOp>(
- [&](auto op) { return isFromReadOnlyTensor(op.source()); })
- .Case<IREE::Flow::DispatchTensorLoadOp>(
- [&](IREE::Flow::DispatchTensorLoadOp loadOp) {
- return loadOp.source()
- .getType()
- .cast<IREE::Flow::DispatchTensorType>()
- .getAccess() == IREE::Flow::TensorAccess::ReadOnly;
- })
- .Default([&](Operation *op) { return false; });
-}
-
namespace {
/// Adapts Linalg ops input operand to output operand. This is required for not
/// creating extra alloca ops. For more details, see
@@ -366,7 +348,7 @@
SmallVector<Value> newOperands;
SmallVector<AffineMap> maps;
for (auto in : op.getInputOperands()) {
- if (!operand && !isFromReadOnlyTensor(in->get()) &&
+ if (!operand && !isReadOnly(in->get()) &&
op.getTiedIndexingMap(in) == op.getTiedIndexingMap(outputOperand) &&
in->get().getType() == outputOperand->get().getType()) {
operand = in;
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 2a82d82..5a26dbf 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -1109,6 +1109,10 @@
})
.Case<vector::TransferWriteOp>(
[&](vector::TransferWriteOp transferWriteOp) {
+ if (!transferWriteOp.source().getType().isa<RankedTensorType>()) {
+ // Nothing to do when source is not a tensor.
+ return success();
+ }
if (failed(getOrAllocateResultBuffers(b, transferWriteOp, bvm,
plan, allocationFn))) {
return failure();
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index f5212fc..af89788 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -20,6 +20,7 @@
srcs = enforce_glob(
[
"affinemin_canonicalization.mlir",
+ "bufferize_copy_only_dispatches.mlir",
"canonicalize_interface_load_store.mlir",
"convert_to_destination_passing_style.mlir",
"dead_alloc.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index a650fa9..9a75c90 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"affinemin_canonicalization.mlir"
+ "bufferize_copy_only_dispatches.mlir"
"canonicalize_interface_load_store.mlir"
"convert_to_destination_passing_style.mlir"
"dead_alloc.mlir"
diff --git a/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir b/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir
new file mode 100644
index 0000000..653aac9
--- /dev/null
+++ b/iree/compiler/Codegen/Common/test/bufferize_copy_only_dispatches.mlir
@@ -0,0 +1,159 @@
+// RUN: iree-opt -iree-codegen-bufferize-copy-only-dispatches -split-input-file %s | FileCheck %s
+
+builtin.module {
+ func @tensor_insert_slice() {
+ %source_size_y = hal.interface.constant.load[0] : index
+ %source_size_x = hal.interface.constant.load[1] : index
+ %dest_size_y = hal.interface.constant.load[2] : index
+ %dest_size_x = hal.interface.constant.load[3] : index
+ %dest_offset_y = hal.interface.constant.load[4] : index
+ %dest_offset_x = hal.interface.constant.load[5] : index
+ %dest_stride_y = hal.interface.constant.load[6] : index
+ %dest_stride_x = hal.interface.constant.load[7] : index
+ %insert_offset_y = hal.interface.constant.load[8] : index
+ %insert_offset_x = hal.interface.constant.load[9] : index
+ %insert_stride_y = hal.interface.constant.load[10] : index
+ %insert_stride_x = hal.interface.constant.load[11] : index
+ %dest_binding_size_y = hal.interface.constant.load[12] : index
+ %dest_binding_size_x = hal.interface.constant.load[13] : index
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+ : !flow.dispatch.tensor<readonly:?x?xi32>{%source_size_y, %source_size_x}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+ : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_binding_size_y, %dest_binding_size_x}
+ %source_load = flow.dispatch.tensor.load %source, offsets = [0, 0],sizes = [%source_size_y, %source_size_x], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:?x?xi32>{%source_size_y, %source_size_x} -> tensor<?x?xi32>
+ %dest_load = flow.dispatch.tensor.load %dest, offsets = [%dest_offset_y, %dest_offset_x], sizes = [%dest_size_y, %dest_size_x],
+ strides = [%dest_stride_y, %dest_stride_x]
+ : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_binding_size_y, %dest_binding_size_x} -> tensor<?x?xi32>
+ %insert = tensor.insert_slice %source_load into
+ %dest_load[%insert_offset_y, %insert_offset_x] [%source_size_y, %source_size_x] [%insert_stride_y, %insert_stride_x]
+ : tensor<?x?xi32> into tensor<?x?xi32>
+ flow.dispatch.tensor.store %insert, %dest, offsets = [%dest_offset_y, %dest_offset_x], sizes = [%dest_size_y, %dest_size_x],
+ strides = [%dest_stride_y, %dest_stride_x]
+ : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_binding_size_y, %dest_binding_size_x}
+ return
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @tensor_insert_slice()
+// CHECK-DAG: %[[SOURCE_SIZE_Y:.+]] = hal.interface.constant.load[0]
+// CHECK-DAG: %[[SOURCE_SIZE_X:.+]] = hal.interface.constant.load[1]
+// CHECK-DAG: %[[DEST_OFFSET_Y:.+]] = hal.interface.constant.load[4]
+// CHECK-DAG: %[[DEST_OFFSET_X:.+]] = hal.interface.constant.load[5]
+// CHECK-DAG: %[[DEST_STRIDE_Y:.+]] = hal.interface.constant.load[6]
+// CHECK-DAG: %[[DEST_STRIDE_X:.+]] = hal.interface.constant.load[7]
+// CHECK-DAG: %[[INSERT_OFFSET_Y:.+]] = hal.interface.constant.load[8]
+// CHECK-DAG: %[[INSERT_OFFSET_X:.+]] = hal.interface.constant.load[9]
+// CHECK-DAG: %[[INSERT_STRIDE_Y:.+]] = hal.interface.constant.load[10]
+// CHECK-DAG: %[[INSERT_STRIDE_X:.+]] = hal.interface.constant.load[11]
+// CHECK-DAG: %[[SOURCE:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[INSERT_OFFSET_Y]], %[[DEST_STRIDE_Y]], %[[DEST_OFFSET_Y]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[INSERT_OFFSET_X]], %[[DEST_STRIDE_X]], %[[DEST_OFFSET_X]]]
+// CHECK-DAG: %[[STRIDE_Y:.+]] = affine.apply #[[MAP1]]()[%[[DEST_STRIDE_Y]], %[[INSERT_STRIDE_Y]]]
+// CHECK-DAG: %[[STRIDE_X:.+]] = affine.apply #[[MAP1]]()[%[[DEST_STRIDE_X]], %[[INSERT_STRIDE_X]]]
+// CHECK-DAG: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[SOURCE_SIZE_Y]], %[[SOURCE_SIZE_X]]]
+// CHECK-SAME: [%[[STRIDE_Y]], %[[STRIDE_X]]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SOURCE]] :
+// CHECK-SAME: outs(%[[SUBVIEW]] :
+
+// -----
+
+builtin.module {
+ func @tensor_extract_slice() {
+ %source_size_y = hal.interface.constant.load[0] : index
+ %source_size_x = hal.interface.constant.load[1] : index
+ %dest_size_y = hal.interface.constant.load[2] : index
+ %dest_size_x = hal.interface.constant.load[3] : index
+ %source_offset_y = hal.interface.constant.load[4] : index
+ %source_offset_x = hal.interface.constant.load[5] : index
+ %extract_offset_y = hal.interface.constant.load[6] : index
+ %extract_offset_x = hal.interface.constant.load[7] : index
+ %extract_stride_y = hal.interface.constant.load[8] : index
+ %extract_stride_x = hal.interface.constant.load[9] : index
+ %source_stride_y = hal.interface.constant.load[10] : index
+ %source_stride_x = hal.interface.constant.load[11] : index
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+ : !flow.dispatch.tensor<readonly:?x?xi32>{%source_size_y, %source_size_x}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+ : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
+ %source_load = flow.dispatch.tensor.load %source, offsets = [%source_offset_y, %source_offset_x], sizes = [%source_size_y, %source_size_x],
+ strides = [%source_stride_y, %source_stride_x]
+ : !flow.dispatch.tensor<readonly:?x?xi32>{%source_size_y, %source_size_x} -> tensor<?x?xi32>
+ %extract = tensor.extract_slice %source_load[%extract_offset_y, %extract_offset_x] [%dest_size_y, %dest_size_x]
+ [%extract_stride_y, %extract_stride_x] : tensor<?x?xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %extract, %dest, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
+ : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
+ return
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: func @tensor_extract_slice()
+// CHECK-DAG: %[[DEST_SIZE_Y:.+]] = hal.interface.constant.load[2]
+// CHECK-DAG: %[[DEST_SIZE_X:.+]] = hal.interface.constant.load[3]
+// CHECK-DAG: %[[SOURCE_OFFSET_Y:.+]] = hal.interface.constant.load[4]
+// CHECK-DAG: %[[SOURCE_OFFSET_X:.+]] = hal.interface.constant.load[5]
+// CHECK-DAG: %[[EXTRACT_OFFSET_Y:.+]] = hal.interface.constant.load[6]
+// CHECK-DAG: %[[EXTRACT_OFFSET_X:.+]] = hal.interface.constant.load[7]
+// CHECK-DAG: %[[EXTRACT_STRIDE_Y:.+]] = hal.interface.constant.load[8]
+// CHECK-DAG: %[[EXTRACT_STRIDE_X:.+]] = hal.interface.constant.load[9]
+// CHECK-DAG: %[[SOURCE_STRIDE_Y:.+]] = hal.interface.constant.load[10]
+// CHECK-DAG: %[[SOURCE_STRIDE_X:.+]] = hal.interface.constant.load[11]
+// CHECK-DAG: %[[SOURCE:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[EXTRACT_OFFSET_Y]], %[[SOURCE_STRIDE_Y]], %[[SOURCE_OFFSET_Y]]]
+// CHECK-DAG: %[[STRIDE_Y:.+]] = affine.apply #[[MAP1]]()[%[[SOURCE_STRIDE_Y]], %[[EXTRACT_STRIDE_Y]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[EXTRACT_OFFSET_X]], %[[SOURCE_STRIDE_X]], %[[SOURCE_OFFSET_X]]]
+// CHECK-DAG: %[[STRIDE_X:.+]] = affine.apply #[[MAP1]]()[%[[SOURCE_STRIDE_X]], %[[EXTRACT_STRIDE_X]]]
+// CHECK-DAG: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[DEST_SIZE_Y]], %[[DEST_SIZE_X]]]
+// CHECK-SAME: [%[[STRIDE_Y]], %[[STRIDE_X]]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SOURCE_SUBVIEW]] :
+// CHECK-SAME: outs(%[[DEST]] :
+
+// -----
+
+builtin.module {
+ func @UpSampling1D() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:2x16x3xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:2x8x3xf32>
+ %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 8, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:2x8x3xf32> -> tensor<2x8x3xf32>
+ %3 = tensor.extract_slice %2[0, 0, 0] [2, 1, 3] [1, 1, 1] : tensor<2x8x3xf32> to tensor<2x3xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 16, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readwrite:2x16x3xf32> -> tensor<2x16x3xf32>
+ %5 = tensor.insert_slice %3 into %4[0, 0, 0] [2, 1, 3] [1, 1, 1] : tensor<2x3xf32> into tensor<2x16x3xf32>
+ flow.dispatch.tensor.store %5, %0, offsets = [0, 0, 0], sizes = [2, 16, 3], strides = [1, 1, 1] : tensor<2x16x3xf32> -> !flow.dispatch.tensor<readwrite:2x16x3xf32>
+ return
+ }
+}
+// CHECK-LABEL: func @UpSampling1D()
+// CHECK-DAG: %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[SOURCE:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][0, 0, 0] [2, 1, 3]
+// CHECK-DAG: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0] [2, 1, 3]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SOURCE_SUBVIEW]] : memref<2x3xf32, #{{[a-zA-Z0-9]+}}>)
+// CHECK-SAME: outs(%[[DEST_SUBVIEW]] : memref<2x3xf32, #{{[a-zA-Z0-9]+}}>)
+
+// -----
+
+builtin.module {
+ func @concatenate_cst() {
+ %cst = arith.constant dense<0> : tensor<2x3xi32>
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:2x5xi32>
+ flow.dispatch.tensor.store %cst, %0, offsets = [0, 2], sizes = [2, 3], strides = [1, 1] : tensor<2x3xi32> -> !flow.dispatch.tensor<readwrite:2x5xi32>
+ return
+ }
+}
+// CHECK-LABEL: func @concatenate_cst()
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : tensor<2x3xi32>
+// CHECK-DAG: %[[ZERO:.+]] = bufferization.to_memref %[[CST]] : memref<2x3xi32>
+// CHECK-DAG: %[[DEST_BINDING:.+]] = hal.interface.binding.subspan
+// CHECK-DAG: %[[SUBVIEW:.+]] = memref.subview %[[DEST_BINDING]][0, 2] [2, 3]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[ZERO]] :
+// CHECK-SAME: outs(%[[SUBVIEW]] :
diff --git a/iree/compiler/Codegen/Common/test/insert_distribution_info.mlir b/iree/compiler/Codegen/Common/test/insert_distribution_info.mlir
index 41ad834..9ecbbf5 100644
--- a/iree/compiler/Codegen/Common/test/insert_distribution_info.mlir
+++ b/iree/compiler/Codegen/Common/test/insert_distribution_info.mlir
@@ -307,37 +307,43 @@
]>
]>
#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
-#translation = #iree_codegen.translation_info<CPUDefault>
-hal.executable public @tensor_insert {
+#translation = #iree_codegen.translation_info<CPUBufferOpsTileAndVectorize>
+hal.executable public @copy_op {
hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
- hal.executable.entry_point public @tensor_insert_slice layout(#executable_layout) {translation_info = #translation}
+ hal.executable.entry_point public @copy_op layout(#executable_layout) {translation_info = #translation}
builtin.module {
- func @tensor_insert_slice() {
- %0 = hal.interface.constant.load[0] : index
- %1 = hal.interface.constant.load[1] : index
- %2 = hal.interface.constant.load[2] : index
- %3 = hal.interface.constant.load[3] : index
- %4 = hal.interface.constant.load[4] : index
- %5 = hal.interface.constant.load[5] : index
- %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
- %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
- %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
- %9 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3} -> tensor<?x?xi32>
- %10 = tensor.insert_slice %8 into %9[%4, %5] [%0, %1] [1, 1] {lowering_config = #config} : tensor<?x?xi32> into tensor<?x?xi32>
- flow.dispatch.tensor.store %10, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
+ func @copy_op() {
+ %source_size_y = hal.interface.constant.load[0] : index
+ %source_size_x = hal.interface.constant.load[1] : index
+ %dest_size_y = hal.interface.constant.load[2] : index
+ %dest_size_x = hal.interface.constant.load[3] : index
+ %source_offset_y = hal.interface.constant.load[4] : index
+ %source_offset_x = hal.interface.constant.load[5] : index
+ %dest_offset_y = hal.interface.constant.load[6] : index
+ %dest_offset_x = hal.interface.constant.load[7] : index
+ %slice_size_y = hal.interface.constant.load[8] : index
+ %slice_size_x = hal.interface.constant.load[9] : index
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%source_size_y, %source_size_x}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%dest_size_y, %dest_size_x}
+ %source_subview = memref.subview %source[%source_offset_y, %source_offset_x] [%slice_size_y, %slice_size_x] [1, 1] : memref<?x?xi32> to memref<?x?xi32, offset : ?, strides : [?, ?]>
+ %dest_subview = memref.subview %dest[%dest_offset_y, %dest_offset_x] [%slice_size_y, %slice_size_x] [1, 1] : memref<?x?xi32> to memref<?x?xi32, offset : ?, strides : [?, ?]>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%source_subview : memref<?x?xi32, offset : ?, strides : [?, ?]>)
+ outs(%dest_subview : memref<?x?xi32, offset : ?, strides : [?, ?]>)
+ attrs = {lowering_config = #config} {
+ ^bb0(%arg0: i32, %arg1: i32):
+ linalg.yield %arg0 : i32
+ }
return
}
}
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault, workload_per_wg = [64, 64]>
-// CHECK: hal.executable.entry_point public @tensor_insert
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUBufferOpsTileAndVectorize, workload_per_wg = [64, 64]>
+// CHECK: hal.executable.entry_point public @copy_op
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK-NEXT: (%[[ARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
@@ -346,56 +352,6 @@
// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
// CHECK: hal.return %[[D0]], %[[D1]], %[[C1]]
-// CHECK: func @tensor_insert_slice()
-
-// -----
-
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
-#translation = #iree_codegen.translation_info<CPUDefault>
-hal.executable public @extract_slice {
- hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
- hal.executable.entry_point public @extract_slice layout(#executable_layout) {translation_info = #translation}
- builtin.module {
- func @extract_slice() {
- %0 = hal.interface.constant.load[0] : index
- %1 = hal.interface.constant.load[1] : index
- %2 = hal.interface.constant.load[2] : index
- %3 = hal.interface.constant.load[3] : index
- %4 = hal.interface.constant.load[4] : index
- %5 = hal.interface.constant.load[5] : index
- %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
- %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
- %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
- %9 = tensor.extract_slice %8[%4, %5] [%2, %3] [1, 1] {lowering_config = #config} : tensor<?x?xi32> to tensor<?x?xi32>
- flow.dispatch.tensor.store %9, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
- return
- }
- }
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault, workload_per_wg = [64, 64]>
-// CHECK: hal.executable.entry_point public @extract_slice
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK-NEXT: (%[[ARG0:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-// CHECK-DAG: %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-// CHECK: hal.return %[[D0]], %[[D1]], %[[C1]]
-// CHECK: func @extract_slice()
// -----
diff --git a/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 6bd12a7..3407016 100644
--- a/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -327,29 +327,35 @@
]>
]>
#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
-#translation = #iree_codegen.translation_info<CPUDefault>
-hal.executable public @tensor_insert {
+#translation = #iree_codegen.translation_info<CPUBufferOpsTileAndVectorize, workload_per_wg = [64, 64]>
+hal.executable public @copy_op {
hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
- hal.executable.entry_point public @tensor_insert_slice layout(#executable_layout) {translation_info = #translation}
+ hal.executable.entry_point public @copy_op layout(#executable_layout) {translation_info = #translation}
builtin.module {
- func @tensor_insert_slice() {
- %0 = hal.interface.constant.load[0] : index
- %1 = hal.interface.constant.load[1] : index
- %2 = hal.interface.constant.load[2] : index
- %3 = hal.interface.constant.load[3] : index
- %4 = hal.interface.constant.load[4] : index
- %5 = hal.interface.constant.load[5] : index
- %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
- %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
- %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
- %9 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3} -> tensor<?x?xi32>
- %10 = tensor.insert_slice %8 into %9[%4, %5] [%0, %1] [1, 1] {lowering_config = #config} : tensor<?x?xi32> into tensor<?x?xi32>
- flow.dispatch.tensor.store %10, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
+ func @copy_op() {
+ %source_size_y = hal.interface.constant.load[0] : index
+ %source_size_x = hal.interface.constant.load[1] : index
+ %dest_size_y = hal.interface.constant.load[2] : index
+ %dest_size_x = hal.interface.constant.load[3] : index
+ %source_offset_y = hal.interface.constant.load[4] : index
+ %source_offset_x = hal.interface.constant.load[5] : index
+ %dest_offset_y = hal.interface.constant.load[6] : index
+ %dest_offset_x = hal.interface.constant.load[7] : index
+ %slice_size_y = hal.interface.constant.load[8] : index
+ %slice_size_x = hal.interface.constant.load[9] : index
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%source_size_y, %source_size_x}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%dest_size_y, %dest_size_x}
+ %source_subview = memref.subview %source[%source_offset_y, %source_offset_x] [%slice_size_y, %slice_size_x] [1, 1] : memref<?x?xi32> to memref<?x?xi32, offset : ?, strides : [?, ?]>
+ %dest_subview = memref.subview %dest[%dest_offset_y, %dest_offset_x] [%slice_size_y, %slice_size_x] [1, 1] : memref<?x?xi32> to memref<?x?xi32, offset : ?, strides : [?, ?]>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%source_subview : memref<?x?xi32, offset : ?, strides : [?, ?]>)
+ outs(%dest_subview : memref<?x?xi32, offset : ?, strides : [?, ?]>)
+ attrs = {lowering_config = #config} {
+ ^bb0(%arg0: i32, %arg1: i32):
+ linalg.yield %arg0 : i32
+ }
return
}
}
@@ -357,103 +363,38 @@
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault>
-// CHECK: hal.executable.entry_point public @tensor_insert
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: func @tensor_insert_slice()
-// CHECK-DAG: %[[SIZE_Y:.+]] = hal.interface.constant.load[0] : index
-// CHECK-DAG: %[[SIZE_X:.+]] = hal.interface.constant.load[1] : index
-// CHECK-DAG: %[[DEST_SIZE_Y:.+]] = hal.interface.constant.load[2] : index
-// CHECK-DAG: %[[DEST_SIZE_X:.+]] = hal.interface.constant.load[3] : index
-// CHECK-DAG: %[[OFFSET_Y:.+]] = hal.interface.constant.load[4] : index
-// CHECK-DAG: %[[OFFSET_X:.+]] = hal.interface.constant.load[5] : index
-// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
-// CHECK-DAG: %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
-// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
-// CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
-// CHECK-DAG: %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]]]
-// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_Y]]]
-// CHECK: scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SIZE_Y]] step %[[STEP_Y]]
-// CHECK-DAG: %[[TILESIZE_Y:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[SIZE_Y]]]
-// CHECK-DAG: %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
-// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
-// CHECK: scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SIZE_X]] step %[[STEP_X]]
-// CHECK-DAG: %[[TILESIZE_X:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[SIZE_X]]]
-// CHECK-DAG: %[[SOURCE:.+]] = flow.dispatch.tensor.load
-// CHECK-SAME: offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
-// CHECK-DAG: %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP3]](%[[IV0]])[%[[OFFSET_Y]]]
-// CHECK-DAG: %[[STORE_OFFSET_X:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%[[OFFSET_X]]]
-// CHECK: flow.dispatch.tensor.store %[[SOURCE]]
-// CHECK-SAME: offsets = [%[[STORE_OFFSET_Y]], %[[STORE_OFFSET_X]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
-
-// -----
-
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
-#translation = #iree_codegen.translation_info<CPUDefault>
-hal.executable public @extract_slice {
- hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
- hal.executable.entry_point public @extract_slice layout(#executable_layout) {translation_info = #translation}
- builtin.module {
- func @extract_slice() {
- %0 = hal.interface.constant.load[0] : index
- %1 = hal.interface.constant.load[1] : index
- %2 = hal.interface.constant.load[2] : index
- %3 = hal.interface.constant.load[3] : index
- %4 = hal.interface.constant.load[4] : index
- %5 = hal.interface.constant.load[5] : index
- %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
- %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
- %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
- %9 = tensor.extract_slice %8[%4, %5] [%2, %3] [1, 1] {lowering_config = #config} : tensor<?x?xi32> to tensor<?x?xi32>
- flow.dispatch.tensor.store %9, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
- return
- }
- }
- }
-}
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault>
-// CHECK: hal.executable.entry_point public @extract_slice
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: func @extract_slice()
+// CHECK: func @copy_op()
// CHECK-DAG: %[[SOURCE_SIZE_Y:.+]] = hal.interface.constant.load[0] : index
// CHECK-DAG: %[[SOURCE_SIZE_X:.+]] = hal.interface.constant.load[1] : index
-// CHECK-DAG: %[[SIZE_Y:.+]] = hal.interface.constant.load[2] : index
-// CHECK-DAG: %[[SIZE_X:.+]] = hal.interface.constant.load[3] : index
-// CHECK-DAG: %[[OFFSET_Y:.+]] = hal.interface.constant.load[4] : index
-// CHECK-DAG: %[[OFFSET_X:.+]] = hal.interface.constant.load[5] : index
+// CHECK-DAG: %[[DEST_SIZE_Y:.+]] = hal.interface.constant.load[2] : index
+// CHECK-DAG: %[[DEST_SIZE_X:.+]] = hal.interface.constant.load[3] : index
+// CHECK-DAG: %[[SOURCE_OFFSET_Y:.+]] = hal.interface.constant.load[4] : index
+// CHECK-DAG: %[[SOURCE_OFFSET_X:.+]] = hal.interface.constant.load[5] : index
+// CHECK-DAG: %[[DEST_OFFSET_Y:.+]] = hal.interface.constant.load[6] : index
+// CHECK-DAG: %[[DEST_OFFSET_X:.+]] = hal.interface.constant.load[7] : index
+// CHECK-DAG: %[[SLICE_SIZE_Y:.+]] = hal.interface.constant.load[8] : index
+// CHECK-DAG: %[[SLICE_SIZE_X:.+]] = hal.interface.constant.load[9] : index
+// CHECK-DAG: %[[SOURCE_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[DEST_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[SOURCE:.+]] = memref.subview %[[SOURCE_BINDING]][%[[SOURCE_OFFSET_Y]], %[[SOURCE_OFFSET_X]]]
+// CHECK-DAG: %[[DEST:.+]] = memref.subview %[[DEST_BINDING]][%[[DEST_OFFSET_Y]], %[[DEST_OFFSET_X]]]
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-DAG: %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
// CHECK-DAG: %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]]]
// CHECK-DAG: %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_Y]]]
-// CHECK: scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SIZE_Y]] step %[[STEP_Y]]
-// CHECK-DAG: %[[TILESIZE_Y:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[SIZE_Y]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SLICE_SIZE_Y]] step %[[STEP_Y]]
// CHECK-DAG: %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
// CHECK-DAG: %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
-// CHECK: scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SIZE_X]] step %[[STEP_X]]
-// CHECK-DAG: %[[TILESIZE_X:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[SIZE_X]]]
-// CHECK-DAG: %[[LOAD_OFFSET_Y:.+]] = affine.apply #[[MAP3]](%[[IV0]])[%[[OFFSET_Y]]]
-// CHECK-DAG: %[[LOAD_OFFSET_X:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%[[OFFSET_X]]]
-// CHECK-DAG: %[[SOURCE:.+]] = flow.dispatch.tensor.load
-// CHECK-SAME: offsets = [%[[LOAD_OFFSET_Y]], %[[LOAD_OFFSET_X]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
-// CHECK: flow.dispatch.tensor.store %[[SOURCE]]
-// CHECK-SAME: offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SLICE_SIZE_X]] step %[[STEP_X]]
+// CHECK-DAG: %[[TILESIZE_Y:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[SLICE_SIZE_Y]]]
+// CHECK-DAG: %[[TILESIZE_X:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[SLICE_SIZE_X]]]
+// CHECK-DAG: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[IV0]], %[[IV1]]]
+// CHECK-DAG: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[IV0]], %[[IV1]]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SOURCE_SUBVIEW]] :
+// CHECK-SAME: outs(%[[DEST_SUBVIEW]] :
// -----
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.td b/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 3e3fdc4..5225430 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -21,29 +21,32 @@
: I32EnumAttrCase<"CPUConvTileAndDecomposeExpert", 3>;
def CPU_TileFuseAndVectorize
: I32EnumAttrCase<"CPUTileFuseAndVectorize", 4>;
-def CPU_SandboxCodegen
- : I32EnumAttrCase<"LinalgTransformInterpCodegen", 5>;
+def CPU_BufferOpsTileAndVectorize
+ : I32EnumAttrCase<"CPUBufferOpsTileAndVectorize", 5>;
+
+def Linalg_TransformInterpCodegen
+ : I32EnumAttrCase<"LinalgTransformInterpCodegen", 6>;
def LLVMGPU_SimpleDistribute
- : I32EnumAttrCase<"LLVMGPUDistribute",6>;
+ : I32EnumAttrCase<"LLVMGPUDistribute", 7>;
def LLVMGPU_Vectorize
- : I32EnumAttrCase<"LLVMGPUVectorize", 7>;
+ : I32EnumAttrCase<"LLVMGPUVectorize", 8>;
def LLVMGPU_MatmulSimt
- : I32EnumAttrCase<"LLVMGPUMatmulSimt", 8>;
+ : I32EnumAttrCase<"LLVMGPUMatmulSimt", 9>;
def LLVMGPU_MatmulTensorCore
- : I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 9>;
+ : I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 10>;
def SPIRV_Distribute
- : I32EnumAttrCase<"SPIRVDistribute", 10>;
+ : I32EnumAttrCase<"SPIRVDistribute", 11>;
def SPIRV_DistributeCopy
- : I32EnumAttrCase<"SPIRVDistributeCopy", 11>;
+ : I32EnumAttrCase<"SPIRVDistributeCopy", 12>;
def SPIRV_Vectorize
- : I32EnumAttrCase<"SPIRVVectorize", 12>;
+ : I32EnumAttrCase<"SPIRVVectorize", 13>;
def SPIRV_VectorizeToCooperativeOps
- : I32EnumAttrCase<"SPIRVVectorizeToCooperativeOps", 13>;
+ : I32EnumAttrCase<"SPIRVVectorizeToCooperativeOps", 14>;
def None
- : I32EnumAttrCase<"None", 14>;
+ : I32EnumAttrCase<"None", 15>;
// EnumAttrCase for all known lowerings for ops within dispatch region
// to scalar/native-vector code.
@@ -52,10 +55,10 @@
"identifier for pass pipeline use to lower dispatch region",
[CPU_Default, CPU_SingleTilingExpert, CPU_DoubleTilingExpert,
CPU_ConvTileAndDecomposeExpert, CPU_TileFuseAndVectorize,
- CPU_SandboxCodegen, LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize,
- LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore, SPIRV_Distribute,
- SPIRV_DistributeCopy, SPIRV_Vectorize, SPIRV_VectorizeToCooperativeOps,
- None]> {
+ CPU_BufferOpsTileAndVectorize, Linalg_TransformInterpCodegen,
+ LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize, LLVMGPU_MatmulSimt,
+ LLVMGPU_MatmulTensorCore, SPIRV_Distribute, SPIRV_DistributeCopy, SPIRV_Vectorize,
+ SPIRV_VectorizeToCooperativeOps, None]> {
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
// Don't generate a C++ class! We want to use the AttrDef
let genSpecializedAttr = 0;
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 51a8d03..eb0ea6b 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -628,9 +628,13 @@
tileSizes.push_back(parallelTileSizes);
tileSizes.push_back(reductionTileSizes);
- return setOpConfigAndEntryPointFnTranslation(
- entryPointFn, genericOp, tileSizes,
- DispatchLoweringPassPipeline::CPUDoubleTilingExpert);
+ // For non-tensor based ops use the Buffer ops pipeline.
+ auto passPipeline =
+ genericOp.hasTensorSemantics()
+ ? DispatchLoweringPassPipeline::CPUDoubleTilingExpert
+ : DispatchLoweringPassPipeline::CPUBufferOpsTileAndVectorize;
+ return setOpConfigAndEntryPointFnTranslation(entryPointFn, genericOp,
+ tileSizes, passPipeline);
}
/// Sets the lowering configuration for linalg.conv_2d_nhwc_hwcf and
@@ -831,25 +835,11 @@
if (rootOperation) return rootOperation;
// If no root operation is found yet. Look for linalg generic ops.
- for (auto op : computeOps) {
- if (isa<linalg::GenericOp>(op)) {
+ for (auto op : llvm::reverse(computeOps)) {
+ if (isa<linalg::LinalgOp>(op)) {
if (failed(updateRootOperation(op))) return failure();
}
}
- if (rootOperation) return rootOperation;
-
- // TODO(ravishankarm): Currently there is a corner case of a dispatch region
- // with just a `tensor.extract_slice`/`tensor.insert_slice`. Those need to be
- // folded with `flow.dispatch.tensor.load`/`flow.dispatch.tensor.store` ops
- // respectively. This should go hand-in-hand with dropping the external model
- // implementation of the `TiledOpInterface` for these ops. Till we cross that
- // bridge, handle that case.
- // Throw in linalg.fill here as well, though that should never happen either.
- if (computeOps.size() == 1 &&
- isa<linalg::FillOp, tensor::ExtractSliceOp, tensor::InsertSliceOp>(
- computeOps[0])) {
- rootOperation = computeOps[0];
- }
return rootOperation;
}
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index e087c25..f43d11e 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -193,6 +193,10 @@
addCPUDefaultPassPipeline(nestedModulePM);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::
+ CPUBufferOpsTileAndVectorize:
+ addCPUBufferOpsTileAndVectorizePipeline(nestedModulePM);
+ break;
+ case IREE::Codegen::DispatchLoweringPassPipeline::
CPUSingleTilingExpert:
addSingleTilingExpertPassPipeline(nestedModulePM);
break;
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 0e14a85..2d0a082 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -192,6 +192,20 @@
}
}
+void addCPUBufferOpsTileAndVectorizePipeline(OpPassManager &passManager) {
+ // Do first level of tiling and distribution.
+ passManager.addNestedPass<FuncOp>(createInsertDistributionInfoPass());
+ passManager.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+ passManager.addPass(createCanonicalizerPass());
+ passManager.addPass(createCSEPass());
+
+ // This pipeline should also vectorize these ops, but they arent today because
+ // of a correctness issue. See Issue #8579.
+
+ // Run IREE specific passes before vector lowering expert.
+ passManager.addNestedPass<FuncOp>(createRemoveSingleIterationLoopPass());
+}
+
void addDoubleTilingExpertPassPipeline(OpPassManager &passManager) {
// Do first level of tiling and distribution.
passManager.addNestedPass<FuncOp>(createInsertDistributionInfoPass());
@@ -427,6 +441,7 @@
void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {
passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
createTypePropagationPass());
+ passManager.nest<ModuleOp>().addPass(createBufferizeCopyOnlyDispatchesPass());
passManager.addPass(createLLVMCPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
addLowerToLLVMPasses(nestedModulePM);
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index 56bc1b3..5aed8ea 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -268,84 +268,39 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
-hal.executable @tensor_insert {
+hal.executable @copy_op {
hal.executable.variant @system_elf_x86_64, target = <"llvm", "system-elf-x86_64"> {
- hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
+ hal.executable.entry_point @copy_op layout(#executable_layout)
builtin.module {
- func.func @tensor_insert_slice() {
+ func.func @copy_op() {
%d0 = hal.interface.constant.load[0] : index
%d1 = hal.interface.constant.load[1] : index
%d2 = hal.interface.constant.load[2] : index
%d3 = hal.interface.constant.load[3] : index
%o0 = hal.interface.constant.load[4] : index
%o1 = hal.interface.constant.load[5] : index
- %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1}
- %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3}
- %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1} -> tensor<?x?xi32>
- %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3} -> tensor<?x?xi32>
- %result = tensor.insert_slice %source into %dest[%o0, %o1] [%d0, %d1] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
- flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%d2, %d3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3}
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%d0, %d1}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%d2, %d3}
+ %dest_view = memref.subview %dest[%o0, %o1] [%d0, %d1] [1, 1] : memref<?x?xi32> to memref<?x?xi32, offset : ?, strides : [?, ?]>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)> , affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%source : memref<?x?xi32>) outs(%dest_view : memref<?x?xi32, offset : ?, strides : [?, ?]>) {
+ ^bb0(%arg0 : i32, %arg1 : i32):
+ linalg.yield %arg0 : i32
+ }
return
}
}
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault>
-// CHECK: hal.executable.entry_point public @tensor_insert_slice
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [1, 4], [0, 0]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUBufferOpsTileAndVectorize>
+// CHECK: hal.executable.entry_point public @copy_op
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: tensor.insert_slice
-// CHECK-SAME: lowering_config = #[[CONFIG]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable @extract_slice {
- hal.executable.variant @system_elf_x86_64, target = <"llvm", "system-elf-x86_64"> {
- hal.executable.entry_point @extract_slice layout(#executable_layout)
- builtin.module {
- func.func @extract_slice() {
- %d0 = hal.interface.constant.load[0] : index
- %d1 = hal.interface.constant.load[1] : index
- %d2 = hal.interface.constant.load[2] : index
- %d3 = hal.interface.constant.load[3] : index
- %o0 = hal.interface.constant.load[4] : index
- %o1 = hal.interface.constant.load[5] : index
- %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1}
- %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3}
- %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1} -> tensor<?x?xi32>
- %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
- : !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3} -> tensor<?x?xi32>
- %result = tensor.extract_slice %source[%o0, %o1] [%d0, %d1] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
- flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%d2, %d3], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3}
- return
- }
- }
- }
-}
-
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDefault>
-// CHECK: hal.executable.entry_point public @extract_slice
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: tensor.extract_slice
-// CHECK-SAME: lowering_config = #[[CONFIG]]
-
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 38dcce8..933756a 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -438,20 +438,6 @@
}
if (!rootOperation) {
- // TODO(ravishankarm): Currently you could have dispatches with a single
- // tensor.insert_slice or a tensor.extract_slice. Those are handled by
- // tile + distribute as well since these ops have an external model
- // implementing the `TiledOpInterface`. This is legacy. These ops shouldnt
- // implement this interface, and backends must be able to handle a
- // dispatch with flow.dispatch.tensor.load -> flow.dispatch.tensor.store.
- // Till this is cleaned up, set a configuration for this.
- if (computeOps.size() == 1 &&
- isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(computeOps[0])) {
- rootOperation = computeOps[0];
- }
- }
-
- if (!rootOperation) {
// setTranslationInfo(
// funcOp,
// IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute,
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index 4c92f7c..773dffd 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -36,7 +37,7 @@
registry
.insert<IREE::Codegen::IREECodegenDialect, IREE::HAL::HALDialect,
linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect,
- vector::VectorDialect, gpu::GPUDialect>();
+ vector::VectorDialect, gpu::GPUDialect, scf::SCFDialect>();
}
LLVMGPULowerExecutableTargetPass() = default;
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index b6dd906..42eb9a0 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -188,6 +188,7 @@
void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
pm.nest<ModuleOp>().nest<FuncOp>().addPass(createTypePropagationPass());
+ pm.nest<ModuleOp>().addPass(createBufferizeCopyOnlyDispatchesPass());
pm.addPass(createLLVMGPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
//===--------------------------------------------------------------------===//
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index 68af7ed..37b5d9a 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -130,50 +130,6 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
-hal.executable @tensor_insert_slice {
- hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
- hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
- builtin.module {
- func.func @tensor_insert_slice() {
- %c0 = arith.constant 0 : index
- %size_y = hal.interface.constant.load[0] : index
- %size_x = hal.interface.constant.load[1] : index
- %dest_size_y = hal.interface.constant.load[2] : index
- %dest_size_x = hal.interface.constant.load[3] : index
- %offset_y = hal.interface.constant.load[4] : index
- %offset_x = hal.interface.constant.load[5] : index
- %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xi32>{%size_y, %size_x}
- %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
- %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%size_y, %size_x], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xi32>{%size_y, %size_x} -> tensor<?x?xi32>
- %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x} -> tensor<?x?xi32>
- %result = tensor.insert_slice %source into %dest[%offset_y, %offset_x] [%size_y, %size_x] [1, 1]
- : tensor<?x?xi32> into tensor<?x?xi32>
- flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
- : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
- return
- }
- }
- }
-}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorize>
-// CHECK: hal.executable.entry_point public @tensor_insert_slice
-// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: tensor.insert_slice
-// CHECK-SAME: lowering_config = #[[CONFIG]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
hal.executable @copy_as_generic {
hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
hal.executable.entry_point @copy_as_generic layout(#executable_layout)
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 15d747f..4cb9151 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -60,6 +60,12 @@
/// allocations and view operations.
std::unique_ptr<OperationPass<FuncOp>> createCleanupBufferAllocViewPass();
+/// Pass to bufferize dispatches that are copying from one interface to another.
+/// This will create a `linalg.generic` op which is a copy that can then be
+/// used by backends to handle appropriately.
+std::unique_ptr<OperationPass<ModuleOp>>
+createBufferizeCopyOnlyDispatchesPass();
+
/// Create a pass to convert a model using f32 type to the equivalent one
/// using f16.
std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass();
@@ -226,6 +232,11 @@
/// to memrefs
void addCPUDefaultPassPipeline(OpPassManager &passManager);
+/// Populates the passes to lower linalg ops on buffers. Currenly this pipeline
+/// is only used for dispatches that just copy data from input interfaces to
+/// output interface.
+void addCPUBufferOpsTileAndVectorizePipeline(OpPassManager &passManager);
+
/// Populates the passes needed to multi level tile and lowering of linalg ops
/// on tensors to vectors operations.
LogicalResult verifyTensorToVectorsPassPipelineConfig(
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index d24c13d..9b637bb 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -59,6 +59,13 @@
let constructor = "mlir::iree_compiler::createForOpCanonicalizationPass()";
}
+def BufferizeCopyOnlyDispatches :
+ Pass<"iree-codegen-bufferize-copy-only-dispatches", "ModuleOp"> {
+ let summary =
+ "Bufferize dispatches that copy to/from interfaces to convert to a linalg.copy op";
+ let constructor = "mlir::iree_compiler::createBufferizeCopyOnlyDispatchesPass()";
+}
+
def LinalgBufferize :
Pass<"iree-codegen-linalg-bufferize", "func::FuncOp"> {
let summary = "Convert from to Linalg ops on tensors to buffers";
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index e0e8cf0..95538fb 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -357,7 +357,11 @@
// Special case for non-linalg ops.
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp || linalgOp.getNumOutputs() != 1) {
+ // TODO(#8580): Ops with buffer semantics, like those created by copy-only
+ // dispatches can be vectorized too, but that code fails compilation. So
+ // disabling that for now.
+ if (!linalgOp || linalgOp.getNumOutputs() != 1 ||
+ linalgOp.hasBufferSemantics()) {
auto pipeline =
IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistribute;
@@ -615,37 +619,16 @@
rootOperation = computeOp;
}
- // If there are still no root op, check for any linalg.generic op.
if (!rootOperation) {
+ // If there are still no root op, check for any linalg.generic op.
Operation *computeOp = computeOps.back();
+ if (failed(setDefaultOpConfig(limits, computeOp))) return failure();
- // Handle the case of compute op being a
- // `tensor.extract_slice`/`tensor.insert_slice`. That needs bufferization
- // to run before configuration can be set again. Just set the translation
- // to use the `SPIRVDistributeAndCopy` pipeline. The configuration will be
- // set again after bufferization.
- //
- // TODO(ravishankarm): This is a awkward.
- // `tensor.extract_slice`/`tensor.insert_slice` will be dropped from
- // `TiledOpInterface` soon, and will not be compute op. At that time, they
- // will be folded with `flow.tensor.load` and `flow.tensor.store`
- // operations. Then this case will degenerate to having no compute ops.
- // Rework this at that stage to run bufferization early.
- if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(computeOp)) {
- setTranslationInfo(
- funcOp,
- IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistributeCopy,
- /*workloadPerWorkgroup=*/ArrayRef<int64_t>{},
- /*workgroupSize=*/ArrayRef<int64_t>{});
- } else {
- if (failed(setDefaultOpConfig(limits, computeOp))) return failure();
-
- // Check if the op configuration was set.
- if (!getLoweringConfig(computeOp)) {
- return computeOp->emitOpError(
- "without known roots, the last compute operation in the tiled "
- "loop body is expected to be set as root");
- }
+ // Check if the op configuration was set.
+ if (!getLoweringConfig(computeOp)) {
+ return computeOp->emitOpError(
+ "without known roots, the last compute operation in the tiled "
+ "loop body is expected to be set as root");
}
rootOperation = computeOp;
}
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index d7dffea..d5581d7 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -251,7 +251,7 @@
void buildSPIRVCodegenPassPipeline(OpPassManager &pm) {
pm.nest<ModuleOp>().nest<FuncOp>().addPass(createTypePropagationPass());
-
+ pm.nest<ModuleOp>().addPass(createBufferizeCopyOnlyDispatchesPass());
pm.addPass(createSPIRVLowerExecutableTargetPass());
addMemRefLoweringPasses(pm.nest<ModuleOp>());
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
index 7a491c2..7dc3add 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
@@ -193,7 +193,7 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
-hal.executable private @tensor_insert {
+hal.executable private @copy_op {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirvfb", {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
max_compute_shared_memory_size = 32768 : i32,
@@ -201,78 +201,30 @@
max_compute_workgroup_size = dense<512> : vector<3xi32>,
subgroup_size = 16 : i32}>
}> {
- hal.executable.entry_point @tensor_insert layout(#executable_layout)
+ hal.executable.entry_point @copy_op layout(#executable_layout)
builtin.module {
- func.func @tensor_insert() {
+ func.func @copy_op() {
%offset_y = hal.interface.constant.load[0] : index
%offset_x = hal.interface.constant.load[1] : index
%source_size_y = hal.interface.constant.load[2] : index
%source_size_x = hal.interface.constant.load[3] : index
%dest_size_y = hal.interface.constant.load[4] : index
%dest_size_x = hal.interface.constant.load[5] : index
- %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x}
- %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
- %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%source_size_y, %source_size_y], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x} -> tensor<?x?xf32>
- %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
- : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x} -> tensor<?x?xf32>
- %insert = tensor.insert_slice %source into %dest[%offset_y, %offset_x] [%source_size_y, %source_size_x] [1, 1]
- : tensor<?x?xf32> into tensor<?x?xf32>
- flow.dispatch.tensor.store %insert, %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
- : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
+ %source = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xf32>{%source_size_y, %source_size_x}
+ %dest = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xf32>{%dest_size_y, %dest_size_x}
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%source : memref<?x?xf32>) outs(%dest : memref<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ linalg.yield %b0 : f32
+ }
return
}
}
}
}
-// Check that the pipeline is set to `SPIRVDistributeAndCopy`
-
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVDistributeCopy>
-// CHECK: tensor.insert_slice
-// CHECK-NOT: lowering_config
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
- #hal.descriptor_set.layout<0, bindings = [
- #hal.descriptor_set.binding<0, storage_buffer>,
- #hal.descriptor_set.binding<1, storage_buffer>
- ]>
-]>
-hal.executable private @tensor_extract {
- hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirvfb", {
- spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
- max_compute_shared_memory_size = 32768 : i32,
- max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<512> : vector<3xi32>,
- subgroup_size = 16 : i32}>
- }> {
- hal.executable.entry_point @tensor_extract layout(#executable_layout)
- builtin.module {
- func.func @tensor_extract() {
- %offset_y = hal.interface.constant.load[0] : index
- %offset_x = hal.interface.constant.load[1] : index
- %source_size_y = hal.interface.constant.load[2] : index
- %source_size_x = hal.interface.constant.load[3] : index
- %result_size_y = hal.interface.constant.load[4] : index
- %result_size_x = hal.interface.constant.load[5] : index
- %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
- : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x}
- %result_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
- : !flow.dispatch.tensor<writeonly:?x?xf32>{%result_size_y, %result_size_x}
- %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%source_size_y, %source_size_y], strides = [1, 1]
- : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x} -> tensor<?x?xf32>
- %extract = tensor.extract_slice %source[%offset_y, %offset_x] [%result_size_y, %result_size_x] [1, 1]
- : tensor<?x?xf32> to tensor<?x?xf32>
- flow.dispatch.tensor.store %extract, %result_binding, offsets = [0, 0], sizes = [%result_size_y, %result_size_x], strides = [1, 1]
- : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%result_size_y, %result_size_x}
- return
- }
- }
- }
-}
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVDistributeCopy>
-// CHECK: tensor.extract_slice
-// CHECK-NOT: lowering_config
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 16], [1, 1]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVDistribute>
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
index a850da3..db31e94 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
@@ -79,7 +79,7 @@
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 2, 32, 1], [0, 1, 1, 1]{{\]}}>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 1, 64], [0, 1, 1, 1]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVDistribute>
// CHECK: hal.executable.entry_point public @copy
// CHECK-SAME: translation_info = #[[TRANSLATION]]
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 98ad38b..c51e751 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -74,6 +74,26 @@
return triple && triple.getValue().isRISCV();
}
+bool isReadOnly(Value v) {
+ Operation *definingOp = v.getDefiningOp();
+ if (!definingOp) return false;
+ return TypeSwitch<Operation *, bool>(definingOp)
+ .Case<arith::ConstantOp>(
+ [&](arith::ConstantOp constantOp) { return true; })
+ .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>(
+ [&](auto op) { return isReadOnly(op.src()); })
+ .Case<tensor::CastOp, tensor::ExtractSliceOp>(
+ [&](auto op) { return isReadOnly(op.source()); })
+ .Case<IREE::Flow::DispatchTensorLoadOp>(
+ [&](IREE::Flow::DispatchTensorLoadOp loadOp) {
+ return loadOp.source()
+ .getType()
+ .cast<IREE::Flow::DispatchTensorType>()
+ .getAccess() == IREE::Flow::TensorAccess::ReadOnly;
+ })
+ .Default([&](Operation *op) { return false; });
+}
+
//===----------------------------------------------------------------------===//
// Utility functions to set configurations
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Codegen/Utils/Utils.h b/iree/compiler/Codegen/Utils/Utils.h
index e137eb1..24fd7cd 100644
--- a/iree/compiler/Codegen/Utils/Utils.h
+++ b/iree/compiler/Codegen/Utils/Utils.h
@@ -57,6 +57,11 @@
return isVMVXBackend(variantOp);
}
+/// Checks if a tensor value is generated from a read-only object, like
+/// and interface binding with read-only attribute or from an `arith.constant`
+/// operation.
+bool isReadOnly(Value v);
+
//===----------------------------------------------------------------------===//
// Utility functions to set configurations
//===----------------------------------------------------------------------===//
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 4bbcdb1..7e5e146 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 4bbcdb1bf868b9288329fa501d498761abcaa92c
+Subproject commit 7e5e146f0f31fab1e45e5554ce93068655098fcb