Fix SSA use-def violation created by Tile and fuse pass. (#15133)
Fixes #15126
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
index 074cfb7..b4daed9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
@@ -92,12 +92,12 @@
};
LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
+ DominanceInfo &dominanceInfo,
scf::SCFTilingOptions options) {
llvm::SmallDenseSet<Operation *> origTiledAndFusedOps;
collectTiledAndFusedOps(rootOp, origTiledAndFusedOps);
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
- return origTiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
- outerMostTiledLoop->isAncestor(user);
+ return origTiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
};
// The rest of this method is similar to
@@ -184,7 +184,9 @@
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
- return !isIgnoredUser(user, forLoops.front());
+ return !isIgnoredUser(user, forLoops.front()) &&
+ !forLoops.front()->isAncestor(user);
+ ;
})) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedProducer.value(), forLoops);
@@ -202,7 +204,8 @@
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
- return !isIgnoredUser(use.getOwner(), outermostLoop);
+ return !isIgnoredUser(use.getOwner(), outermostLoop) &&
+ dominanceInfo.properlyDominates(outermostLoop, use.getOwner());
});
}
@@ -259,8 +262,9 @@
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+ DominanceInfo dominanceInfo(funcOp);
auto options = scf::SCFTilingOptions().setTileSizes(tileSizesOfr);
- if (failed(applyTileAndFuse(rewriter, consumerOp, options))) {
+ if (failed(applyTileAndFuse(rewriter, consumerOp, dominanceInfo, options))) {
LLVM_DEBUG(llvm::dbgs() << "----- tile and fuse failed -----\n");
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir
index e2c9bf8..4e5479e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir
@@ -153,3 +153,63 @@
// CHECK: %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[ITER1]]
// CHECK: %{{.+}} = linalg.generic
// CHECK-SAME: outs(%[[OUT_SLICE2]]
+
+// -----
+
+// This test is to check it doesnt crash. See #15126
+func.func @softmax() {
+ %c2 = arith.constant 2 : index
+ %c5 = arith.constant 5 : index
+ %cst = arith.constant 0xFF800000 : f32
+ %c10 = arith.constant 10 : index
+ %c1 = arith.constant 1 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %cst_1 = arith.constant -1.000000e+30 : f32
+ %c512 = arith.constant 512 : index
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x10xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x10xf32>>
+ %2 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 10], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x10xf32>> -> tensor<1x10xf32>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 10], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x10xf32>> -> tensor<1x10xf32>
+ %4 = tensor.empty() : tensor<1xf32>
+ %5 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0], [0], [0], [0]]>} ins(%cst_1 : f32) outs(%4 : tensor<1xf32>) -> tensor<1xf32>
+ %expanded = tensor.expand_shape %3 [[0], [1, 2]] : tensor<1x10xf32> into tensor<1x5x2xf32>
+ %6 = tensor.empty() : tensor<1x2xf32>
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1x2xf32>) -> tensor<1x2xf32>
+ %8 = scf.for %arg0 = %c0 to %c5 step %c1 iter_args(%arg1 = %7) -> (tensor<1x2xf32>) {
+ %extracted_slice = tensor.extract_slice %expanded[0, %arg0, 0] [1, 1, 2] [1, 1, 1] : tensor<1x5x2xf32> to tensor<1x1x2xf32>
+ %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "reduction", "parallel"]} ins(%extracted_slice : tensor<1x1x2xf32>) outs(%arg1 : tensor<1x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %14 = arith.maximumf %in, %out : f32
+ linalg.yield %14 : f32
+ } -> tensor<1x2xf32>
+ scf.yield %13 : tensor<1x2xf32>
+ }
+ %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%8 : tensor<1x2xf32>) outs(%5 : tensor<1xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %13 = arith.maximumf %in, %out : f32
+ linalg.yield %13 : f32
+ } -> tensor<1xf32>
+ %10 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0], [0], [0], [0]]>} ins(%cst_0 : f32) outs(%4 : tensor<1xf32>) -> tensor<1xf32>
+ %11 = scf.for %arg0 = %c0 to %c10 step %c2 iter_args(%arg1 = %10) -> (tensor<1xf32>) {
+ %extracted_slice = tensor.extract_slice %3[0, %arg0] [1, 2] [1, 1] : tensor<1x10xf32> to tensor<1x2xf32>
+ %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice, %9 : tensor<1x2xf32>, tensor<1xf32>) outs(%arg1 : tensor<1xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0], [0, 0], [0, 2], [0, 0]]>} {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %14 = arith.subf %in, %in_2 : f32
+ %15 = math.exp %14 : f32
+ %16 = arith.addf %15, %out : f32
+ linalg.yield %16 : f32
+ } -> tensor<1xf32>
+ scf.yield %13 : tensor<1xf32>
+ }
+ %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %9, %11 : tensor<1x10xf32>, tensor<1xf32>, tensor<1xf32>) outs(%2 : tensor<1x10xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0], [0, 0], [0, 0], [2, 2]]>} {
+ ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+ %13 = arith.subf %in, %in_2 : f32
+ %14 = math.exp %13 : f32
+ %15 = arith.divf %14, %in_3 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x10xf32>
+ flow.dispatch.tensor.store %12, %1, offsets = [0, 0], sizes = [1, 10], strides = [1, 1] : tensor<1x10xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x10xf32>>
+ return
+}
+// CHECK-LABEL: func @softmax()