| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <algorithm> |
| |
| #include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" |
| #include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" |
| #include "iree/compiler/Utils/GraphUtils.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassRegistry.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/Utils.h" |
| |
| #define DEBUG_TYPE "iree-dispatch-detail" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace Flow { |
| |
| namespace { |
| |
| // Returns true if the given |op| can be dispatched in all cases. |
| // Other passes may handle special cases of these ops but this initial |
| // identification is conservative. |
| bool isDispatchableOp(Operation *op, Dispatchability &dispatchability) { |
| // TODO(b/144530470): replace with tablegen attributes/interfaces. |
| if (FlowDialect::isDialectOp(op)) { |
| // Ignore things we've already produced as they should only relate to |
| // sequencer operations. |
| LLVM_DEBUG(llvm::dbgs() << " NOT DISPATCHABLE (Flow Dialect): " |
| << op->getName() << "\n"); |
| return false; |
| } else if (op->isKnownTerminator()) { |
| // Currently we skip all terminators as we want to leave them in the block |
| // to keep it valid. Future folding passes may take care of them if they are |
| // worth bringing into the dispatch region. |
| LLVM_DEBUG(llvm::dbgs() << " NOT DISPATCHABLE (Known Terminator): " |
| << op->getName() << "\n"); |
| return false; |
| } else if (auto callOp = dyn_cast<CallOp>(op)) { |
| bool dispatchable = dispatchability.isDispatchable(callOp.getCallee()); |
| LLVM_DEBUG(llvm::dbgs() |
| << " " << (dispatchable ? "" : "NOT ") |
| << "DISPATCHABLE (Call): " << op->getName() << "\n"); |
| return dispatchable; |
| } else if (isa<CallIndirectOp>(op)) { |
| // Indirect calls are not supported in dispatch code. |
| LLVM_DEBUG(llvm::dbgs() << " NOT DISPATCHABLE (Call Indirect): " |
| << op->getName() << "\n"); |
| return false; |
| } else if (isa<ConstantOp>(op)) { |
| // Constants are handled in the RematerializeDispatchConstants pass. |
| // We do that independently so that we can more easily see the use of |
| // constants across all dispatches instead of just on an individual basis |
| // as we do here. |
| LLVM_DEBUG(llvm::dbgs() |
| << " NOT DISPATCHABLE (Constant): " << op->getName() << "\n"); |
| return false; |
| } else if (op->getNumResults() && |
| !op->getResult(0).getType().isa<ShapedType>()) { |
| // We don't put scalar manipulation into dispatch regions. |
| LLVM_DEBUG(llvm::dbgs() |
| << " NOT DISPATCHABLE (Non Shaped): " << op->getName() << "\n"); |
| return false; |
| } else if (!isOpOfKnownDialect(op)) { |
| // Probably a custom op. |
| LLVM_DEBUG(llvm::dbgs() << " NOT DISPATCHABLE (Unknown Dialect): " |
| << op->getName() << "\n"); |
| return false; |
| } |
| LLVM_DEBUG(llvm::dbgs() << " DISPATCHABLE: " << op->getName() << "\n"); |
| return true; |
| } |
| |
| // Returns true if the given |op| can have other ops fused into it. |
| // This is sketchy and it'd be nice to define this as an op property instead. |
| // |
| // What we are looking for in foldable ops is whether the execution of the op |
| // when fused has some possible benefit (or at least, a non-negative cost). |
| // Eventually we want to allow backends to vote on this and allow multiple |
| // folding strategies within the same executable. For now we just hardcode what |
| // we know for the ops we have. |
| // |
| // Preconditions: isDispatchableOp(op) == true. |
| bool isFusionRootOp(Operation *op) { |
| // TODO(b/144530470): replace with tablegen attributes/interfaces. |
| // TODO(#1605): Remove mhlo::PadOp from the check. |
| if (isa<mhlo::DotOp>(op) || isa<mhlo::ConvOp>(op) || |
| isa<mhlo::ReduceOp>(op) || isa<mhlo::PadOp>(op) || |
| isa<mhlo::ReduceWindowOp>(op)) { |
| // We have hand-written kernels for these right now we want to stand alone. |
| // When we do a bit more magic we should allow these ops to fold. |
| LLVM_DEBUG(llvm::dbgs() << " NOT A FUSION ROOT (Special Op): " |
| << op->getName() << "\n"); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| bool isNonFusionRootOp(Operation *op) { |
| // Avoid forming dispatch regions around metadata ops that do no work. |
| if (isa<Shape::TieShapeOp>(op) || isa<Shape::MakeRankedShapeOp>(op)) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| // Returns true if the given |op| can be fused into other ops. |
| // |
| // Ops that perform narrowing on shapes (such as reduction ops) should not |
| // generally be fused with other downstream ops (probably...). This avoids |
| // potential oversampling and indexing issues and allows backends to perform |
| // more efficient rooted cascading reduction dispatches. |
| // |
| // Preconditions: isDispatchableOp(op) == true. |
| bool isFusableOp(Operation *op) { |
| // TODO(b/144530470): replace with tablegen attributes/interfaces. |
| if (isa<mhlo::DotOp>(op) || isa<mhlo::ConvOp>(op)) { |
| return false; |
| } else if (isa<mhlo::ReduceOp>(op) || isa<mhlo::ReduceWindowOp>(op)) { |
| // Reduction is usually a dedicated root operation - we can shove things in |
| // the front of it but not behind. |
| return false; |
| } else if (isa<mhlo::PadOp>(op)) { |
| // TODO(#1605): Remove mhlo::PadOp from the check. |
| return false; |
| } |
| return true; |
| } |
| |
| // Recursively traverses the IR DAG along the operand edges to find ops we are |
| // able to fuse and appends them to |subgraph|. |
| void gatherFusionOps(Operation *op, Dispatchability &dispatchability, |
| llvm::ArrayRef<Operation *> metadataOps, |
| llvm::SetVector<Operation *> *subgraph) { |
| // Skip ops that are used outside of the subgraph we are building. |
| for (auto result : op->getResults()) { |
| if (result.use_empty() || result.hasOneUse()) continue; |
| for (auto *user : result.getUsers()) { |
| if (subgraph->count(user) == 0) { |
| // Op that consumes the result is not (yet) in the subgraph. |
| // For now we'll ignore these as it may represent a fork that we don't |
| // want to join too early. |
| return; |
| } |
| } |
| } |
| |
| // Walk backward up to ops providing our input operands. |
| for (auto operand : op->getOperands()) { |
| auto *sourceOp = operand.getDefiningOp(); |
| |
| // Scan any intermediate "metadata" ops which should be included iff they |
| // are between the starting op and a viable target op. |
| llvm::SmallVector<Operation *, 1> nextMetadataOps; |
| while (sourceOp) { |
| if (auto tieShapeOp = llvm::dyn_cast<Shape::TieShapeOp>(sourceOp)) { |
| nextMetadataOps.push_back(tieShapeOp); |
| sourceOp = tieShapeOp.operand().getDefiningOp(); |
| continue; |
| } |
| break; |
| } |
| if (!sourceOp) continue; |
| |
| if (subgraph->count(sourceOp) == 0) { |
| if (isDispatchableOp(sourceOp, dispatchability) && |
| isFusableOp(sourceOp)) { |
| gatherFusionOps(sourceOp, dispatchability, nextMetadataOps, subgraph); |
| } |
| } |
| } |
| |
| for (auto *metadataOp : metadataOps) { |
| LLVM_DEBUG(llvm::dbgs() |
| << " : Add metadata op: " << metadataOp->getName() << "\n"); |
| subgraph->insert(metadataOp); |
| } |
| LLVM_DEBUG(llvm::dbgs() << " : Add dispatchable op: " << op->getName() |
| << "\n"); |
| subgraph->insert(op); |
| } |
| |
| void extendInboundMetadataOps(llvm::SetVector<Operation *> *subgraph) { |
| llvm::SmallMapVector<Operation *, Operation *, 4> metadataCloneMap; |
| // Discover and create clones. |
| for (Operation *subgraphOp : *subgraph) { |
| if (llvm::isa<Shape::TieShapeOp>(subgraphOp)) continue; |
| |
| LLVM_DEBUG(llvm::dbgs() << " : Extend inbound metadata for: " |
| << subgraphOp->getName() << "\n"); |
| OpBuilder b(subgraphOp->getContext()); |
| for (auto operand : subgraphOp->getOperands()) { |
| // Only consider edges outside of the subgraph. |
| Operation *metadataOp = operand.getDefiningOp(); |
| if (!metadataOp || subgraph->count(metadataOp) > 0 || |
| metadataCloneMap.count(metadataOp) > 0) |
| continue; |
| |
| if (auto tieShapeOp = llvm::dyn_cast<Shape::TieShapeOp>(metadataOp)) { |
| LLVM_DEBUG(llvm::dbgs() << " : Duplicating tie_shape op\n"); |
| b.setInsertionPointAfter(tieShapeOp.getOperation()); |
| auto duped = b.create<Shape::TieShapeOp>( |
| tieShapeOp.getLoc(), tieShapeOp.getType(), tieShapeOp, |
| tieShapeOp.shape()); |
| metadataCloneMap.insert({metadataOp, duped.getOperation()}); |
| } |
| } |
| } |
| |
| // Replace uses of clones and add to subgraph. |
| for (auto &kv : metadataCloneMap) { |
| Operation *originalOp = kv.first; |
| Operation *dupedOp = kv.second; |
| originalOp->replaceAllUsesWith(dupedOp); |
| dupedOp->replaceUsesOfWith(dupedOp->getResult(0), originalOp->getResult(0)); |
| subgraph->insert(dupedOp); |
| } |
| } |
| |
| void extendOutboundMetadataOps(llvm::SetVector<Operation *> *subgraph) { |
| llvm::SmallSetVector<Operation *, 4> metadataOps; |
| // Discover and create clones. |
| for (Operation *subgraphOp : *subgraph) { |
| if (llvm::isa<Shape::TieShapeOp>(subgraphOp)) continue; |
| |
| LLVM_DEBUG(llvm::dbgs() << " : Extend outbound metadata for: " |
| << subgraphOp->getName() << "\n"); |
| OpBuilder b(subgraphOp->getContext()); |
| for (auto result : subgraphOp->getResults()) { |
| for (auto &use : result.getUses()) { |
| // Only consider edges outside of the subgraph. |
| Operation *metadataOp = use.getOwner(); |
| if (subgraph->count(metadataOp) > 0 || metadataOps.count(metadataOp)) |
| continue; |
| |
| if (auto tieShapeOp = llvm::dyn_cast<Shape::TieShapeOp>(metadataOp)) { |
| LLVM_DEBUG(llvm::dbgs() << " : Duplicating tie_shape op\n"); |
| b.setInsertionPointAfter(tieShapeOp.getOperation()); |
| auto duped = b.create<Shape::TieShapeOp>( |
| tieShapeOp.getLoc(), tieShapeOp.getType(), tieShapeOp, |
| tieShapeOp.shape()); |
| metadataOp->replaceAllUsesWith(duped); |
| duped.getOperation()->replaceUsesOfWith(duped.result(), |
| tieShapeOp.result()); |
| metadataOps.insert(metadataOp); |
| } |
| } |
| } |
| } |
| |
| for (auto *metadataOp : metadataOps) { |
| subgraph->insert(metadataOp); |
| } |
| } |
| |
| // Finds all ops that can be fused together with the given |rootOp| by searching |
| // backwards in the op order through input edges. |
| // Returns a topologically sorted list of all fused ops with |rootOp| at the |
| // end. |
| std::vector<Operation *> findFusionSubgraphFromRoot( |
| Operation *rootOp, Dispatchability &dispatchability) { |
| LLVM_DEBUG(llvm::dbgs() << "+++ FINDING FUSION SUBGRAPH FROM ROOT: " |
| << rootOp->getName() << "\n"); |
| llvm::SetVector<Operation *> subgraph; |
| subgraph.insert(rootOp); |
| |
| if (isFusionRootOp(rootOp)) { |
| LLVM_DEBUG(llvm::dbgs() << "--- FUSING INTO ROOT\n\n"); |
| gatherFusionOps(rootOp, dispatchability, {}, &subgraph); |
| } else { |
| LLVM_DEBUG(llvm::dbgs() << "--- FUSED TO SINGLE NON-ROOT\n\n"); |
| } |
| |
| extendInboundMetadataOps(&subgraph); |
| extendOutboundMetadataOps(&subgraph); |
| LLVM_DEBUG(llvm::dbgs() << "--- FUSED SUBGRAPH OF " << subgraph.size() |
| << " OPS\n\n"); |
| return sortOpsTopologically(subgraph); |
| } |
| |
| // Identifies ranges of dispatchable ops and moves them into dispatch regions. |
| LogicalResult identifyBlockDispatchRegions(Block *block, |
| Dispatchability &dispatchability) { |
| // Fixed point iteration until we can no longer fuse anything. |
| bool didFindAnyNewRegions; |
| do { |
| // Iterate in reverse so we root further along in the op list. |
| didFindAnyNewRegions = false; |
| for (auto &rootOp : llvm::reverse(*block)) { |
| LLVM_DEBUG(llvm::dbgs() << "-> EVALUATING OP FOR ROOT FUSION: " |
| << rootOp.getName() << "\n"); |
| |
| if (!isDispatchableOp(&rootOp, dispatchability)) { |
| // Op should remain at the sequencer level. |
| LLVM_DEBUG(llvm::dbgs() << " -SKIP NON DISPATCHABLE OP-\n"); |
| continue; |
| } |
| if (isNonFusionRootOp(&rootOp)) { |
| // Don't form a root around ops that cannot be a fusion root (but |
| // may be otherwise dispatchable). |
| LLVM_DEBUG(llvm::dbgs() << " -SKIP NON FUSION ROOT OP-\n"); |
| continue; |
| } |
| |
| // Attempt to find all operations, including rootOp, that can be fused. |
| // The ops will be sorted in topological order with rootOp as the last op. |
| // Worst case we may end up with a subgraph of only the rootOp. |
| auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp, dispatchability); |
| |
| // Compute the workload based on the output shape. |
| // When variadic all output shapes match so we can just take the first. |
| auto workload = calculateWorkload(&rootOp, rootOp.getResult(0)); |
| if (!workload) { |
| return failure(); |
| } |
| |
| // Try to build a dispatch region from this root. |
| if (failed(buildDispatchRegion(block, workload, fusedSubgraph))) { |
| return failure(); |
| } |
| |
| // Successfully created a dispatch region from the ops and we must now |
| // start over again as we've likely trashed the whole block structure. |
| didFindAnyNewRegions = true; |
| break; |
| } |
| } while (didFindAnyNewRegions); |
| return success(); |
| } |
| |
| } // namespace |
| |
| // Identifies dispatchable ops and moves them into dispatch regions. |
| // Some ops, such as call, will be deferred until following passes. |
| class IdentifyDispatchRegionsPass |
| : public PassWrapper<IdentifyDispatchRegionsPass, FunctionPass> { |
| public: |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<IREE::Flow::FlowDialect>(); |
| } |
| |
| void runOnFunction() override { |
| // NOTE: we require the DispatchabilityAnalysisPass to have run first. |
| auto dispatchability = getCachedParentAnalysis<Dispatchability>(); |
| if (!dispatchability.hasValue()) { |
| getFunction().emitError() |
| << "dispatchability analysis not performed " |
| "on module; run -iree-flow-dispatchability-analysis first"; |
| return signalPassFailure(); |
| } |
| |
| for (auto &block : getFunction()) { |
| if (failed(identifyBlockDispatchRegions(&block, |
| dispatchability.getValue()))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| }; |
| |
| std::unique_ptr<OperationPass<FuncOp>> createIdentifyDispatchRegionsPass() { |
| return std::make_unique<IdentifyDispatchRegionsPass>(); |
| } |
| |
| static PassRegistration<IdentifyDispatchRegionsPass> pass( |
| "iree-flow-identify-dispatch-regions", |
| "Conservatively identifies dispatch regions in functions"); |
| |
| } // namespace Flow |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |