| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "mlir/Dialect/StandardOps/Ops.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/IR/Builders.h" |
| #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace Flow { |
| |
| bool isOpOfKnownDialect(Operation *op) { |
| if (!op->getDialect()) return false; |
| // TODO(benvanik): replace with op dispatchability interface to allow dialects |
| // to opt into dispatch. |
| auto dialectNamespace = op->getDialect()->getNamespace(); |
| return dialectNamespace == xla_hlo::XlaHloDialect::getDialectNamespace() || |
| dialectNamespace == mlir::StandardOpsDialect::getDialectNamespace() || |
| dialectNamespace == FlowDialect::getDialectNamespace(); |
| } |
| |
| namespace { |
| |
| // Returns the set of values that must be captured for use by |ops| and the |
| // set of values defined by |ops| that are used outside of the set. |
| LogicalResult analyzeOpRangeValues( |
| const llvm::SmallDenseSet<Operation *> &opSet, |
| llvm::SetVector<ValuePtr> *capturedValues, |
| llvm::SetVector<ValuePtr> *escapingValues) { |
| for (auto *op : opSet) { |
| for (auto value : op->getOperands()) { |
| if (!llvm::is_contained(opSet, value->getDefiningOp())) { |
| // Op is using a value not in the ops set, ensure we capture it. |
| capturedValues->insert(value); |
| } |
| } |
| for (auto value : op->getResults()) { |
| for (auto &use : value->getUses()) { |
| if (!llvm::is_contained(opSet, use.getOwner())) { |
| // An op outside of the ops set is using the value, needs to escape. |
| escapingValues->insert(value); |
| } |
| } |
| } |
| } |
| return success(); |
| } |
| |
| } // namespace |
| |
| LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, |
| ValuePtr workload, |
| ArrayRef<Operation *> ops) { |
| // Fused location with all ops. |
| SmallVector<Location, 16> opLocs; |
| for (auto *op : ops) { |
| opLocs.push_back(op->getLoc()); |
| } |
| auto regionLoc = FusedLoc::get(opLocs, func.getContext()); |
| |
| // Get a list of values that we need to capture and values that escape the |
| // region and need to be returned. |
| llvm::SmallDenseSet<Operation *> opSet; |
| opSet.reserve(ops.size()); |
| opSet.insert(ops.begin(), ops.end()); |
| llvm::SetVector<ValuePtr> capturedValues; |
| llvm::SetVector<ValuePtr> escapingValues; |
| if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) { |
| return failure(); |
| } |
| SmallVector<Type, 8> escapingTypes; |
| for (auto value : escapingValues) escapingTypes.push_back(value->getType()); |
| |
| // Build the region op and add it to the parent block. |
| OpBuilder parentBuilder(parentBlock); |
| parentBuilder.setInsertionPoint(ops.back()); |
| auto dispatchRegionOp = parentBuilder.create<IREE::Flow::DispatchRegionOp>( |
| regionLoc, escapingTypes, workload, capturedValues.getArrayRef()); |
| |
| // Create the block and setup the arg mapping for captured values. |
| auto *regionBlock = new Block(); |
| dispatchRegionOp.body().push_back(regionBlock); |
| OpBuilder regionBuilder(regionBlock); |
| BlockAndValueMapping mapping; |
| for (auto capturedValue : capturedValues) { |
| auto blockArg = regionBlock->addArgument(capturedValue->getType()); |
| mapping.map(capturedValue, blockArg); |
| } |
| |
| // Clone ops into the new region block. |
| for (auto *op : ops) { |
| // Note that this updates the mapping with the new values (so at the end |
| // we have those new values). |
| regionBuilder.clone(*op, mapping); |
| } |
| |
| // Return results (as we need a terminator in our block). |
| // These are all of the values that escape our region. |
| SmallVector<ValuePtr, 8> resultValues; |
| for (auto oldValue : escapingValues) { |
| resultValues.push_back(mapping.lookupOrDefault(oldValue)); |
| } |
| regionBuilder.create<IREE::Flow::ReturnOp>(opLocs.back(), resultValues); |
| |
| // Replace usage of values with the results of the region. |
| for (int i = 0; i < escapingValues.size(); ++i) { |
| escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i)); |
| } |
| |
| // Remove original ops from the parent region. |
| for (auto it = ops.rbegin(); it != ops.rend(); ++it) { |
| (*it)->erase(); |
| } |
| |
| return success(); |
| } |
| |
| namespace { |
| |
| // Recursively finds all reachable functions from the given |rootFunc| and adds |
| // them to the |reachableFuncs| set. |
| // |
| // Note that indirect calls are not supported, however we don't allow those in |
| // dispatch regions anyway so they should not be present here. |
| LogicalResult findReachableFunctions( |
| FuncOp rootFuncOp, llvm::SetVector<FuncOp> &reachableFuncs, |
| llvm::StringMap<FuncOp> &dispatchableFuncOps) { |
| llvm::SetVector<FuncOp> worklist; |
| worklist.insert(rootFuncOp); |
| while (!worklist.empty()) { |
| auto funcOp = worklist.pop_back_val(); |
| funcOp.walk([&](CallOp callOp) { |
| auto calleeOp = dispatchableFuncOps.find(callOp.callee())->second; |
| if (reachableFuncs.insert(calleeOp)) { |
| worklist.insert(calleeOp); |
| } |
| }); |
| } |
| return success(); |
| } |
| |
| } // namespace |
| |
| std::pair<IREE::Flow::ExecutableOp, FuncOp> createRegionExecutable( |
| Operation *op, FunctionType functionType, StringRef symbolSuffix, |
| llvm::StringMap<FuncOp> &dispatchableFuncOps) { |
| // Create the function and take the region body directly. |
| // NOTE: this will get uniquified if we have multiple in the same block. |
| auto parentFunc = op->getParentOfType<FuncOp>(); |
| std::string functionName = |
| (parentFunc.getName().str() + "_rgn" + symbolSuffix).str(); |
| auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType); |
| BlockAndValueMapping mapping; |
| op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping); |
| |
| // Replace flow.return with std.return. |
| for (auto &block : outlinedFunc.getBlocks()) { |
| if (auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(block.back())) { |
| OpBuilder builder(returnOp); |
| builder.create<mlir::ReturnOp>( |
| returnOp.getLoc(), llvm::to_vector<4>(returnOp.getOperands())); |
| returnOp.erase(); |
| } |
| } |
| |
| // Gather all reachable functions. |
| llvm::SetVector<FuncOp> reachableFuncs; |
| findReachableFunctions(outlinedFunc, reachableFuncs, dispatchableFuncOps); |
| |
| // Create the executable that will contain the outlined region. |
| // NOTE: this will get uniquified if we have multiple in the same block. |
| auto parentModule = parentFunc.getParentOfType<ModuleOp>(); |
| OpBuilder parentModuleBuilder(parentModule); |
| parentModuleBuilder.setInsertionPoint(parentFunc); |
| std::string executableName = |
| (parentFunc.getName().str() + "_ex" + symbolSuffix).str(); |
| auto executableOp = parentModuleBuilder.create<IREE::Flow::ExecutableOp>( |
| outlinedFunc.getLoc(), executableName); |
| |
| // Create the inner ModuleOp that contains the original functions. We need |
| // to provide this shim as some ops (like std.call) look for the |
| // containing module to provide symbol resolution. |
| OpBuilder executableBuilder(executableOp); |
| executableBuilder.setInsertionPointToStart(&executableOp.getBlock()); |
| auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc()); |
| innerModule.push_back(outlinedFunc); |
| |
| // Copy all reachable functions into the executable. |
| // Linker passes may dedupe these later on. |
| OpBuilder innerModuleBuilder(innerModule.getBody()); |
| innerModuleBuilder.setInsertionPoint(innerModule.getBody(), |
| ++innerModule.getBody()->begin()); |
| for (auto reachableFunc : reachableFuncs) { |
| innerModuleBuilder.clone(*reachableFunc); |
| } |
| |
| return std::make_pair(executableOp, outlinedFunc); |
| } |
| |
| } // namespace Flow |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |