[Flow] Add pass to fuse encoding ops into dispatch regions after hoisting (#18069)
This PR is the follow up to https://github.com/iree-org/iree/pull/18063,
implementing the fusion pass to move set_encoding ops into producer
dispatch regions when when the producer op is a LinalgOp. If there is no
fusable producer, then the set_encoding is wrapped in a new dispatch
region.
This PR also implements the ReifyRankedShapedTypeOpInterface for
DispatchRegionOp, which is needed to resolve tensor dim ops after moving
dynamically shaped SetEncoding ops into producer dispatch regions. The
implementation simply takes the dynamic dims from the result dynamic
dims of the dispatch region op.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 0e47aa5..5e5b2ef 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -578,6 +578,25 @@
shapedType ? shapedType.getNumDynamicDims() : 0);
}
+LogicalResult DispatchRegionOp::reifyResultShapes(
+ OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ SmallVector<Type> resultTypes(getResultTypes());
+ unsigned counter = 0;
+ for (Type resultType : resultTypes) {
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType) {
+ reifiedReturnShapes.push_back({});
+ continue;
+ }
+ SmallVector<Value> dynamicDims =
+ getResultDims().slice(counter, shapedType.getNumDynamicDims());
+ reifiedReturnShapes.push_back(
+ mlir::getMixedValues(shapedType.getShape(), dynamicDims, b));
+ counter += shapedType.getNumDynamicDims();
+ }
+ return success();
+}
+
/// Canonicalizes a DispatchRegionOp: Drop all unused results. Returns `true`
/// if the IR was modified.
bool dropUnusedDispatchRegionResults(RewriterBase &rewriter,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 938e7ad..301ce8b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -35,6 +35,7 @@
let opDocGroup = OpGroupPartitionedRegionOps in {
def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
Util_ShapeAwareOp,
AttrSizedOperandSegments]> {
let summary = [{a group of ops}];
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index c7f2ce9..60472e0 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -503,6 +503,82 @@
return newRegionOp.value();
}
+// Move a `target` op that is following the given dispatch region op into the
+// dispatch region.
+FailureOr<IREE::Flow::DispatchRegionOp>
+moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
+ IREE::Flow::DispatchRegionOp regionOp) {
+ // Fail if any of the `target` operands do not dominate the dispatch region.
+ mlir::DominanceInfo dominanceInfo(regionOp);
+ for (Value operand : target->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (definingOp && !dominanceInfo.dominates(definingOp, regionOp)) {
+ return rewriter.notifyMatchFailure(
+ target, "target operands do not dominate the dispatch region op.");
+ }
+ }
+
+ // Values replaced by moving the `target` into the dispatch region.
+ SmallVector<Value> replacedValues;
+
+ // List of dynamic dimensions for each new results added to the dispatch
+ // region.
+ SmallVector<SmallVector<Value>> dispatchOpNewResultsDynamicDims;
+
+ // New values that are yielded from dispatch.
+ SmallVector<Value> yieldedResults;
+
+ Block &body = regionOp.getBody().front();
+ // Clone op into dispatch region.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(body.getTerminator());
+ Operation *clonedTarget = rewriter.clone(*target);
+
+ // Replace any operands returned by the `regionOp` with the results yielded
+ // inside of the `regionOp`.
+ for (OpOperand &operand : clonedTarget->getOpOperands()) {
+ if (operand.get().getDefiningOp() != regionOp) {
+ continue;
+ }
+ auto returnOp =
+ cast<IREE::Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
+ auto opResult = cast<OpResult>(operand.get());
+ Value yieldedValue = returnOp->getOperand(opResult.getResultNumber());
+ rewriter.modifyOpInPlace(clonedTarget, [&]() {
+ clonedTarget->setOperand(operand.getOperandNumber(), yieldedValue);
+ });
+ }
+
+ // Gather all uses of `target`.
+ for (auto [index, result] : llvm::enumerate(target->getResults())) {
+ replacedValues.push_back(result);
+ yieldedResults.push_back(clonedTarget->getResult(index));
+ rewriter.setInsertionPoint(target);
+ SmallVector<Value> &dims = dispatchOpNewResultsDynamicDims.emplace_back();
+ if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
+ return target->emitOpError(
+ "failed to reify dynamic dims of result to be yielded from "
+ "dispatch region");
+ }
+ }
+
+ FailureOr<IREE::Flow::DispatchRegionOp> newRegionOp =
+ appendDispatchRegionResults(rewriter, regionOp, yieldedResults,
+ dispatchOpNewResultsDynamicDims);
+
+ if (failed(newRegionOp)) {
+ return regionOp->emitOpError("failed to append results to op");
+ }
+
+ ValueRange replacements =
+ newRegionOp->getResults().take_back(replacedValues.size());
+ for (auto [index, replacedVal] : llvm::enumerate(replacedValues)) {
+ rewriter.replaceAllUsesWith(replacedVal, replacements[index]);
+ }
+ rewriter.eraseOp(target);
+ return newRegionOp.value();
+}
+
FailureOr<IREE::Flow::DispatchRegionOp>
wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) {
OpBuilder::InsertionGuard g(rewriter);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index 1451cc3..887613c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -93,6 +93,37 @@
ArrayRef<Operation *> targets,
Flow::DispatchRegionOp regionOp);
+/// Move a `target` op that is following the given dispatch region op into the
+/// dispatch region.
+///
+/// Results of the `target` are appended to the yielded results of the dispatch
+/// region, and uses of each result are replaced with the corresponding newly
+/// yielded result. Operands of the `target` that are produced by the dispatch
+/// region are replaced in the cloned op with the corresponding result yielded
+/// inside of the dispatch region.
+///
+/// Example:
+///
+/// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
+/// %0 = "another_op"(%input) : (tensor<?xf32>) -> (tensor<?xf32>)
+/// flow.return %0 : tensor<?xf32>
+/// }
+/// %1 = "some_op"(%r) : () -> (tensor<?xf32>)
+/// %2 = "some_op_use"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+///
+/// Becomes:
+///
+/// %r:2 = flow.dispatch.region -> (tensor<?xf32>{%d0}, tensor<?xf32>{%d1}) {
+/// %0 = "another_op"(%input) : (tensor<?xf32>) -> (tensor<?xf32>)
+/// %1 = "some_op"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+/// flow.return %0, %1 : tensor<?xf32>, tensor<?xf32>
+/// }
+/// %2 = "some_op_use"(%r#1) : (tensor<?xf32>) -> (tensor<?xf32>)
+///
+FailureOr<IREE::Flow::DispatchRegionOp>
+moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
+ IREE::Flow::DispatchRegionOp regionOp);
+
/// Wrap the given op in a new dispatch region op.
FailureOr<Flow::DispatchRegionOp> wrapOpInDispatchRegion(RewriterBase &rewriter,
Operation *op);
diff --git a/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
index 9df4709..5b6fc3d 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
+++ b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
@@ -27,6 +27,7 @@
"FoldUnitExtentDims.cpp",
"FormDispatchRegions.cpp",
"FormScalarDispatches.cpp",
+ "FuseEncodingOpsIntoDispatchRegions.cpp",
"FuseHorizontalContractions.cpp",
"FuseMultiUseElementwiseProducer.cpp",
"FusionPreprocessing.cpp",
diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
index 9f7b2b0..d4ba559 100644
--- a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
+++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
@@ -29,6 +29,7 @@
"FoldUnitExtentDims.cpp"
"FormDispatchRegions.cpp"
"FormScalarDispatches.cpp"
+ "FuseEncodingOpsIntoDispatchRegions.cpp"
"FuseHorizontalContractions.cpp"
"FuseMultiUseElementwiseProducer.cpp"
"FusionPreprocessing.cpp"
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
new file mode 100644
index 0000000..f71a6d0
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
@@ -0,0 +1,100 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
+#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/DispatchCreation/Passes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-dispatch-creation-producers-into-dispatch-regions"
+
+namespace mlir::iree_compiler::DispatchCreation {
+
+#define GEN_PASS_DEF_FUSEENCODINGOPSINTODISPATCHREGIONSPASS
+#include "iree/compiler/DispatchCreation/Passes.h.inc"
+
+namespace {
+
+// Return true if the op is fusable with a SetEncodingOp consumer.
+// For now, just check if it is a LinalgOp.
+static bool isFusableWithSetEncoding(Operation *op) {
+ return isa<linalg::LinalgOp>(op);
+}
+
+struct FuseEncodingOpsIntoDispatchRegionsPass
+ : public DispatchCreation::impl::FuseEncodingOpsIntoDispatchRegionsPassBase<
+ FuseEncodingOpsIntoDispatchRegionsPass> {
+ void runOnOperation() override {
+ mlir::FunctionOpInterface funcOp = getOperation();
+ MLIRContext *context = &getContext();
+ IRRewriter rewriter(context);
+
+ SmallVector<IREE::Encoding::SetEncodingOp> encodingOps;
+ funcOp->walk([&](IREE::Encoding::SetEncodingOp encodingOp) {
+ encodingOps.push_back(encodingOp);
+ });
+
+ for (IREE::Encoding::SetEncodingOp encodingOp : encodingOps) {
+ OpOperand &operand = encodingOp.getSourceMutable();
+ auto producerDispatch =
+ operand.get().getDefiningOp<IREE::Flow::DispatchRegionOp>();
+ // Nothing to fuse with, so wrap the `encodingOp` in its own dispatch.
+ if (!producerDispatch) {
+ if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
+ return signalPassFailure();
+ }
+ continue;
+ }
+
+ // Find producer operation inside of the dispatch region to determine if
+ // fusion is possible.
+ auto result = cast<OpResult>(operand.get());
+ auto dispatchReturnOp = cast<IREE::Flow::ReturnOp>(
+ producerDispatch.getBody().front().getTerminator());
+ auto producerInRegion = dyn_cast<OpResult>(
+ dispatchReturnOp->getOperand(result.getResultNumber()));
+ if (!producerInRegion) {
+ if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
+ return signalPassFailure();
+ }
+ continue;
+ }
+
+ // Place the op in its own dispatch region if fusion is not possible.
+ if (!isFusableWithSetEncoding(producerInRegion.getOwner())) {
+ if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
+ return signalPassFailure();
+ }
+ continue;
+ }
+ // Fuse the `encodingOp` into the producer dispatch region.
+ if (failed(moveFollowingOpIntoDispatchRegion(rewriter, encodingOp,
+ producerDispatch))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Dynamic dims may have dominance issues after pulling encoding ops into
+ // producer dispatch regions, so we need to resolve tensor.dim ops.
+ RewritePatternSet patterns(context);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::DispatchCreation
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
index 01d132e..d5150cc 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
@@ -248,7 +248,11 @@
// op, so hoist them out of their current dispatch regions. Also, bubble
// SetEncodingOps through special operations like bit-extending ops and
// broadcasting ops.
- .addPass(DispatchCreation::createHoistEncodingOpsPass);
+ .addPass(DispatchCreation::createHoistEncodingOpsPass)
+ // After SetEncodingOps are hoisted, try to fuse them with their
+ // producer dispatches to try to hide packing costs.
+ .addPass(
+ DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass);
}
FunctionLikeNest(passManager)
// Collapse dimensions of linalg Ops.
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.td b/compiler/src/iree/compiler/DispatchCreation/Passes.td
index d198a4c..1f1132e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.td
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.td
@@ -254,6 +254,16 @@
];
}
+def FuseEncodingOpsIntoDispatchRegionsPass :
+ InterfacePass<"iree-dispatch-creation-fuse-encoding-ops-into-dispatch-regions-pass", "mlir::FunctionOpInterface"> {
+ let summary = "Fuses set_encoding ops into producer dispatch regions, or forms new dispatches around them.";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::Encoding::IREEEncodingDialect",
+ ];
+}
+
def HoistEncodingOpsPass :
InterfacePass<"iree-dispatch-creation-hoist-encoding-ops", "mlir::FunctionOpInterface"> {
let summary = "Hoists tensor encoding ops out of flow dispatch regions.";
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel
index 60da198..880a55f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel
@@ -35,6 +35,7 @@
"dispatch_linalg_on_tensors_default.mlir",
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
"form_scalar_dispatches.mlir",
+ "fuse_encoding_ops_into_dispatch_regions.mlir",
"fuse_horizontal_contractions.mlir",
"fuse_multiuse_elementwise_producer.mlir",
"fusion_preprocessing.mlir",
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt
index e310cbc..7de76f9 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt
@@ -32,6 +32,7 @@
"form_dispatch_regions.mlir"
"form_dispatch_workgroups.mlir"
"form_scalar_dispatches.mlir"
+ "fuse_encoding_ops_into_dispatch_regions.mlir"
"fuse_horizontal_contractions.mlir"
"fuse_multiuse_elementwise_producer.mlir"
"fusion_preprocessing.mlir"
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
new file mode 100644
index 0000000..8137d06
--- /dev/null
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
@@ -0,0 +1,144 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-fuse-encoding-ops-into-dispatch-regions-pass),canonicalize)" --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+module {
+ util.func public @parallel_fusion(%arg0: tensor<2x11008x128xf32>) -> tensor<2x11008x128xf32, #encoding> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<2x11008x128xf32>
+ %1 = flow.dispatch.region -> (tensor<2x11008x128xf32>) {
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg0 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32>)
+ outs(%0 : tensor<2x11008x128xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.addf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<2x11008x128xf32>
+ flow.return %3 : tensor<2x11008x128xf32>
+ }
+ %2 = iree_encoding.set_encoding %1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+ util.return %2 : tensor<2x11008x128xf32, #encoding>
+ }
+}
+
+// CHECK-LABEL: @parallel_fusion
+// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK: %[[ADD:.+]] = linalg.generic
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK: flow.return %[[SET_ENCODING]] :
+// CHECK: }
+// CHECK: util.return %[[DISPATCH0]] : tensor<2x11008x128xf32, #iree_encoding.encoding
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+module {
+ util.func public @reduction_fusion(%arg0: tensor<2x11008x128x16xf32>) -> tensor<2x11008x128xf32, #encoding> {
+ %0 = tensor.empty() : tensor<2x11008x128xf32>
+ %1 = flow.dispatch.region -> (tensor<2x11008x128xf32>) {
+ %5 = linalg.generic {
+ indexing_maps = [#map, #map4],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0 : tensor<2x11008x128x16xf32>)
+ outs(%0 : tensor<2x11008x128xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %6 = arith.addf %in, %out : f32
+ linalg.yield %6 : f32
+ } -> tensor<2x11008x128xf32>
+ flow.return %5 : tensor<2x11008x128xf32>
+ }
+ %2 = iree_encoding.set_encoding %1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+ util.return %2 : tensor<2x11008x128xf32, #encoding>
+ }
+}
+
+// CHECK-LABEL: @reduction_fusion
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK: %[[REDUCTION:.+]] = linalg.generic
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[REDUCTION]]
+// CHECK: flow.return %[[SET_ENCODING]] :
+// CHECK: }
+// CHECK: util.return %[[DISPATCH]] : tensor<2x11008x128xf32, #iree_encoding.encoding
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x11008x128xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+module {
+ util.func public @transpose_fusion(%arg0: tensor<2x128x11008xf32>) -> tensor<2x11008x128xf32, #encoding> {
+ %0 = tensor.empty() : tensor<2x11008x128xf32>
+ %1 = flow.dispatch.region -> (tensor<2x11008x128xf32>) {
+ %5 = linalg.generic {
+ indexing_maps = [#map, #map4],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<2x128x11008xf32>)
+ outs(%0 : tensor<2x11008x128xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x11008x128xf32>
+ flow.return %5 : tensor<2x11008x128xf32>
+ }
+ %2 = iree_encoding.set_encoding %1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
+ util.return %2 : tensor<2x11008x128xf32, #encoding>
+ }
+}
+
+// CHECK-LABEL: @transpose_fusion
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32, #iree_encoding.encoding
+// CHECK: %[[TRANSPOSE:.+]] = linalg.generic
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[TRANSPOSE]]
+// CHECK: flow.return %[[SET_ENCODING]]
+// CHECK: }
+// CHECK: util.return %[[DISPATCH]] : tensor<2x11008x128xf32, #iree_encoding.encoding
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<?x?x?xf32>, user_indexing_maps = [#map1, #map2, #map3], round_dims_to = array<i64: 32, 32, 32>>
+module {
+ util.func public @fusion_dynamic(%arg0: tensor<?x?x?xf32>, %d0: index, %d1: index, %d2: index) -> tensor<?x?x?xf32, #encoding> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+ %1 = flow.dispatch.region -> (tensor<?x?x?xf32>{%d0, %d1, %d2}) {
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg0 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%0 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.addf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?x?xf32>
+ flow.return %3 : tensor<?x?x?xf32>
+ }
+ %2 = iree_encoding.set_encoding %1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #encoding>
+ util.return %2 : tensor<?x?x?xf32, #encoding>
+ }
+}
+
+// CHECK-LABEL: @fusion_dynamic
+// CHECK-SAME: {{.+}}: tensor<?x?x?xf32>, %[[D0:.+]]: index, %[[D1:.+]]: index, %[[D2:.+]]: index)
+// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region -> (tensor<?x?x?xf32, #iree_encoding.encoding
+// CHECK-SAME: {%[[D0]], %[[D1]], %[[D2]]}
+// CHECK: %[[ADD:.+]] = linalg.generic
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK: flow.return %[[SET_ENCODING]] :
+// CHECK: }
+// CHECK: util.return %[[DISPATCH0]] : tensor<?x?x?xf32, #iree_encoding.encoding