Create dispatches for scalar computation using program slices. (#13711)
Current dispatch region formation logic is geared towards "large"
tensors. For computations on small tensors, these just need to be
moved into a dispatch and executed sequentially on the device. Unlike
the current dispatch region formation that is centered around tile and
fuse approach, the scalar computations just need to be lowered to
loops (which might be small trip loops that can be
canonicalized/unrolled). So the approach used here uses backward slice
to group together a DAG of operations that can be moved into a
sequential dispatch for execution on the device. It also does some
basic "horizontal fusion" to group together operations that are not
producer->consumer but are still scalar computations.
Fixes #13545
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index b209f58..15dcf01 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -48,6 +48,7 @@
"ExportBenchmarkFuncs.cpp",
"FormDispatchRegions.cpp",
"FormDispatchWorkgroups.cpp",
+ "FormScalarDispatches.cpp",
"FusionOfTensorOps.cpp",
"InferNumericNarrowing.cpp",
"InitializeEmptyTensors.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 16e8d83..b0a41ed 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -47,6 +47,7 @@
"ExportBenchmarkFuncs.cpp"
"FormDispatchRegions.cpp"
"FormDispatchWorkgroups.cpp"
+ "FormScalarDispatches.cpp"
"FusionOfTensorOps.cpp"
"InferNumericNarrowing.cpp"
"InitializeEmptyTensors.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp
new file mode 100644
index 0000000..eef91c0
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp
@@ -0,0 +1,282 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
+
+#define DEBUG_TYPE "iree-flow-form-scalar-dispatches"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+/// Pass declaration.
+struct FormScalarDispatchesPass
+ : public FormScalarDispatchesBase<FormScalarDispatchesPass> {
+ using FormScalarDispatchesBase<
+ FormScalarDispatchesPass>::FormScalarDispatchesBase;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<affine::AffineDialect, IREE::Flow::FlowDialect,
+ linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+/// Return true if type represents a value less than `n` elements.
+static bool isScalarOrTensorOfLinearSizeN(int n, Type type) {
+ if (type.isIntOrIndexOrFloat()) {
+ return true;
+ }
+ if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ if (!tensorType.hasStaticShape()) {
+ return false;
+ }
+ return tensorType.getNumElements() <= n;
+ }
+ return false;
+}
+
+/// Return `true` for operations that are to be treated as compute operations.
+static bool isComputeOperation(Operation *op) {
+ MLIRContext *context = op->getContext();
+ if (op->getDialect() == context->getLoadedDialect<linalg::LinalgDialect>()) {
+ return true;
+ }
+ if (op->getDialect() == context->getLoadedDialect<tensor::TensorDialect>()) {
+ return !isa<tensor::CastOp, tensor::CollapseShapeOp, tensor::EmptyOp,
+ tensor::ExpandShapeOp>(op);
+ }
+ return false;
+}
+
+/// Return `true` if the workload of this operation is less than `n`.
+static bool isOperationWorkloadLessThanSizeN(int n, Operation *candidateOp) {
+ return llvm::all_of(candidateOp->getOperands(),
+ [&](Value v) {
+ return isScalarOrTensorOfLinearSizeN(n, v.getType());
+ }) &&
+ llvm::all_of(candidateOp->getResultTypes(), [&](Type t) {
+ return isScalarOrTensorOfLinearSizeN(n, t);
+ });
+}
+
+/// Return `true` is the operation is to be treated as a scalar operation
+/// and moved into a scalar dispatch (not necessarily as the root of the
+/// dispatch).
+static bool isScalarOperation(int workload, Operation *op) {
+ // 1. Ignore most operations. Only look for a whitelist set of operations.
+ if (!isComputeOperation(op)) {
+ return false;
+ }
+
+ // 2. Check that the workload of the operation is less then the limit
+ if (!isOperationWorkloadLessThanSizeN(workload, op)) {
+ return false;
+ }
+
+ // 3. Do not move operations that are cloned into the dispatch region.
+ // TODO: This might prevent moving all scalar operations into dispatch
+ // resulting in artifical splits. Revisit after more examples.
+ return !isClonableIntoDispatchOp(op);
+}
+
+/// Given a `rootOp` return a DAG of the program that represents
+/// operations that can be moved into a scalar dispatch with the `rootOp`
+/// as the root of the DAG.
+llvm::SetVector<Operation *> computeSliceToMoveIntoDispatch(
+ int workload, Operation *rootOp,
+ const llvm::DenseMap<Operation *, Operation *> &opToRootMap) {
+ BackwardSliceOptions options;
+ options.filter = [&](Operation *currentOp) {
+ assert(currentOp && "current op is null");
+ if (opToRootMap.count(currentOp)) {
+ return false;
+ }
+ // Operations needs to be in the same block as `rootOp`.
+ if (currentOp->getBlock() != rootOp->getBlock()) {
+ return false;
+ }
+
+ if (!isScalarOperation(workload, currentOp)) {
+ return false;
+ }
+
+ // All its uses must be in the `opToRootMap`, i.e. they are either
+ // in the current dispatches, or those already formed.
+ return llvm::all_of(currentOp->getUsers(), [&](Operation *user) {
+ return opToRootMap.count(user);
+ });
+ };
+ options.omitBlockArguments = true;
+ llvm::SetVector<Operation *> slice;
+ getBackwardSlice(rootOp, &slice, options);
+ return slice;
+}
+
+/// Return `true` if the op is to be treated as a root of a scalar dispatch.
+static bool isSliceRoot(int workload, Operation *op) {
+ return !op->getParentOfType<DispatchRegionOp>() &&
+ isScalarOperation(workload, op);
+}
+
+// Form dispatch regions from slice of the operation.
+static FailureOr<DispatchRegionOp> formDispatchRegionFromSlice(
+ RewriterBase &rewriter, Operation *rootOp, ArrayRef<Operation *> slice) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(rootOp);
+ FailureOr<DispatchRegionOp> dispatchRegionOp =
+ wrapOpInDispatchRegion(rewriter, rootOp);
+ if (failed(dispatchRegionOp)) {
+ return rootOp->emitOpError("failed to form dispatch region with root op");
+ }
+ FailureOr<DispatchRegionOp> newDispatchOp =
+ movePrecedingOpsIntoDispatchRegion(rewriter, slice,
+ dispatchRegionOp.value());
+ if (failed(newDispatchOp)) {
+ return dispatchRegionOp.value()->emitOpError(
+ "failed to move slice into op");
+ }
+ return newDispatchOp.value();
+}
+
+void FormScalarDispatchesPass::runOnOperation() {
+ mlir::FunctionOpInterface funcOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ int scalarWorkloadLimit = 1;
+ // Convenient struct to hold all operations that need to be moved into a
+ // descriptor.
+ struct DispatchRegionDescriptor {
+ Operation *rootOp;
+ SmallVector<Operation *> fusedOps;
+ };
+
+ SmallVector<DispatchRegionDescriptor> dispatches;
+ llvm::DenseMap<Operation *, Operation *> opToRootMap;
+
+ // Walk the function in postorder, reverse orded ignore all operations
+ // not immediately nested within the `funcOp`.
+ funcOp.walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
+ if (op->getParentOp() != funcOp || opToRootMap.count(op)) {
+ return;
+ }
+
+ if (!isSliceRoot(scalarWorkloadLimit, op)) {
+ return;
+ }
+
+ llvm::SetVector<Operation *> fusedOpsSet =
+ computeSliceToMoveIntoDispatch(scalarWorkloadLimit, op, opToRootMap);
+ for (Operation *sliceOp : fusedOpsSet) {
+ assert(!opToRootMap.count(sliceOp) &&
+ "trying to add same op to two dispatches");
+ opToRootMap[sliceOp] = op;
+ }
+
+ // Iterate backwards within the block to get ops that dont necessarily
+ // have producer -> consumer relationship but can still be fused.
+ Block *currBlock = op->getBlock();
+ Operation *prevOp = op;
+ bool didHorizontalFusion = false;
+ while (prevOp != &currBlock->front()) {
+ prevOp = prevOp->getPrevNode();
+
+ if (opToRootMap.count(prevOp)) {
+ continue;
+ }
+
+ if (!isSliceRoot(scalarWorkloadLimit, prevOp)) {
+ if (isClonableIntoDispatchOp(prevOp)) {
+ continue;
+ }
+ break;
+ }
+
+ didHorizontalFusion = true;
+ fusedOpsSet.insert(prevOp);
+ opToRootMap[prevOp] = op;
+ llvm::SetVector<Operation *> currSlice = computeSliceToMoveIntoDispatch(
+ scalarWorkloadLimit, prevOp, opToRootMap);
+ for (auto sliceOp : currSlice) {
+ assert(!opToRootMap.count(sliceOp) &&
+ "trying to add same op to two dispatches");
+ opToRootMap[sliceOp] = op;
+ }
+ fusedOpsSet.insert(currSlice.begin(), currSlice.end());
+ }
+
+ DispatchRegionDescriptor &currDispatch =
+ dispatches.emplace_back(DispatchRegionDescriptor{});
+ currDispatch.rootOp = op;
+ currDispatch.fusedOps.assign(fusedOpsSet.begin(), fusedOpsSet.end());
+ if (didHorizontalFusion) {
+ mlir::computeTopologicalSorting(currDispatch.fusedOps);
+ }
+ });
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Num scalar dispatches : " << dispatches.size() << "\n";
+ for (auto [index, dispatch] : llvm::enumerate(dispatches)) {
+ llvm::dbgs() << "//--------------------------//\n";
+ llvm::dbgs() << "Dispatch : " << index << ", Root :";
+ dispatch.rootOp->print(llvm::dbgs());
+ llvm::dbgs() << "\nFusedOps :";
+ for (auto fusedOp : dispatch.fusedOps) {
+ fusedOp->print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ }
+ llvm::dbgs() << "//--------------------------//\n";
+ }
+ });
+
+ IRRewriter rewriter(context);
+ for (auto &currDispatch : dispatches) {
+ rewriter.setInsertionPoint(currDispatch.rootOp);
+ FailureOr<DispatchRegionOp> dispatchRegionOp = formDispatchRegionFromSlice(
+ rewriter, currDispatch.rootOp, currDispatch.fusedOps);
+ if (failed(dispatchRegionOp)) {
+ currDispatch.rootOp->emitOpError(
+ "failed to form scalar dispatch region with operation as root");
+ return signalPassFailure();
+ }
+
+ // Set the workgroup count to {1, 1, 1} since this is to be executed
+ // sequentially (at leats for now)
+ Region &countRegion = dispatchRegionOp->getWorkgroupCount();
+ Block *countBody = rewriter.createBlock(&countRegion, countRegion.begin());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(countBody);
+ auto one = rewriter.create<arith::ConstantIndexOp>(
+ dispatchRegionOp.value()->getLoc(), 1);
+ rewriter.create<Flow::ReturnOp>(dispatchRegionOp.value()->getLoc(),
+ ValueRange{one, one, one});
+ }
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createFormScalarDispatchesPass() {
+ return std::make_unique<FormScalarDispatchesPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 6ed5e6d..16b5d2c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -270,6 +270,8 @@
return createDispatchWithTransformDialect(
clDispatchTransformFileName);
})
+
+ .addPass(createFormScalarDispatchesPass)
// Only want use the transform dialect for some dispatch regions and let
// the FormDispatchRegions handle the rest. This only moves the root
// compute op into the dispatch region, so that we can run additional
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 6aea432..bc0a000 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -153,6 +153,10 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFormDispatchRegionsPass(FormDispatchRegionsOptions options = {});
+// Pass to create `flow.dispatch.region`s for scalar computations.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createFormScalarDispatchesPass();
+
// Pass to collapse dimensions of Linalg Ops on tensor ops.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCollapseDimensionsPass();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 7205502..83e90f1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -88,6 +88,12 @@
];
}
+def FormScalarDispatches :
+ InterfacePass<"iree-flow-form-scalar-dispatches", "mlir::FunctionOpInterface"> {
+ let summary = "Form Dispatch Regions for scalar computations.";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createFormScalarDispatchesPass()";
+}
+
def CloneProducersIntoDispatchRegions :
InterfacePass<"iree-flow-clone-producers-into-dispatch-regions", "mlir::FunctionOpInterface"> {
let summary = "Clone producers into dispatch regions to be isolated above";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index a414ee6..c6225e3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
@@ -505,7 +506,9 @@
return true;
}
}
- if (llvm::all_of(op->getOperands(),
+ if (op->getDialect() ==
+ op->getContext()->getLoadedDialect<arith::ArithDialect>() &&
+ llvm::all_of(op->getOperands(),
[&](Value v) { return v.getType().isIntOrFloat(); }) &&
llvm::all_of(op->getResults(),
[&](Value v) { return v.getType().isIntOrFloat(); })) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index ea8370f..f0d04d1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -34,6 +34,7 @@
"export_benchmark_funcs.mlir",
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
+ "form_scalar_dispatches.mlir",
"fusion_of_tensor_ops.mlir",
"infer_numeric_narrowing.mlir",
"initialize_empty_tensors.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 6dbdb1c..6d11ab5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -32,6 +32,7 @@
"export_benchmark_funcs.mlir"
"form_dispatch_regions.mlir"
"form_dispatch_workgroups.mlir"
+ "form_scalar_dispatches.mlir"
"fusion_of_tensor_ops.mlir"
"infer_numeric_narrowing.mlir"
"initialize_empty_tensors.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir
new file mode 100644
index 0000000..b918a79
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir
@@ -0,0 +1,159 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-form-scalar-dispatches))" --split-input-file %s | FileCheck %s
+
+#map = affine_map<() -> ()>
+func.func @simpleDAG(
+ %arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>, %arg3 : tensor<f32>)
+ -> (tensor<f32>, tensor<f32>) {
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%1, %arg3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %4 = arith.mulf %b0, %b1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg2, %3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %6 = arith.subf %b1, %b0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<f32>
+ return %1, %5 : tensor<f32>, tensor<f32>
+}
+// CHECK-LABEL: func @simpleDAG(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<f32>)
+// CHECK: %[[RESULT:.+]]:2 = flow.dispatch.region
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK: %[[GENERIC2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC1]], %[[ARG3]] :
+// CHECK: %[[GENERIC3:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[GENERIC2]] :
+// CHECK: flow.return %[[GENERIC3]], %[[GENERIC1]]
+// CHECK: count() -> (index, index, index)
+// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT: flow.return %[[C1]], %[[C1]], %[[C1]]
+// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+#map = affine_map<() -> ()>
+func.func @simpleHorizontal(
+ %arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>, %arg3 : tensor<f32>)
+ -> (tensor<f32>, tensor<f32>) {
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%1, %arg2 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %4 = arith.mulf %b0, %b1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []}
+ ins(%arg3 : tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32) :
+ %6 = arith.addf %b0, %b0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<f32>
+ return %3, %5 : tensor<f32>, tensor<f32>
+}
+// CHECK-LABEL: func @simpleHorizontal
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK: %[[RESULT:.+]]:2 = flow.dispatch.region
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK: %[[GENERIC2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GENERIC1]], %[[ARG2]] :
+// CHECK: %[[GENERIC3:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG3]] :
+// CHECK: flow.return %[[GENERIC3]], %[[GENERIC2]]
+// CHECK: count() -> (index, index, index)
+// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-NEXT: flow.return %[[C1]], %[[C1]], %[[C1]]
+// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+#map0 = affine_map<() -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+#map3 = affine_map<(d0) -> (d0)>
+func.func @interleaving(
+ %arg0 : tensor<1x1xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>, %arg3 : tensor<f32>)
+ -> (tensor<f32>, tensor<1xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.empty() : tensor<1xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<1x1xf32>, tensor<1xf32>) outs(%1 : tensor<1xf32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %3 = arith.mulf %b0, %b1 : f32
+ %4 = arith.addf %3, %b2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<1xf32>
+ %5 = tensor.empty() : tensor<f32>
+ %6 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%arg2, %arg3 : tensor<f32>, tensor<f32>) outs(%5 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %7 = arith.subf %b1, %b0 : f32
+ linalg.yield %7 : f32
+ } -> tensor<f32>
+ cf.br ^b1
+ ^b1:
+ %7 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel"]}
+ ins(%2, %arg1 : tensor<1xf32>, tensor<1xf32>) outs(%0 : tensor<1xf32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %8 = arith.mulf %b0, %b1 : f32
+ linalg.yield %8 : f32
+ } -> tensor<1xf32>
+ %9 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%6, %arg3 : tensor<f32>, tensor<f32>) outs(%5 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %10 = arith.divf %b1, %b0 : f32
+ linalg.yield %10 : f32
+ } -> tensor<f32>
+ return %9, %7 : tensor<f32>, tensor<1xf32>
+}
+// CHECK-LABEL: func @interleaving(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1xf32>,
+// CHECK-SAME: %[[ARG2:.+]]: tensor<f32>,
+// CHECK-SAME: %[[ARG3:.+]]: tensor<f32>)
+// CHECK: %[[EMPTY0:.+]] = tensor.empty() : tensor<1xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY0]] :
+// CHECK: %[[EMPTY1:.+]] = tensor.empty() : tensor<f32>
+// CHECK: %[[DISPATCH0:.+]]:2 = flow.dispatch.region
+// CHECK: %[[GENERIC0:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: %[[GENERIC1:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]] :
+// CHECK-SAME: outs(%[[EMPTY1]] :
+// CHECK: flow.return %[[GENERIC1]], %[[GENERIC0]]
+// CHECK: ^bb1:
+// CHECK-DAG: %[[DISPATCH1:.+]]:2 = flow.dispatch.region
+// CHECK: %[[GENERIC2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[DISPATCH0]]#1, %[[ARG1]] :
+// CHECK-SAME: outs(%[[EMPTY0]] :
+// CHECK: %[[GENERIC3:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[DISPATCH0]]#0, %[[ARG3]] :
+// CHECK-SAME: outs(%[[EMPTY1]] :
+// CHECK: flow.return %[[GENERIC3]], %[[GENERIC2]]
+// CHECK: return %[[DISPATCH1]]#0, %[[DISPATCH1]]#1
diff --git a/tests/e2e/regression/BUILD.bazel b/tests/e2e/regression/BUILD.bazel
index 8a49a46..afe75e9 100644
--- a/tests/e2e/regression/BUILD.bazel
+++ b/tests/e2e/regression/BUILD.bazel
@@ -28,6 +28,7 @@
"i1_inlined_constant.mlir",
"linalg_ops.mlir",
"reduction_broadcast_elementwise.mlir",
+ "scalar_computation.mlir",
"softmax.mlir",
"strided_slice.mlir",
"transpose.mlir",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index e0a5b32..d3bc167 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -50,6 +50,7 @@
"lowering_config.mlir"
"pack_pad_transpose_1x9_into_2x4x8x4_issue_12546.mlir"
"reduction_broadcast_elementwise.mlir"
+ "scalar_computation.mlir"
"softmax.mlir"
"strided_slice.mlir"
"transpose.mlir"
@@ -90,6 +91,7 @@
"layernorm.mlir"
"linalg_ops.mlir"
"reduction_broadcast_elementwise.mlir"
+ "scalar_computation.mlir"
"softmax.mlir"
"strided_slice.mlir"
"transpose.mlir"
@@ -116,6 +118,7 @@
"i1_inlined_constant.mlir"
"linalg_ops.mlir"
"reduction_broadcast_elementwise.mlir"
+ "scalar_computation.mlir"
"softmax.mlir"
"strided_slice.mlir"
"transpose.mlir"
@@ -144,6 +147,7 @@
"layernorm.mlir"
"linalg_ops.mlir"
"reduction_broadcast_elementwise.mlir"
+ "scalar_computation.mlir"
"softmax.mlir"
"strided_slice.mlir"
"transpose.mlir"
diff --git a/tests/e2e/regression/scalar_computation.mlir b/tests/e2e/regression/scalar_computation.mlir
new file mode 100644
index 0000000..492a9d2
--- /dev/null
+++ b/tests/e2e/regression/scalar_computation.mlir
@@ -0,0 +1,58 @@
+#map = affine_map<() -> ()>
+func.func @simpleDAG() {
+ %arg0 = arith.constant dense<1.0> : tensor<f32>
+ %arg1 = arith.constant dense<2.0> : tensor<f32>
+ %arg2 = arith.constant dense<3.0> : tensor<f32>
+ %arg3 = arith.constant dense<4.0> : tensor<f32>
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg2, %arg3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %4 = arith.mulf %b0, %b1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%1, %3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %6 = arith.subf %b1, %b0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<f32>
+ check.expect_almost_eq_const(%1, dense<3.0> : tensor<f32>) : tensor<f32>
+ check.expect_almost_eq_const(%5, dense<9.0> : tensor<f32>) : tensor<f32>
+ return
+}
+
+func.func @simpleHorizontal() {
+ %arg0 = arith.constant dense<1.0> : tensor<f32>
+ %arg1 = arith.constant dense<2.0> : tensor<f32>
+ %arg2 = arith.constant dense<3.0> : tensor<f32>
+ %arg3 = arith.constant dense<4.0> : tensor<f32>
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%1, %arg2 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32) :
+ %4 = arith.mulf %b0, %b1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []}
+ ins(%arg3 : tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0: f32, %b1 : f32) :
+ %6 = arith.addf %b0, %b0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<f32>
+ check.expect_almost_eq_const(%3, dense<9.0> : tensor<f32>) : tensor<f32>
+ check.expect_almost_eq_const(%5, dense<8.0> : tensor<f32>) : tensor<f32>
+ return
+}