[spirv] Fuse tensor.pad ops into their consumer ops (#8049)
It's typical for convolutions to require padding. Right now we handle
such padding with two separate steps, first performing `vkCmdFillBuffer`
to initialize the target buffer and then launch a specific kernel to
copy over the source content. These two steps also means two additional
pipeline barriers to sequence the execution.
This commit instead enables fusion of tensor.pad ops into the consumer
ops directly to avoid generating the buffer fill and copy. This is
expected to improve performance on mobile GPUs.
After fusing the tensor.pad op, we generate the kernel to contain both
a fast and a slow path. The fast path are for inner tiles that do not
need padding; the slow path is for boundary tiles that do need padding.
Both paths are vectorized properly.
diff --git a/benchmarks/TFLite/CMakeLists.txt b/benchmarks/TFLite/CMakeLists.txt
index 52c874a..4ff6f12 100644
--- a/benchmarks/TFLite/CMakeLists.txt
+++ b/benchmarks/TFLite/CMakeLists.txt
@@ -548,6 +548,7 @@
"GPU-Adreno"
TRANSLATION_FLAGS
${ANDROID_ADRENO_GPU_TRANSLATION_FLAGS}
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
BENCHMARK_TOOL
iree-benchmark-module
DRIVER
@@ -572,6 +573,7 @@
"GPU-Mali-Valhall"
TRANSLATION_FLAGS
${ANDROID_MALI_GPU_TRANSLATION_FLAGS}
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
BENCHMARK_TOOL
iree-benchmark-module
DRIVER
@@ -592,6 +594,7 @@
"--iree-input-type=tosa"
"--iree-flow-demote-f32-to-f16"
"--iree-vulkan-target-triple=valhall-unknown-android11"
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
BENCHMARK_TOOL
iree-benchmark-module
DRIVER
@@ -628,6 +631,7 @@
"GPU-Adreno"
TRANSLATION_FLAGS
${ANDROID_ADRENO_GPU_TRANSLATION_FLAGS}
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
"--iree-hal-benchmark-dispatch-repeat-count=16"
BENCHMARK_TOOL
iree-benchmark-module
@@ -655,6 +659,7 @@
"GPU-Mali-Valhall"
TRANSLATION_FLAGS
${ANDROID_MALI_GPU_TRANSLATION_FLAGS}
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
"--iree-hal-benchmark-dispatch-repeat-count=32"
BENCHMARK_TOOL
iree-benchmark-module
@@ -678,6 +683,7 @@
"--iree-input-type=tosa"
"--iree-flow-demote-f32-to-f16"
"--iree-vulkan-target-triple=valhall-unknown-android11"
+ "--iree-flow-enable-fuse-padding-into-consumer-ops"
"--iree-hal-benchmark-dispatch-repeat-count=32"
BENCHMARK_TOOL
iree-benchmark-module
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 5ae0ea4..ddda5eb 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -403,6 +403,15 @@
/// having pointer bitcast.
std::unique_ptr<OperationPass<ModuleOp>> createSPIRVVectorizeLoadStore();
+/// Fuses tensor.pad ops into their consumer ops' tiled loop nests.
+std::unique_ptr<OperationPass<FuncOp>>
+createSPIRVFuseTensorPadWithConsumerPass();
+
+// Uses `tensor.pad` ops as anchors to create separate fast and slow paths
+// inside the kernel. The fast path is for inner tiles where we don't need
+// padding, while the slow path is for boundary tiles where we do need padding.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVCreateFastSlowPathPass();
+
//----------------------------------------------------------------------------//
// SPIRV Codegen Pass Pipelines.
//----------------------------------------------------------------------------//
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 5793bcb..f5d20a4 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -354,6 +354,18 @@
let constructor = "mlir::iree_compiler::createSPIRVVectorizeLoadStore()";
}
+def SPIRVFuseTensorPadWithConsumer :
+ Pass<"iree-spirv-fuse-tensor-pad-with-consumer", "FuncOp"> {
+ let summary = "Fuse tensor.pad op into its consumer op's tiled loop nest";
+ let constructor = "mlir::iree_compiler::createSPIRVFuseTensorPadWithConsumerPass()";
+}
+
+def SPIRVCreateFastSlowPath :
+ Pass<"iree-spirv-create-fast-slow-path", "FuncOp"> {
+ let summary = "Create separate fast and slow paths to handle padding";
+ let constructor = "mlir::iree_compiler::createSPIRVCreateFastSlowPathPass()";
+}
+
//------------------------------------------------------------------------------
// Test passes
//------------------------------------------------------------------------------
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index bc8043f..c3a6370 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -19,7 +19,9 @@
"MaliConfig.cpp",
"NVIDIAConfig.cpp",
"Passes.cpp",
+ "SPIRVCreateFastSlowPath.cpp",
"SPIRVDistribute.cpp",
+ "SPIRVFuseTensorPadWithConsumer.cpp",
"SPIRVInitConfigPass.cpp",
"SPIRVLowerExecutableTargetPass.cpp",
"SPIRVTile.cpp",
@@ -50,12 +52,14 @@
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineToStandardTransforms",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithmeticToSPIRV",
"@llvm-project//mlir:ArithmeticTransforms",
+ "@llvm-project//mlir:ArithmeticUtils",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ControlFlowToSPIRV",
"@llvm-project//mlir:DialectUtils",
@@ -77,12 +81,15 @@
"@llvm-project//mlir:SCFToGPUPass",
"@llvm-project//mlir:SCFToSPIRV",
"@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:SPIRVConversion",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:SPIRVTransforms",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorToSPIRV",
+ "@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TosaToStandard",
"@llvm-project//mlir:Transforms",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index a164446..2e4a5bb 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -24,7 +24,9 @@
"MaliConfig.cpp"
"NVIDIAConfig.cpp"
"Passes.cpp"
+ "SPIRVCreateFastSlowPath.cpp"
"SPIRVDistribute.cpp"
+ "SPIRVFuseTensorPadWithConsumer.cpp"
"SPIRVInitConfigPass.cpp"
"SPIRVLowerExecutableTargetPass.cpp"
"SPIRVTile.cpp"
@@ -40,12 +42,14 @@
IREELinalgExtPasses
LLVMSupport
MLIRAffine
+ MLIRAffineAnalysis
MLIRAffineToStandard
MLIRAffineUtils
MLIRAnalysis
MLIRArithmetic
MLIRArithmeticToSPIRV
MLIRArithmeticTransforms
+ MLIRArithmeticUtils
MLIRBufferization
MLIRControlFlowToSPIRV
MLIRFunc
@@ -66,12 +70,15 @@
MLIRSCFToGPU
MLIRSCFToSPIRV
MLIRSCFTransforms
+ MLIRSCFUtils
MLIRSPIRV
MLIRSPIRVConversion
MLIRSPIRVTransforms
MLIRSideEffectInterfaces
MLIRSupport
+ MLIRTensor
MLIRTensorToSPIRV
+ MLIRTensorTransforms
MLIRTosa
MLIRTosaToStandard
MLIRTransforms
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 7973a36..e0e8cf0 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -46,8 +46,13 @@
.getType()
.cast<ShapedType>()
.getShape();
- if (llvm::any_of(inputShape, ShapedType::isDynamic)) return success();
- if (llvm::any_of(outputShape, ShapedType::isDynamic)) return success();
+ if (isa<linalg::Conv2DNhwcHwcfOp>(*linalgOp) &&
+ ShapedType::isDynamic(inputShape[3])) {
+ return success();
+ }
+ if (llvm::any_of(outputShape.drop_front(), ShapedType::isDynamic)) {
+ return success();
+ }
int64_t ic = inputShape[3];
int64_t oh = outputShape[1], ow = outputShape[2], oc = outputShape[3];
diff --git a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index de7c0b8..d5811b8 100644
--- a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -41,12 +41,16 @@
return setMatmulOpConfig(op, workgroupXY, threadMNK);
})
.Case<linalg::Conv2DNhwcHwcfOp>([subgroupSize](auto op) {
- return setConvOpConfig(op, subgroupSize,
- /*bestTilingFactor=*/16);
+ bool hasPaddedInput =
+ op.image().template getDefiningOp<tensor::PadOp>();
+ int bestTilingFactor = hasPaddedInput ? 8 : 16;
+ return setConvOpConfig(op, subgroupSize, bestTilingFactor);
})
.Case<linalg::DepthwiseConv2DNhwcHwcOp>([subgroupSize](auto op) {
- return setConvOpConfig(op, subgroupSize,
- /*bestTilingFactor=*/16);
+ bool hasPaddedInput =
+ op.image().template getDefiningOp<tensor::PadOp>();
+ int bestTilingFactor = hasPaddedInput ? 8 : 16;
+ return setConvOpConfig(op, subgroupSize, bestTilingFactor);
})
.Default([](Operation *) { return success(); });
}
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index 934ec02..ae7edac 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -136,15 +136,23 @@
void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+ pm.addNestedPass<FuncOp>(createSPIRVFuseTensorPadWithConsumerPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
+ pm.addNestedPass<FuncOp>(createFoldAffineMinInDistributedLoopsPass());
+ pm.addPass(memref::createResolveShapedTypeResultDimsPass());
+
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
// Tile to GPU invocations and vectorize.
+ pm.addNestedPass<FuncOp>(createSPIRVCreateFastSlowPathPass());
pm.addNestedPass<FuncOp>(createSPIRVTilePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
pm.addNestedPass<FuncOp>(createSPIRVVectorizePass());
+ pm.addNestedPass<FuncOp>(createForOpCanonicalizationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
@@ -240,6 +248,7 @@
pm.nest<ModuleOp>().nest<FuncOp>().addPass(createTypePropagationPass());
pm.addPass(createSPIRVLowerExecutableTargetPass());
+
addMemRefLoweringPasses(pm.nest<ModuleOp>());
addSPIRVLoweringPasses(pm.nest<ModuleOp>());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp b/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp
new file mode 100644
index 0000000..8177823
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVCreateFastSlowPath.cpp
@@ -0,0 +1,170 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and passes to use `tensor.pad` ops as anchors
+// to create separate fast and slow paths inside the kernel. The fast path
+// is for inner tiles where we don't need padding, while the slow path is for
+// boundary tiles where we do need padding.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-create-fast-slow-path"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Returns true if the the given `attrOrValue` is a constant zero.
+static bool isZero(OpFoldResult attrOrValue) {
+ if (Optional<int64_t> val = getConstantIntValue(attrOrValue))
+ return val.getValue() == 0;
+ return false;
+}
+
+namespace {
+
+/// Uses the `tensor.pad` ops as anchors to create separate fast and slow paths
+/// inside the kernel. The fast path is for inner tiles where we don't need
+/// padding, while the slow path is for boundary tiles where we do need padding.
+///
+/// This pattern works by creating an `scf.if` op with conditions derived from
+/// `tensor.pad` op padding sizes, and copying all ops excluding those for
+/// computing padding sizes to both regions of the `scf.if` op.
+struct CreateFastSlowPath final : public OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ // Flow tiled and distributed loops do not carry values.
+ if (!llvm::empty(forOp.getIterOpOperands())) return failure();
+ Block *forBody = forOp.getBody(0);
+
+ // Find the anchor tensor.pad op, from which we get the conditions for
+ // switching between the fast and slow path.
+ auto padOps = llvm::to_vector<4>(forBody->getOps<tensor::PadOp>());
+ if (llvm::size(padOps) != 1) return failure();
+ tensor::PadOp padOp = *padOps.begin();
+
+ // If all padding sizes are zero, we don't need to do anything.
+ SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad();
+ if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero))
+ return failure();
+
+ rewriter.setInsertionPoint(forBody->getTerminator());
+ SmallVector<Operation *, 16> allOps;
+ for (Operation &op : forBody->without_terminator()) allOps.push_back(&op);
+
+ auto isDefinedInForRegion = [&](Operation *op) {
+ return op->getParentRegion() == &forOp.getLoopBody();
+ };
+ SetVector<Operation *> padSizeOps;
+
+ // Build the condition for the scf.if op: all pad sizes are zero.
+ Location loc = padOp.getLoc();
+ Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> eqZeroCmpVals;
+ for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) {
+ if (auto padValue = pad.dyn_cast<Value>()) {
+ getBackwardSlice(padValue, &padSizeOps, isDefinedInForRegion);
+ padSizeOps.insert(padValue.getDefiningOp());
+ }
+ if (!isZero(pad)) {
+ eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq,
+ getValueOrCreateConstantIndexOp(rewriter, loc, pad), cstZero));
+ }
+ }
+ Value ifCond = eqZeroCmpVals.front();
+ for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front())
+ ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp);
+
+ SmallVector<Operation *> cloneOps;
+ for (Operation *op : allOps) {
+ if (!padSizeOps.contains(op)) cloneOps.push_back(op);
+ }
+
+ // Build the scf.if op itself. Clone all ops other than those used for
+ // computing padding sizes. For the "then" branch, we can elide the padding.
+ // For the "else" branch, we retain the clone op.
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ BlockAndValueMapping bvm;
+ for (Operation *op : cloneOps) {
+ if (op == padOp.getOperation()) {
+ // We can elide the tensor.pad op. Just use its source.
+ bvm.map(padOp.getResult(), bvm.lookupOrDefault(padOp.source()));
+ } else {
+ builder.clone(*op, bvm);
+ }
+ }
+ builder.create<scf::YieldOp>(loc);
+ };
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ BlockAndValueMapping bvm;
+ for (Operation *op : cloneOps) builder.clone(*op, bvm);
+ builder.create<scf::YieldOp>(loc);
+ };
+ rewriter.create<scf::IfOp>(padOp.getLoc(), ifCond, thenBuilder,
+ elseBuilder);
+
+ // All of these ops have been cloned to both regions. Erease them now.
+ for (Operation *op : llvm::reverse(cloneOps)) rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct SPIRVCreateFastSlowPathPass final
+ : public SPIRVCreateFastSlowPathBase<SPIRVCreateFastSlowPathPass> {
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ FuncOp funcOp = getOperation();
+
+ {
+ RewritePatternSet patterns(context);
+ patterns.add<CreateFastSlowPath>(context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Canonicalize the generated scf.if ops. We might have trivially dead
+ // branches, in which the sizes might be incorrect due to eliding the
+ // tensor.pad op.
+ {
+ RewritePatternSet patterns(context);
+ scf::IfOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVCreateFastSlowPathPass() {
+ return std::make_unique<SPIRVCreateFastSlowPathPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVFuseTensorPadWithConsumer.cpp b/iree/compiler/Codegen/SPIRV/SPIRVFuseTensorPadWithConsumer.cpp
new file mode 100644
index 0000000..c6e9215
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVFuseTensorPadWithConsumer.cpp
@@ -0,0 +1,42 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct SPIRVFuseTensorPadWithConsumerPass final
+ : public SPIRVFuseTensorPadWithConsumerBase<
+ SPIRVFuseTensorPadWithConsumerPass> {
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ FuncOp funcOp = getOperation();
+
+ RewritePatternSet patterns(context);
+ patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
+ context, [](tensor::ExtractSliceOp) { return false; });
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+createSPIRVFuseTensorPadWithConsumerPass() {
+ return std::make_unique<SPIRVFuseTensorPadWithConsumerPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index a274609..a33d213 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
index 8421c4a..80efa5e 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
@@ -10,15 +10,20 @@
//
//===----------------------------------------------------------------------===//
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-spirv-tile"
@@ -49,6 +54,106 @@
filter);
}
+/// Gets the given `attrOrValue` as an index value by creating constant ops
+/// for attributes.
+static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder,
+ Location loc) {
+ IntegerAttr attr;
+ if (Value val = attrOrValue.dyn_cast<Value>()) {
+ if (val.getType().isIndex()) return val;
+ matchPattern(val, m_Constant(&attr));
+ } else {
+ attr = attrOrValue.get<Attribute>().cast<IntegerAttr>();
+ }
+ return builder.createOrFold<arith::ConstantIndexOp>(
+ loc, attr.getValue().getSExtValue());
+}
+
+namespace {
+
+/// Concretizes tensor.pad op's result shape if its source op implements
+/// OffsetSizeAndStrideOpInterface. For example, pad(extract_slice).
+struct ConcretizePadResultShape final : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ // If the result shape is already static, then nothing to do.
+ if (padOp.getResultType().hasStaticShape()) return failure();
+
+ int rank = padOp.getResultType().getRank();
+ SmallVector<int64_t> staticShape;
+ staticShape.reserve(rank);
+
+ auto sourceIfxOp = dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(
+ padOp.source().getDefiningOp());
+ if (!sourceIfxOp) return failure();
+
+ SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> source = sourceIfxOp.getMixedSizes();
+ SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+
+ MLIRContext *context = padOp.getContext();
+ Location loc = padOp.getLoc();
+
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(context, sym0, sym1, sym2);
+ auto addMap = AffineMap::get(0, 3, {sym0 + sym1 + sym2}, context);
+
+ SmallVector<Value, 3> valueSizes;
+ for (int dimIndex = 0; dimIndex < rank; ++dimIndex) {
+ valueSizes.clear();
+ valueSizes.push_back(getAsIndexValue(lowPad[dimIndex], rewriter, loc));
+ valueSizes.push_back(getAsIndexValue(source[dimIndex], rewriter, loc));
+ valueSizes.push_back(getAsIndexValue(highPad[dimIndex], rewriter, loc));
+
+ // The pad op's result shape is low padding + source size + high padding.
+ // Try to see if we can get a constant number by composing and
+ // canonicalizing the result. We use affine mechanisms here because
+ // generating arithmetic add ops over dim ops won't work, given they are
+ // SSA values that would need invoking other patterns to simplify. We
+ // cannot invoke patterns in patterns.
+ AffineMap map = addMap;
+ fullyComposeAffineMapAndOperands(&map, &valueSizes);
+ canonicalizeMapAndOperands(&map, &valueSizes);
+
+ auto cstExpr = map.getResult(0).dyn_cast<AffineConstantExpr>();
+ // Specially handle the case where we have both dimensions and symbols and
+ // they map to the same value, e.g.:
+ // affine_map<(d0, s0) -> (d0 - s0 + 4)>(%v, %v).
+ // Due to the restrictions over dimensions and symbols, the above won't
+ // simplify. Try to change dimensions for symbols for such cases.
+ if (!cstExpr && llvm::is_splat(valueSizes)) {
+ int numDims = map.getNumDims();
+ int numSyms = map.getNumSymbols();
+ DenseMap<AffineExpr, AffineExpr> dimToSymMap;
+ for (int i = 0; i < numDims; ++i) {
+ dimToSymMap[rewriter.getAffineDimExpr(i)] =
+ rewriter.getAffineSymbolExpr(numSyms + i);
+ }
+ map = map.replace(dimToSymMap, /*numResultDims=*/0,
+ /*numResultSyms=*/numDims + numSyms);
+
+ canonicalizeMapAndOperands(&map, &valueSizes);
+ cstExpr = map.getResult(0).dyn_cast<AffineConstantExpr>();
+ }
+ if (!cstExpr) return failure();
+
+ staticShape.push_back(cstExpr.getValue());
+ }
+
+ auto resultType = RankedTensorType::get(
+ staticShape, padOp.getResultType().getElementType(),
+ padOp.getResultType().getEncoding());
+
+ rewriter.updateRootInPlace(padOp,
+ [&]() { padOp.result().setType(resultType); });
+ return success();
+ }
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// Main pass
//===----------------------------------------------------------------------===//
@@ -64,14 +169,46 @@
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
+ // Try to find computation ops which we will use as anchor to tile and fuse
+ // again. If there are `scf.if` ops, we have both a fast and slow paths for
+ // padding handling. Then we need to scan both regions to discover such
+ // computation ops so that we can tile and fuse both regions.
SmallVector<Operation *> computeOps;
- SmallVector<LoopTilingAndDistributionInfo> loopInfos;
- if (failed(getComputeOps(funcOp, computeOps, loopInfos))) {
- return signalPassFailure();
- }
+ SmallVector<scf::IfOp, 1> ifOps;
+ funcOp.walk([&ifOps](scf::IfOp ifOp) { ifOps.push_back(ifOp); });
+ if (ifOps.empty()) {
+ SmallVector<LoopTilingAndDistributionInfo> loopInfos;
+ if (failed(getComputeOps(funcOp, computeOps, loopInfos))) {
+ return signalPassFailure();
+ }
+ while (computeOps.size() > 1) computeOps.erase(computeOps.begin());
+ } else {
+ if (ifOps.size() > 1) {
+ funcOp.emitError("expected to contain no more than one scf.if ops");
+ return signalPassFailure();
+ }
- { // Tile to invocations.
- auto consumerOp = dyn_cast<linalg::LinalgOp>(computeOps.back());
+ for (Operation &op : llvm::reverse(*ifOps.front().thenBlock())) {
+ if (isa<linalg::LinalgOp, IREE::LinalgExt::TiledOpInterface>(op)) {
+ computeOps.push_back(&op);
+ break;
+ }
+ }
+ if (Block *elseBlock = ifOps.front().elseBlock()) {
+ for (Operation &op : llvm::reverse(*elseBlock)) {
+ if (isa<linalg::LinalgOp, IREE::LinalgExt::TiledOpInterface>(op)) {
+ computeOps.push_back(&op);
+ break;
+ }
+ }
+ }
+ }
+ assert(computeOps.size() <= 2);
+
+ // Now tile the last computation op to invocations and fuse all operand
+ // computation ops into the materialized loop nest.
+ for (Operation *computeOp : computeOps) {
+ auto consumerOp = dyn_cast<linalg::LinalgOp>(computeOp);
OpBuilder builder(context);
SmallVector<int64_t> tileSizes = getTileSizes(consumerOp, 1);
@@ -94,15 +231,28 @@
for (int i = loops.size() - 1, dim = 0; i >= 0; --i) {
loops[i]->setAttr(attrName, builder.getIndexAttr(dim++));
}
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling to invocations ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ { // Fuse `tensor.pad` op inside the materalized loop nest too.
+ RewritePatternSet patterns(context);
+ patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
+ context, [](tensor::ExtractSliceOp) { return false; });
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling to invocations ---\n";
+ llvm::dbgs() << "--- After fusing padding into consumers ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
- {
+ { // Canonicalize.
RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
// Pulling in upstream scf.for and affine.min canonicalization patterns.
@@ -111,6 +261,11 @@
// Pulling in IREE scf.for and affine.min canonicalization patterns.
// They work on tiled and distributed loops.
populateFoldAffineMinInDistributedLoopsPatterns(patterns);
+ // Pulling in flow.dispatch.tensor.load op canonicalization patterns.
+ // Tiling can generate dim ops taking them as operands.
+ IREE::Flow::DispatchTensorLoadOp::getCanonicalizationPatterns(patterns,
+ context);
+ patterns.add<ConcretizePadResultShape>(context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
@@ -120,14 +275,13 @@
});
}
- {
- // Set markers to drive tiling reduction dimensions.
+ { // Set markers to drive tiling reduction dimensions.
OpBuilder builder(context);
auto marker = builder.getStringAttr(getTileReductionMarker());
funcOp.walk([&](linalg::LinalgOp op) {
if (isa<linalg::ContractionOpInterface>(*op) ||
isa<linalg::ConvolutionOpInterface>(*op)) {
- op->setAttr("__internal_linalg_transform__", marker);
+ op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker, marker);
}
});
}
@@ -147,16 +301,32 @@
});
}
- {
+ { // Fuse `tensor.pad` op inside the materalized loop nest too.
+ RewritePatternSet patterns(context);
+ patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
+ context, [](tensor::ExtractSliceOp) { return false; });
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After fusing padding into consumers ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ { // Canonicalize.
RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
// Pulling in upstream scf.for and affine.min canonicalization patterns.
// They work on tiled (but not distributed) loops. We only tiled reduction
// loops previously so this should be fine.
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
+ // Pulling in flow.dispatch.tensor.load op canonicalization patterns.
+ // Tiling can generate dim ops taking them as operands.
+ IREE::Flow::DispatchTensorLoadOp::getCanonicalizationPatterns(patterns,
+ context);
+ patterns.add<ConcretizePadResultShape>(context);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling canonicalization ---\n";
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 9bdeb80..e0286b6 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -48,7 +48,8 @@
} else if (auto vtOp = dyn_cast<VectorTransferOpInterface>(op)) {
auto vecType = vtOp.getVectorType();
SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1);
- for (auto dim : llvm::enumerate(vtOp.permutation_map().getResults())) {
+ for (const auto &dim :
+ llvm::enumerate(vtOp.permutation_map().getResults())) {
if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) {
if (dimExpr.getPosition() == vtOp.permutation_map().getNumDims() - 1) {
nativeSize[dim.index()] =
@@ -59,7 +60,7 @@
return nativeSize;
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
unsigned lastParalleldim = 0;
- for (auto it : llvm::enumerate(contractOp.iterator_types())) {
+ for (const auto &it : llvm::enumerate(contractOp.iterator_types())) {
if (isParallelIterator(it.value())) lastParalleldim = it.index();
}
SmallVector<int64_t, 4> nativeSize(contractOp.iterator_types().size(), 1);
@@ -78,6 +79,7 @@
patterns.add<linalg::LinalgVectorizationPattern>(
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>(),
opt);
+ populateVectorizePadPatterns(patterns);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorReductionToContractPatterns(patterns);
}
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 0ffb36d..2fa72eb 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -29,6 +29,7 @@
"config_mali_matmul.mlir",
"config_nvidia_matmul_cooperative_ops.mlir",
"convert_to_spirv.mlir",
+ "create_fast_slow_path.mlir",
"distribute_to_invocations.mlir",
"pipeline_matmul_cooperative_ops.mlir",
"pipeline_matmul_vectorization.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index c5dbb74..2714f0b 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -24,6 +24,7 @@
"config_mali_matmul.mlir"
"config_nvidia_matmul_cooperative_ops.mlir"
"convert_to_spirv.mlir"
+ "create_fast_slow_path.mlir"
"distribute_to_invocations.mlir"
"pipeline_matmul_cooperative_ops.mlir"
"pipeline_matmul_vectorization.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/create_fast_slow_path.mlir b/iree/compiler/Codegen/SPIRV/test/create_fast_slow_path.mlir
new file mode 100644
index 0000000..008d001
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/create_fast_slow_path.mlir
@@ -0,0 +1,97 @@
+// RUN: iree-opt -split-input-file -iree-spirv-create-fast-slow-path -mlir-print-local-scope %s | FileCheck %s
+
+func @padded_conv() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c112 = arith.constant 112 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:1x224x224x3xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:1x112x112x32xf32>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ scf.for %arg0 = %workgroup_id_z to %c112 step %workgroup_count_z {
+ %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y]
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_y]
+ scf.for %arg1 = %4 to %c112 step %5 {
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
+ scf.for %arg2 = %6 to %c32 step %7 {
+ %8 = flow.dispatch.tensor.load %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 1, 4, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x1x4x32xf32>
+ %9 = linalg.init_tensor [1, 1, 4, 32] : tensor<1x1x4x32xf32>
+ %10 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
+ %11 = affine.min affine_map<(d0) -> (d0 * 2 + 3, 224)>(%arg0)
+ %12 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%11, %arg0)
+ %13 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 3)>(%11, %arg0)
+ %14 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
+ %15 = affine.min affine_map<(d0) -> (d0 * 2 + 9, 224)>(%arg1)
+ %16 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%15, %arg1)
+ %17 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 9)>(%15, %arg1)
+ %18 = flow.dispatch.tensor.load %0, offsets = [0, %10, %14, 0], sizes = [1, %12, %16, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x224x224x3xf32> -> tensor<1x?x?x3xf32>
+ %19 = tensor.pad %18 low[0, 0, 0, 0] high[0, %13, %17, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x?x3xf32> to tensor<1x?x?x3xf32>
+ %20 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x32xf32>
+ %21 = linalg.fill(%cst, %9) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4]]>} : f32, tensor<1x1x4x32xf32> -> tensor<1x1x4x32xf32>
+ %22 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4]]>, strides = dense<2> : tensor<2xi64>} ins(%19, %20 : tensor<1x?x?x3xf32>, tensor<3x3x3x32xf32>) outs(%21 : tensor<1x1x4x32xf32>) -> tensor<1x1x4x32xf32>
+ %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22, %8 : tensor<1x1x4x32xf32>, tensor<1x1x4x32xf32>) outs(%9 : tensor<1x1x4x32xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4]]>} {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %24 = arith.subf %arg3, %arg4 : f32
+ linalg.yield %24 : f32
+ } -> tensor<1x1x4x32xf32>
+ flow.dispatch.tensor.store %23, %3, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, 1, 4, 32], strides = [1, 1, 1, 1] : tensor<1x1x4x32xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @padded_conv
+
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+
+// CHECK: scf.for %[[IV0:.+]] =
+// CHECK: scf.for %[[IV1:.+]] =
+// CHECK: scf.for
+
+// CHECK: %[[MIN0:.+]] = affine.min affine_map<(d0) -> (d0 * 2 + 3, 224)>(%[[IV0]])
+// CHECK: %[[SIZE0:.+]] = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 3)>(%[[MIN0]], %[[IV0]])
+// CHECK: %[[MIN1:.+]] = affine.min affine_map<(d0) -> (d0 * 2 + 9, 224)>(%[[IV1]])
+// CHECK: %[[SIZE1:.+]] = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 9)>(%[[MIN1]], %[[IV1]])
+// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[SIZE0]], %[[C0]] : index
+// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[SIZE1]], %[[C0]] : index
+// CHECK: %[[COND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1
+
+// CHECK: scf.if %[[COND]] {
+
+// CHECK: flow.dispatch.tensor.load
+// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load
+// CHECK: %[[FILTER:.+]] = flow.dispatch.tensor.load
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: ins(%[[INPUT]], %[[FILTER]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CONV]]
+// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
+
+// CHECK: } else {
+
+// CHECK: flow.dispatch.tensor.load
+// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load
+// CHECK: %[[PAD:.+]] = tensor.pad %[[INPUT]]
+// CHECK: %[[FILTER:.+]] = flow.dispatch.tensor.load
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: ins(%[[PAD]], %[[FILTER]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CONV]]
+// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
+
+// CHECK: }
+
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index 71485ba..e4ba66d 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile,iree-spirv-vectorize))))' %s | FileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,iree-spirv-vectorize))))' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]>
#translation = #iree_codegen.translation_info<SPIRVVectorize, workload_per_wg = [16, 4, 4]>
@@ -177,3 +177,257 @@
// For linalg.depthwise_conv_2d_nhwc_hwc
// CHECK: vector.transfer_write
+
+// -----
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4]]>
+#translation = #iree_codegen.translation_info<SPIRVVectorize, workload_per_wg = [32, 4, 1]>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>
+ ]>
+]>
+
+hal.executable private @low_padded_conv {
+ hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @low_padded_conv layout(#executable_layout) {
+ workgroup_size = [8: index, 2: index, 1: index],
+ translation_info = #translation
+ }
+ builtin.module {
+ func @low_padded_conv() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c112 = arith.constant 112 : index
+ %c32 = arith.constant 32 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x224x224x3xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x112x112x32xf32>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ %4 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+ scf.for %arg0 = %5 to %c112 step %6 {
+ %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg1 = %7 to %c112 step %8 {
+ %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg2 = %9 to %c32 step %10 {
+ %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
+ %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
+ %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
+ %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
+ %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+ %16 = flow.dispatch.tensor.load %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %13, %14, %15], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x?x?x?xf32>
+ %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
+ %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
+ %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+ %20 = tensor.extract_slice %4[0, %arg0, %arg1, %arg2] [1, %17, %18, %19] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x?x?x?xf32>
+ %21 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 225)>(%11, %arg0)
+ %22 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 225)>(%12, %arg1)
+ %23 = affine.min affine_map<(d0) -> (d0 * 2, 224)>(%arg0)
+ %24 = affine.min affine_map<(d0, d1) -> (d0 + d1 * 2, 224)>(%21, %arg0)
+ %25 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%24, %23)
+ %26 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%21, %24, %23)
+ %27 = affine.min affine_map<(d0) -> (d0 * 2, 224)>(%arg1)
+ %28 = affine.min affine_map<(d0, d1) -> (d0 + d1 * 2, 224)>(%22, %arg1)
+ %29 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%28, %27)
+ %30 = affine.apply affine_map<(d0, d1, d2) -> (d0 - d1 + d2)>(%22, %28, %27)
+ %31 = flow.dispatch.tensor.load %0, offsets = [0, %23, %27, 0], sizes = [1, %25, %29, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x224x224x3xf32> -> tensor<1x?x?x3xf32>
+ %32 = tensor.pad %31 low[0, 0, 0, 0] high[0, %26, %30, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x?x3xf32> to tensor<1x?x?x3xf32>
+ %33 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+ %34 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %33], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x?xf32>
+ %35 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
+ %36 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
+ %37 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+ %38 = linalg.init_tensor [1, %35, %36, %37] : tensor<1x?x?x?xf32>
+ %39 = linalg.fill(%cst, %38) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
+ %40 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%32, %34 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%39 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+ %41 = linalg.generic {lowering_config = #config, indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%40, %16 : tensor<1x?x?x?xf32>, tensor<1x?x?x?xf32>) outs(%20 : tensor<1x?x?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %42 = arith.subf %arg3, %arg4 : f32
+ linalg.yield %42 : f32
+ } -> tensor<1x?x?x?xf32>
+ flow.dispatch.tensor.store %41, %3, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ }
+ }
+ }
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func @low_padded_conv()
+
+// Loop nest for workgroup tiling and distribution
+// CHECK-COUNT-3: scf.for
+
+// Switch between fast and slow path
+// CHECK: scf.if
+
+// Fast path:
+// Loop nest for thread tiling and reduction tiling
+// CHECK-COUNT-4: scf.for
+// Vector code
+// CHECK-COUNT-6: vector.fma
+
+// CHECK: } else {
+
+// Slow path:
+// Loop nest for thread tiling and reduction tiling
+// CHECK-COUNT-4: scf.for
+// CHECK: scf.if
+// CHECK-NEXT: vector.transfer_read
+// CHECK: scf.if
+// CHECK-NEXT: vector.transfer_read
+// CHECK-COUNT-6: vector.fma
+
+// -----
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1]]>
+#translation = #iree_codegen.translation_info<SPIRVVectorize, workload_per_wg = [32, 4, 1]>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+
+hal.executable private @low_high_padded_depthwise_conv {
+ hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
+ hal.executable.entry_point @low_high_padded_depthwise_conv layout(#executable_layout) {
+ workgroup_size = [8: index, 2: index, 1: index],
+ translation_info = #translation
+ }
+ builtin.module {
+ func @low_high_padded_depthwise_conv() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c112 = arith.constant 112 : index
+ %c32 = arith.constant 32 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x112x112x32xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:3x3x32xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:32xf32>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ %4 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+ scf.for %arg0 = %5 to %c112 step %6 {
+ %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg1 = %7 to %c112 step %8 {
+ %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg2 = %9 to %c32 step %10 {
+ %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+ %12 = flow.dispatch.tensor.load %2, offsets = [%arg2], sizes = [%11], strides = [1] : !flow.dispatch.tensor<readonly:32xf32> -> tensor<?xf32>
+ %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
+ %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
+ %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
+ %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
+ %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+ %18 = tensor.extract_slice %4[0, %arg0, %arg1, %arg2] [1, %15, %16, %17] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x?x?x?xf32>
+ %19 = affine.min affine_map<(d0, d1) -> (d1 + 2, -d0 + 114)>(%arg0, %13)
+ %20 = affine.min affine_map<(d0, d1) -> (d1 + 2, -d0 + 114)>(%arg1, %14)
+ %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+ %22 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg0)
+ %23 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg0)
+ %24 = affine.min affine_map<(d0) -> (d0, 112)>(%23)
+ %25 = affine.max affine_map<(d0, d1) -> (d0 + d1 - 1, 0)>(%19, %arg0)
+ %26 = affine.min affine_map<(d0) -> (d0, 112)>(%25)
+ %27 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%26, %24)
+ %28 = affine.apply affine_map<(d0, d1, d2, d3) -> (-d0 + d1 - d2 + d3)>(%22, %19, %26, %24)
+ %29 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg1)
+ %30 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg1)
+ %31 = affine.min affine_map<(d0) -> (d0, 112)>(%30)
+ %32 = affine.max affine_map<(d0, d1) -> (d0 + d1 - 1, 0)>(%20, %arg1)
+ %33 = affine.min affine_map<(d0) -> (d0, 112)>(%32)
+ %34 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%33, %31)
+ %35 = affine.apply affine_map<(d0, d1, d2, d3) -> (-d0 + d1 - d2 + d3)>(%29, %20, %33, %31)
+ %36 = affine.min affine_map<(d0) -> (d0, 32)>(%arg2)
+ %37 = affine.min affine_map<(d0, d1) -> (d0 + d1, 32)>(%arg2, %21)
+ %38 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%37, %36)
+ %39 = flow.dispatch.tensor.load %0, offsets = [0, %24, %31, %36], sizes = [1, %27, %34, %38], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x?x?x?xf32>
+ %40 = tensor.pad %39 low[0, %22, %29, 0] high[0, %28, %35, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x?x?xf32> to tensor<1x?x?x?xf32>
+ %41 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+ %42 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %41], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x32xf32> -> tensor<3x3x?xf32>
+ %43 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
+ %44 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
+ %45 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+ %46 = linalg.init_tensor [1, %43, %44, %45] : tensor<1x?x?x?xf32>
+ %47 = linalg.fill(%cst, %46) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
+ %48 = linalg.depthwise_conv_2d_nhwc_hwc {lowering_config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%40, %42 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%47 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+ %49 = linalg.generic {lowering_config = #config, indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12, %48 : tensor<?xf32>, tensor<1x?x?x?xf32>) outs(%18 : tensor<1x?x?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %50 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %50 : f32
+ } -> tensor<1x?x?x?xf32>
+ flow.dispatch.tensor.store %49, %3, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %16, %17], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+ }
+ }
+ }
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func @low_high_padded_depthwise_conv()
+
+// Loop nest for workgroup tiling and distribution
+// CHECK-COUNT-3: scf.for
+
+// Switch between fast and slow path
+// CHECK: scf.if
+
+// Fast path:
+// Loop nest for thread tiling and reduction tiling
+// CHECK-COUNT-4: scf.for
+// Vector code
+// CHECK-COUNT-2: vector.transfer_read
+// CHECK: vector.fma
+// CHECK: vector.transfer_read
+// CHECK: vector.fma
+
+// CHECK: } else {
+
+// Slow path:
+// Loop nest for thread tiling and reduction tiling
+// CHECK-COUNT-4: scf.for
+// CHECK: scf.if
+// CHECK-NEXT: vector.transfer_read
+// CHECK: scf.if
+// CHECK-NEXT: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK-COUNT-2: vector.fma
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 487d16d..af14cc3 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -314,12 +314,19 @@
DEBUG_WITH_TYPE(DEBUG_TYPE,
llvm::dbgs() << "current producer: " << producer << "\n");
- linalg::LinalgOp fusedProducer = rewriter.clone(*producer);
+ Operation *fusedProducer = rewriter.clone(*producer);
rewriter.replaceOpWithinBlock(producer, fusedProducer->getResults(),
&dispatchOp.getRegion().front());
removeFusionGroupsAttribute(fusedProducer);
pullInProducersInSameGroup(rewriter, dispatchOp, fusedProducer, groupNum);
+ } else if (auto producer = en.value().getDefiningOp<tensor::PadOp>()) {
+ DEBUG_WITH_TYPE(DEBUG_TYPE,
+ llvm::dbgs() << "current producer: " << producer << "\n");
+
+ Operation *fusedProducer = rewriter.clone(*producer);
+ rewriter.replaceOpWithinBlock(producer, fusedProducer->getResults(),
+ &dispatchOp.getRegion().front());
}
}
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index f743187..d6417bc 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -48,6 +48,11 @@
"flow-padding-size"),
llvm::cl::init(false));
+static llvm::cl::opt<bool> clEnableFusePaddingIntoConsumerOps(
+ "iree-flow-enable-fuse-padding-into-consumer-ops",
+ llvm::cl::desc("Enable fusing linalg pad_tensor ops into consumer ops"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<int> clLinalgOpsPaddingSize(
"iree-flow-linalg-ops-padding-size",
llvm::cl::desc("Enable padding linalg ops to an integer multiple of "
@@ -150,7 +155,8 @@
FunctionLikeNest(passManager)
// Pad tensors.
- .addPass(IREE::Flow::createPadTensorToSubTensorInsertPass)
+ .addPredicatedPass((!clEnableFusePaddingIntoConsumerOps),
+ IREE::Flow::createPadTensorToSubTensorInsertPass)
// Preprocess the input to a form more amenable for fusion
// - Convert all elementwise ops to Linalg