| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| //===- SplitDispathFunctionPass.cpp ---------------------------------------===// |
| // |
| // This file implements a pass to split computation workload to multiple |
| // sequential dispatch functions. This pass operates on Linalg ops and |
| // scf.parallel op and prepares for lowering to GPU, where we need to tile the |
| // workload to workgroups and workitems. If the workload involves computation A |
| // and B, where B is dependent on A and A needs all workgroups to complete, then |
| // we need to split A and B into different kernels because there is no mechanism |
| // to perform cross-workgroup synchronization within a single kernel. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <iterator> |
| |
| #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" |
| #include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h" |
| #include "iree/compiler/Conversion/Common/Attributes.h" |
| #include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" |
| #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" |
| #include "iree/compiler/Dialect/IREE/IR/IREEOps.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/RegionUtils.h" |
| |
| #define DEBUG_TYPE "split-dispatch-function" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns true if an op can be fused with the list of ops that are to be put |
| /// in the same entry point function. This should be consistent with whatthe |
| /// downstream passes can handle. |
| static bool isFusableWithCurrentOpsList( |
| Operation *nextOp, ArrayRef<Operation *> currOpsList, |
| const linalg::LinalgDependenceGraph &dependenceGraph) { |
| if (currOpsList.empty()) return true; |
| |
| linalg::LinalgOp dstOp = dyn_cast<linalg::LinalgOp>(nextOp); |
| linalg::LinalgOp srcOp = dyn_cast<linalg::LinalgOp>(currOpsList.back()); |
| if (dstOp && srcOp) { |
| // TODO(#2963): This splits independent linalg opreations into its own |
| // dispatch, but in reality if the iteration domain of the ops are the same, |
| // and they have all iterator types parallel, they could be put in the same |
| // dispatch region. |
| if (!dependenceGraph.hasDependenceFrom(srcOp, dstOp)) return false; |
| |
| #define ADD_FUSABLE_PAIR(SrcOpTy, DstOpTy, DependenceTy) \ |
| if (isa<SrcOpTy>(srcOp.getOperation()) && \ |
| isa<DstOpTy>(dstOp.getOperation()) && \ |
| dependenceGraph.hasDependenceFrom(srcOp, dstOp, DependenceTy)) \ |
| return true; |
| |
| ADD_FUSABLE_PAIR(linalg::BatchMatmulOp, linalg::GenericOp, |
| linalg::LinalgDependenceGraph::DependenceType::RAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMaxOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMinOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingSumOp, |
| linalg::LinalgDependenceGraph::DependenceType::WAW) |
| ADD_FUSABLE_PAIR(linalg::MatmulOp, linalg::GenericOp, |
| linalg::LinalgDependenceGraph::DependenceType::RAW) |
| |
| #undef ADD_FUSABLE_PAIR |
| } |
| return false; |
| } |
| |
| /// For the list of operations in `ops` returns a list of lists where each list |
| /// contains the operations that need to be put in a separate dispatch function. |
| static LogicalResult separateOps( |
| ArrayRef<Operation *> ops, |
| const linalg::LinalgDependenceGraph &dependenceGraph, |
| SmallVectorImpl<SmallVector<Operation *, 1>> &fusedOpList) { |
| assert(!ops.empty() && |
| "expected at least one separable op for splitting dispatch function"); |
| SmallVector<Operation *, 1> currList; |
| for (auto currOp = ops.begin(), nextOp = std::next(ops.begin()); |
| nextOp != ops.end(); ++currOp, ++nextOp) { |
| // Check that the operation has buffer semantics. |
| if (auto linalgOp = dyn_cast<linalg::LinalgOp>(*currOp)) { |
| if (!linalgOp.hasBufferSemantics()) return failure(); |
| } |
| |
| // Require no other non-metadata ops interleave with Linalg structured ops |
| // for now. This is the common case and it simplifies further analysis. |
| Operation *iter = (*currOp)->getNextNode(); |
| while (iter != *nextOp && (MemoryEffectOpInterface::hasNoEffect(iter) || |
| isa<IREE::PlaceholderOp>(iter))) |
| iter = iter->getNextNode(); |
| if (iter != *nextOp) return failure(); |
| |
| currList.push_back(*currOp); |
| |
| // If the nextOp is not fusible with the currOp, then record the list of ops |
| // so far, and start a new list. |
| if (isFusableWithCurrentOpsList(*nextOp, currList, dependenceGraph)) { |
| continue; |
| } |
| |
| // Push the current list of ops into the list of lists `currList` and |
| // start a new list. |
| fusedOpList.emplace_back(); |
| std::swap(fusedOpList.back(), currList); |
| } |
| currList.push_back(ops.back()); |
| fusedOpList.emplace_back(std::move(currList)); |
| return success(); |
| } |
| |
| /// Recursively collects all the operations that are referenced by given |
| /// `rootOp` into `closure`. |
| static void collectAllReferencedOps( |
| ArrayRef<Operation *> rootOps, |
| llvm::SmallPtrSetImpl<Operation *> &closure) { |
| llvm::SmallVector<Operation *, 8> workList; |
| workList.assign(rootOps.begin(), rootOps.end()); |
| |
| while (!workList.empty()) { |
| Operation *curOp = workList.pop_back_val(); |
| if (!curOp) continue; |
| if (!closure.insert(curOp).second) continue; // Seen before |
| // Collect all defining ops for operands. |
| for (Value operand : curOp->getOperands()) { |
| if (Operation *owner = operand.getDefiningOp()) workList.push_back(owner); |
| } |
| // Collect all defining ops for the values used in regions. |
| for (Region ®ion : curOp->getRegions()) { |
| visitUsedValuesDefinedAbove(region, [&workList](OpOperand *operand) { |
| workList.push_back(operand->get().getDefiningOp()); |
| }); |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass and patterns |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| struct SplitDispatchFunctionPass |
| : public PassWrapper<SplitDispatchFunctionPass, OperationPass<ModuleOp>> { |
| void runOnOperation() override; |
| LogicalResult splitDispatchFunction(FuncOp oldFn, OpBuilder &builder); |
| }; |
| |
| } // namespace |
| |
| void SplitDispatchFunctionPass::runOnOperation() { |
| ModuleOp moduleOp = getOperation(); |
| |
| // Collect all dispatch entry functions. |
| SmallVector<FuncOp, 1> functions; |
| for (FuncOp fn : moduleOp.getOps<FuncOp>()) { |
| if (isEntryPoint(fn)) functions.push_back(fn); |
| } |
| if (functions.empty()) return; |
| if (functions.size() > 1) { |
| moduleOp.emitError("expected only one entry function"); |
| return signalPassFailure(); |
| } |
| |
| auto builder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| if (failed(splitDispatchFunction(functions.front(), builder))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| LogicalResult SplitDispatchFunctionPass::splitDispatchFunction( |
| FuncOp oldFn, OpBuilder &builder) { |
| // Entry functions are supported to be of `void(void)` type. |
| assert(oldFn.getType().getNumInputs() == 0 && |
| oldFn.getType().getNumResults() == 0); |
| |
| if (!llvm::hasSingleElement(oldFn.getBlocks())) { |
| return oldFn.emitError("expected only one block"); |
| } |
| |
| // The dispatch function should have more than one separable ops. Otherwise |
| // there is nothing to do. |
| Block &fnBody = oldFn.getBlocks().front(); |
| |
| // Collect all Linalg and scf.parallel ops for splitting. |
| SmallVector<Operation *, 4> separableOps; |
| for (Operation &op : fnBody) |
| if (isa<linalg::LinalgOp, scf::ParallelOp, scf::ForOp>(op)) |
| separableOps.push_back(&op); |
| |
| if (separableOps.size() <= 1) return success(); |
| |
| linalg::Aliases aliases; |
| linalg::LinalgDependenceGraph dependenceGraph = |
| linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, oldFn); |
| SmallVector<SmallVector<Operation *, 1>, 1> fusedOpsList; |
| if (failed(separateOps(separableOps, dependenceGraph, fusedOpsList))) { |
| return oldFn.emitError( |
| "cannot separate Linalg/Parallel ops into multiple kernels"); |
| } |
| if (fusedOpsList.size() <= 1) return success(); |
| |
| ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp()); |
| Block &oldFnBlock = oldFn.getBlocks().front(); |
| Location loc = oldFn.getLoc(); |
| |
| SmallVector<std::string, 4> splitKernels; |
| splitKernels.reserve(separableOps.size()); |
| llvm::SmallPtrSet<Operation *, 16> closure; |
| |
| for (const auto &fusedOps : llvm::enumerate(fusedOpsList)) { |
| if (fusedOps.value().empty()) continue; |
| // Create a new function for hosting this op. |
| splitKernels.emplace_back( |
| llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), fusedOps.index())); |
| StringRef newFnName = splitKernels.back(); |
| builder.setInsertionPointToStart(moduleOp.getBody()); |
| auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType()); |
| LLVM_DEBUG({ |
| llvm::dbgs() << "Created new function : func @" << newFn.getName() |
| << "\n"; |
| }); |
| |
| // Copy over all attributes except type and name. |
| for (const auto &namedAttr : oldFn.getAttrs()) { |
| if (namedAttr.first != impl::getTypeAttrName() && |
| namedAttr.first != SymbolTable::getSymbolAttrName() && |
| namedAttr.first != getNumWorkgroupsFnAttrName()) |
| newFn.setAttr(namedAttr.first, namedAttr.second); |
| } |
| // Need special handling for the number of workgroups function. |
| if (FuncOp numWorkgroupsFn = |
| getNumWorkgroupsFn(oldFn, getNumWorkgroupsFnAttrName())) { |
| FuncOp newNumWorkgroupsFn = |
| builder.create<FuncOp>(loc, newFnName.str() + "__num_workgroups__", |
| numWorkgroupsFn.getType()); |
| newNumWorkgroupsFn.setVisibility(FuncOp::Visibility::Private); |
| newFn.setAttr(getNumWorkgroupsFnAttrName(), |
| builder.getSymbolRefAttr(newNumWorkgroupsFn)); |
| LLVM_DEBUG({ |
| llvm::dbgs() << "Added func @" << newNumWorkgroupsFn.getName() |
| << " as num workgroups fn for func @" << newFn.getName() |
| << "\n"; |
| }); |
| } |
| |
| // Collect the closure for the current Linalg op. |
| closure.clear(); |
| collectAllReferencedOps(fusedOps.value(), closure); |
| |
| // Clone all ops in the closure to the new function. |
| Block *newFnBlock = newFn.addEntryBlock(); |
| builder.setInsertionPointToStart(newFnBlock); |
| BlockAndValueMapping remapper; |
| for (Operation &op : oldFnBlock) { |
| if (closure.count(&op) == 0) continue; |
| builder.insert(op.clone(remapper)); |
| if (&op == fusedOps.value().back()) break; |
| } |
| builder.insert(oldFnBlock.getTerminator()->clone(remapper)); |
| } |
| |
| // Add the entry point schedule to the module op. |
| SmallVector<Attribute, 4> entryPoints; |
| entryPoints.reserve(separableOps.size()); |
| for (const std::string &kernel : splitKernels) { |
| entryPoints.emplace_back(builder.getStringAttr(kernel)); |
| } |
| moduleOp.setAttr(getEntryPointScheduleAttrName(), |
| builder.getArrayAttr(entryPoints)); |
| |
| if (FuncOp numWorkgroupsFn = |
| getNumWorkgroupsFn(oldFn, getNumWorkgroupsFnAttrName())) { |
| LLVM_DEBUG({ |
| llvm::dbgs() << "Erased num workgroups fn func @" |
| << numWorkgroupsFn.getName() << " for func @" |
| << oldFn.getName() << "\n"; |
| }); |
| numWorkgroupsFn.erase(); |
| } |
| LLVM_DEBUG({ llvm::dbgs() << "Erased func @" << oldFn.getName() << "\n"; }); |
| oldFn.erase(); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass entry point and registration |
| //===----------------------------------------------------------------------===// |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createSplitDispatchFunctionPass() { |
| return std::make_unique<SplitDispatchFunctionPass>(); |
| } |
| |
| static PassRegistration<SplitDispatchFunctionPass> pass( |
| "iree-codegen-split-dispatch-function", |
| "Split workload to multiple dispatch functions to satisfy computation " |
| "dependency for GPU lowering"); |
| |
| } // namespace iree_compiler |
| } // namespace mlir |