[GlobalOpt] Improve unary elementwise propagation to consider broadcasted operands (#17903)
For binary (or more operands) elementwise operations, if one of the
operands is broadcasted or otherwise unaffected by a transposition, then
it can effectively be treated like a unary elementwise operation for the
purpose of propagation because propagating the transpose would introduce
only one additional transpose on the input operand. This improves the
unary elementwise propagation patterns to handle such cases.
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 6347f93..ea18f86 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -342,7 +342,7 @@
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 315.0 \
--goldendispatch-rocm-unet 1714 \
- --goldendispatch-rocm-clip 1569 \
+ --goldendispatch-rocm-clip 1311 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
@@ -364,7 +364,7 @@
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 74.0 \
--goldendispatch-rocm-unet 1714 \
- --goldendispatch-rocm-clip 1569 \
+ --goldendispatch-rocm-clip 1311 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index 2dea4ad..8fe7ed2 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -644,23 +644,71 @@
bool allowGeneralizing = false;
};
-bool isUnaryElementwiseGeneric(linalg::GenericOp genericOp) {
- if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInputs() != 1 ||
- !linalg::isElementwise(genericOp)) {
- return false;
+static bool isIndexingMapAffectedByTransposeMap(
+ AffineMap indexingMap, ArrayRef<int64_t> iterationSpacePermutation) {
+ int64_t prevIdx = -1;
+ for (auto result : indexingMap.getResults()) {
+ int64_t idx =
+ iterationSpacePermutation[cast<AffineDimExpr>(result).getPosition()];
+ // Verify that the relative ordering of indices in the map remain the same.
+ // If not, then the transposition affects the access order for the given
+ // map (and associated operand).
+ if (idx <= prevIdx) {
+ return true;
+ }
+ prevIdx = idx;
}
-
- // Skip transposes and broadcasts. Transposes make more sense to fuse
- // rather than propagate through, and broadcasts are cheaper to transpose
- // before broadcasting.
- if (genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0)) !=
- genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))) {
- return false;
- }
- return true;
+ return false;
}
-// Sinks a transpose through the input of a unary elementwise operation.
+// Finds a single DPS input operand of the given |genericOp| that is affected by
+// the |iterationSpacePermutation|. In other words, the permutation changes the
+// relative ordering of any of the dimensions of that input operand.
+//
+// For example, with permutation [1, 0, 2], affine map (d0, d1, d2) -> (d0, d1)
+// is affected by the permutation because the first two dimensions are iterated
+// in a different order while (d0, d1, d2) -> (d0, d2) is unaffected.
+//
+// If no such operand is found or there is more than one such operation, nullptr
+// is returned.
+static OpOperand *
+getSingleTransposedInputOperand(linalg::GenericOp genericOp,
+ ArrayRef<int64_t> iterationSpacePermutation) {
+ OpOperand *operand = nullptr;
+ for (auto input : genericOp.getDpsInputOperands()) {
+ if (!isIndexingMapAffectedByTransposeMap(
+ genericOp.getMatchingIndexingMap(input),
+ iterationSpacePermutation)) {
+ continue;
+ }
+ if (operand) {
+ return nullptr;
+ }
+ operand = input;
+ }
+ return operand;
+}
+
+// Returns a new list of indexing maps that composes the iteration space
+// permutation map |transposeMap| with all indexing maps of |genericOp| except
+// for the |transposedInputIdx|'th operand. The unchanged operand is expected
+// to have an explicit `linalg.transpose` op constructed for it so its map does
+// not need to be updated.
+static SmallVector<AffineMap>
+getTransposedIndexingMaps(linalg::GenericOp genericOp,
+ int64_t transposedInputIdx, AffineMap transposeMap) {
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ for (unsigned i = 0, e = genericOp.getNumDpsInputs(); i < e; ++i) {
+ if (i == transposedInputIdx) {
+ continue;
+ }
+ indexingMaps[i] = indexingMaps[i].compose(transposeMap);
+ }
+ return indexingMaps;
+}
+
+// Sinks a transpose through the input of a elementwise operation where the
+// transposition of the iteration space only affects a single input operand.
class SinkTransposeThroughUnaryElementwiseInput
: public OpRewritePattern<linalg::GenericOp> {
public:
@@ -669,22 +717,57 @@
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) {
- return failure();
+ return rewriter.notifyMatchFailure(genericOp, "pre-formed dispatch");
}
- if (!isUnaryElementwiseGeneric(genericOp)) {
- return rewriter.notifyMatchFailure(genericOp, "not unary elementwise");
+ if (!linalg::isElementwise(genericOp)) {
+ return rewriter.notifyMatchFailure(genericOp, "non-elementwise generic");
}
- auto transposeOp =
- genericOp.getDpsInputs()[0].getDefiningOp<linalg::TransposeOp>();
- if (!transposeOp) {
- return rewriter.notifyMatchFailure(genericOp, "no transpose operand");
+ if (genericOp.getNumDpsInits() != 1) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "unimplemented: multiple results");
}
- if (!transposeOp->hasOneUse()) {
+ AffineMap resultMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+ if (!resultMap.isIdentity()) {
return rewriter.notifyMatchFailure(
- genericOp, "do not propagate multi-use transpose");
+ genericOp, "unimplemented: non-identity result map");
+ }
+
+ linalg::TransposeOp transposeOp;
+ OpOperand *inputOperand;
+ for (auto input : genericOp.getDpsInputOperands()) {
+ // Skip broadcasted operands and transposed operands. If the input is
+ // broadcasted then we would not want to propagate because that would
+ // do the transpose on larger data, and if transposed we would rather
+ // simply compose the transposes (handled in a separate pattern).
+ if (genericOp.getMatchingIndexingMap(input) != resultMap) {
+ continue;
+ }
+
+ auto maybeTransposeOp = input->get().getDefiningOp<linalg::TransposeOp>();
+ // Skip multi-use transposes.
+ if (!maybeTransposeOp || !maybeTransposeOp->hasOneUse()) {
+ continue;
+ }
+
+ auto transposableInputOperand = getSingleTransposedInputOperand(
+ genericOp, maybeTransposeOp.getPermutation());
+ // Skip if more than one operand is affected by the transpose.
+ if (transposableInputOperand != input) {
+ continue;
+ }
+
+ transposeOp = maybeTransposeOp;
+ inputOperand = transposableInputOperand;
+ break;
+ }
+
+ if (!transposeOp) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "no single use transpose operand");
}
ArrayRef<int64_t> perm = transposeOp.getPermutation();
@@ -694,18 +777,30 @@
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], invPerm);
- // We do not need to update indexing maps because this is a unary
- // elementwise op where the input and output maps are the same. Just
- // replace the operands with transposed variants.
- auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(),
- {transposeOp.getInput(), newInit});
+ // We do not need to update iterator types because this is an elementwise
+ // op. We just need to update the indexing maps of all other input operands
+ // by composing the transpose map.
+ AffineMap transposeMap =
+ AffineMap::getPermutationMap(perm, rewriter.getContext());
+ SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
+ genericOp, inputOperand->getOperandNumber(), transposeMap);
+
+ SmallVector<Value> newOperands = genericOp->getOperands();
+ newOperands[inputOperand->getOperandNumber()] = transposeOp.getInput();
+ newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;
+
+ auto newGenericOp =
+ mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
+ newGenericOp.setIndexingMapsAttr(
+ rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(
genericOp, createTranspose(rewriter, newGenericOp->getResult(0), perm));
return success();
}
};
-// Bubbles a transpose through the init of a unary elementwise operation.
+// Bubbles a transpose through the init of a elementwise operation where the
+// transposition of the iteration space only affects a single input operand.
class BubbleTransposeThroughUnaryElementwiseDpsInit
: public OpRewritePattern<linalg::TransposeOp> {
public:
@@ -715,33 +810,64 @@
PatternRewriter &rewriter) const override {
auto genericOp = transposeOp.getInput().getDefiningOp<linalg::GenericOp>();
if (!genericOp) {
- return failure();
+ return rewriter.notifyMatchFailure(transposeOp, "non-generic producer");
}
+
+ if (genericOp.getNumDpsInits() != 1) {
+ return rewriter.notifyMatchFailure(transposeOp,
+ "unimplemented: multiple results");
+ }
+
if (!IREE::Flow::isNonNullAndOutsideDispatch({genericOp, transposeOp})) {
return failure();
}
- if (!isUnaryElementwiseGeneric(genericOp)) {
- return rewriter.notifyMatchFailure(genericOp, "not unary elementwise");
+ if (!linalg::isElementwise(genericOp) ||
+ !genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))
+ .isIdentity()) {
+ return rewriter.notifyMatchFailure(transposeOp, "not elementwise");
}
if (!genericOp->hasOneUse()) {
- return rewriter.notifyMatchFailure(genericOp, "not single user");
+ return rewriter.notifyMatchFailure(transposeOp, "not single user");
}
ArrayRef<int64_t> perm = transposeOp.getPermutation();
- Value newTranspose =
- createTranspose(rewriter, genericOp.getOperand(0), perm);
+ auto invPerm = invertPermutationVector(perm);
+
+ auto inputOperand = getSingleTransposedInputOperand(genericOp, invPerm);
+ if (!inputOperand ||
+ !genericOp.getMatchingIndexingMap(inputOperand).isIdentity()) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "no single transposable input operand");
+ }
+
+ Value newTranspose = createTranspose(rewriter, inputOperand->get(), perm);
// Create a new empty init for the transposed generic.
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], perm);
+ SmallVector<Value> newOperands = genericOp->getOperands();
+ newOperands[inputOperand->getOperandNumber()] = newTranspose;
+ newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;
+
+ AffineMap transposeMap =
+ AffineMap::getPermutationMap(invPerm, rewriter.getContext());
+
+ // We do not need to update iterator types because this is an elementwise
+ // op. We just need to update the indexing maps of all other input operands
+ // by composing the transpose map.
+ SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
+ genericOp, inputOperand->getOperandNumber(), transposeMap);
+
// We do not need to update indexing maps because this is a unary
// elementwise op where the input and output maps are the same. Just
// replace the operands with transposed variants.
- auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(),
- {newTranspose, newInit});
+ auto newGenericOp =
+ mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
+ newGenericOp.setIndexingMapsAttr(
+ rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(transposeOp, newGenericOp);
return success();
}
@@ -912,6 +1038,7 @@
context, /*benefit=*/2);
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) {
+ funcOp.emitError("Transpose initial sinking patterns failed");
return signalPassFailure();
}
}
@@ -968,6 +1095,7 @@
populateCommonCanonicalizationPatterns(context, bubblingPatterns);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(bubblingPatterns)))) {
+ funcOp.emitError("Transpose bubbling patterns failed");
return signalPassFailure();
}
}
@@ -1020,8 +1148,13 @@
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
- if (failed(
- applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) {
+ GreedyRewriteConfig config;
+ // TODO: This is inefficient. Consider rewriting this pass to use a
+ // worklist of just the transpose operations.
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns),
+ config))) {
+ funcOp.emitError("Transpose sinking patterns failed");
return signalPassFailure();
}
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
index 939e650..6b95716 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
@@ -485,3 +485,183 @@
// APROP-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// APROP-SAME: outs(%[[EMPTY]] : tensor<16x16xf32>)
// APROP: util.return %[[MATMUL]]
+
+// -----
+
+util.func public @propagate_transpose_down_through_broadcast_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf32> {
+ %empty = tensor.empty(): tensor<2x3x4xf32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>)
+ outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1]
+ %0 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%transposed, %arg1 : tensor<2x3x4xf32>, tensor<3x4xf32>)
+ outs(%empty : tensor<2x3x4xf32>) {
+ ^bb0(%in: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<2x3x4xf32>
+ util.return %0 : tensor<2x3x4xf32>
+}
+
+// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// SINK-LABEL: util.func public @propagate_transpose_down_through_broadcast_elementwise
+// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32>
+// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32>
+// SINK: %[[ELEM:.+]] = linalg.generic
+// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]]
+// SINK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<3x4xf32>
+// SINK: arith.addf
+// SINK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ELEM]] : tensor<3x4x2xf32>
+// SINK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+// SINK-SAME: permutation = [2, 0, 1]
+// SINK: util.return %[[TRANSPOSE]] : tensor<2x3x4xf32>
+
+// -----
+
+util.func public @propagate_transpose_down_through_multi_operand_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf32> {
+ %empty = tensor.empty(): tensor<2x3x4xf32>
+ %t1 = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>)
+ outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1]
+ %empty2 = tensor.empty(): tensor<4x3xf32>
+ %t2 = linalg.transpose ins(%arg1 : tensor<3x4xf32>)
+ outs(%empty2 : tensor<4x3xf32>) permutation = [1, 0]
+ %0 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%t2, %t1 : tensor<4x3xf32>, tensor<2x3x4xf32>)
+ outs(%empty : tensor<2x3x4xf32>) {
+ ^bb0(%in: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<2x3x4xf32>
+ util.return %0 : tensor<2x3x4xf32>
+}
+
+// Verify that it first selects the correct transpose to propagate and then
+// fuses the transpose on the broadcasted operand.
+
+// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// SINK-LABEL: util.func public @propagate_transpose_down_through_multi_operand_elementwise
+// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32>
+// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32>
+// SINK: %[[ELEM:.+]] = linalg.generic
+// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]]]
+// SINK-SAME: ins(%[[ARG1]], %[[ARG0]] : tensor<3x4xf32>, tensor<3x4x2xf32>
+// SINK: arith.addf
+// SINK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ELEM]] : tensor<3x4x2xf32>
+// SINK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+// SINK-SAME: permutation = [2, 0, 1]
+// SINK: util.return %[[TRANSPOSE]] : tensor<2x3x4xf32>
+
+// -----
+
+util.func public @sink_transpose_down_to_broadcast_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x3x4xf32> {
+ %empty = tensor.empty(): tensor<2x3x4xf32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>)
+ outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1]
+ %0 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%transposed, %arg1 : tensor<2x3x4xf32>, tensor<2x4xf32>)
+ outs(%empty : tensor<2x3x4xf32>) {
+ ^bb0(%in: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<2x3x4xf32>
+ util.return %0 : tensor<2x3x4xf32>
+}
+
+// Verify that the transpose is fused rather than propagated because the
+// broadcast operand would be affected by the transpose.
+
+// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// SINK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// SINK-LABEL: util.func public @sink_transpose_down_to_broadcast_elementwise
+// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32>
+// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<2x4xf32>
+// SINK: %[[ELEM:.+]] = linalg.generic
+// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// SINK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<2x4xf32>
+// SINK: arith.addf
+// SINK: util.return %[[ELEM]] : tensor<2x3x4xf32>
+
+// -----
+
+util.func public @propagate_transpose_up_through_broadcast_elementwise(%arg0: tensor<2x3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4x2xf32> {
+ %empty = tensor.empty(): tensor<2x3x4xf32>
+ %0 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<3x4xf32>)
+ outs(%empty : tensor<2x3x4xf32>) {
+ ^bb0(%in: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<2x3x4xf32>
+ %empty1 = tensor.empty(): tensor<3x4x2xf32>
+ %transposed = linalg.transpose ins(%0 : tensor<2x3x4xf32>)
+ outs(%empty1 : tensor<3x4x2xf32>) permutation = [1, 2, 0]
+ util.return %transposed : tensor<3x4x2xf32>
+}
+
+// BUBBLE-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// BUBBLE-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// BUBBLE-LABEL: util.func public @propagate_transpose_up_through_broadcast_elementwise
+// BUBBLE-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3x4xf32>
+// BUBBLE-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32>
+// BUBBLE: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<2x3x4xf32>
+// BUBBLE-SAME: outs({{.*}} : tensor<3x4x2xf32>)
+// BUBBLE-SAME: permutation = [1, 2, 0]
+// BUBBLE: %[[ELEM:.+]] = linalg.generic
+// BUBBLE-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]]
+// BUBBLE-SAME: ins(%[[TRANSPOSE]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<3x4xf32>
+// BUBBLE: arith.addf
+// BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32>
+
+// -----
+
+util.func public @bubble_transpose_to_broadcast_elementwise(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x2xf32> {
+ %empty = tensor.empty(): tensor<2x3x4xf32>
+ %0 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x4xf32>)
+ outs(%empty : tensor<2x3x4xf32>) {
+ ^bb0(%in: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<2x3x4xf32>
+ %empty1 = tensor.empty(): tensor<3x4x2xf32>
+ %transposed = linalg.transpose ins(%0 : tensor<2x3x4xf32>)
+ outs(%empty1 : tensor<3x4x2xf32>) permutation = [1, 2, 0]
+ util.return %transposed : tensor<3x4x2xf32>
+}
+
+// Verify that the transpose is fused rather than propagated because the
+// broadcast operand would be affected by the transpose.
+
+// BUBBLE-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+// BUBBLE-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// BUBBLE-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// BUBBLE-LABEL: util.func public @bubble_transpose_to_broadcast_elementwise
+// BUBBLE-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3x4xf32>
+// BUBBLE-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<2x4xf32>
+// BUBBLE: %[[ELEM:.+]] = linalg.generic
+// BUBBLE-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// BUBBLE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x4xf32>
+// BUBBLE: arith.addf
+// BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32>