| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements IREE-specific preprocessing for XLA inputs. |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| |
| #define GEN_PASS_DEF_FLATTENTUPLESINCFG |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h.inc" |
| |
| namespace { |
| // Given a set of types, unpack to a list of a types, removing all tuples. |
| void untupleTypes(TypeRange types, llvm::SmallVectorImpl<Type> &newTypes) { |
| for (Type type : types) { |
| if (auto tupleTy = dyn_cast<TupleType>(type)) { |
| untupleTypes(tupleTy.getTypes(), newTypes); |
| } else { |
| newTypes.push_back(type); |
| } |
| } |
| } |
| |
| template <typename T> |
| bool hasTuples(T values) { |
| bool isTuple = false; |
| for (auto val : values) { |
| isTuple |= isa<TupleType>(val.getType()); |
| } |
| |
| return isTuple; |
| } |
| |
| Value processTuple(Type type, Location loc, Block &block, OpBuilder &builder) { |
| auto tupleType = dyn_cast<TupleType>(type); |
| if (!tupleType) { |
| return block.addArgument(type, loc); |
| } |
| |
| llvm::SmallVector<Value> values; |
| values.reserve(tupleType.size()); |
| for (Type subtype : tupleType.getTypes()) { |
| values.push_back(processTuple(subtype, loc, block, builder)); |
| } |
| |
| return mlir::stablehlo::TupleOp::create(builder, loc, tupleType, values); |
| } |
| |
| void copyOperationAttrs(Operation *oldOp, Operation *newOp) { |
| for (NamedAttribute oldAttr : oldOp->getAttrs()) { |
| // Don't copy segment attributes as these correspond to the number operands, |
| // which may be different. |
| if (oldAttr.getName() == "operandSegmentSizes" || |
| oldAttr.getName() == "resultSegmentSizes") |
| continue; |
| |
| newOp->setAttr(oldAttr.getName(), oldAttr.getValue()); |
| } |
| } |
| |
| void recursiveUntuple(Value value, Location loc, OpBuilder &builder, |
| llvm::SmallVectorImpl<Value> &newValues) { |
| auto tupleType = dyn_cast<TupleType>(value.getType()); |
| if (!tupleType) { |
| // We can return the value as is. |
| newValues.push_back(value); |
| return; |
| } |
| |
| for (auto [idx, subType] : llvm::enumerate(tupleType.getTypes())) { |
| auto elementOp = mlir::stablehlo::GetTupleElementOp::create( |
| builder, loc, subType, value, builder.getI32IntegerAttr(idx)); |
| recursiveUntuple(elementOp.getResult(), loc, builder, newValues); |
| } |
| } |
| |
| Value recursiveRetuple(Type oldType, Operation::result_range *values, |
| OpBuilder &builder, Location loc) { |
| auto tupleType = dyn_cast<TupleType>(oldType); |
| if (!tupleType) { |
| Value returnValue = *values->begin(); |
| *values = {values->begin() + 1, values->end()}; |
| return returnValue; |
| } |
| |
| llvm::SmallVector<Value> subValues; |
| for (Type subType : tupleType.getTypes()) { |
| subValues.push_back(recursiveRetuple(subType, values, builder, loc)); |
| } |
| |
| return mlir::stablehlo::TupleOp::create(builder, loc, tupleType, subValues) |
| .getResult(); |
| } |
| |
| template <typename T> |
| LogicalResult untupleAndLookupValues(T values, |
| llvm::SmallVectorImpl<Value> &newValues, |
| OpBuilder &builder, Location loc) { |
| IRMapping mapping; |
| for (auto operand : values) { |
| recursiveUntuple(operand, loc, builder, newValues); |
| } |
| |
| return success(); |
| } |
| |
| class DetupleReturnOp : public OpRewritePattern<func::ReturnOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(func::ReturnOp op, |
| PatternRewriter &builder) const override { |
| if (!hasTuples(op.getOperands())) |
| return builder.notifyMatchFailure(op, "No detupling required"); |
| |
| llvm::SmallVector<Value> newOperands; |
| if (failed(untupleAndLookupValues(op.getOperands(), newOperands, builder, |
| op.getLoc()))) { |
| return builder.notifyMatchFailure(op, "failed to untuple"); |
| } |
| |
| mlir::func::ReturnOp::create(builder, op->getLoc(), newOperands); |
| builder.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| class DetupleCallOp : public OpRewritePattern<func::CallOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(func::CallOp oldOp, |
| PatternRewriter &builder) const override { |
| if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) |
| return builder.notifyMatchFailure(oldOp, "No detupling required"); |
| |
| llvm::SmallVector<Value> newArgs; |
| if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, |
| oldOp.getLoc()))) { |
| return builder.notifyMatchFailure(oldOp, "failed to untuple values"); |
| } |
| |
| SmallVector<Type> resultTypes; |
| untupleTypes(oldOp.getResultTypes(), resultTypes); |
| auto newOp = func::CallOp::create(builder, oldOp->getLoc(), |
| oldOp.getCallee(), resultTypes, newArgs); |
| copyOperationAttrs(oldOp, newOp); |
| |
| auto newResults = newOp.getResults(); |
| llvm::SmallVector<Value> retupledResults; |
| for (auto oldResult : oldOp.getResults()) { |
| auto newResult = recursiveRetuple(oldResult.getType(), &newResults, |
| builder, oldOp->getLoc()); |
| retupledResults.push_back(newResult); |
| } |
| |
| builder.replaceOp(oldOp, retupledResults); |
| return success(); |
| } |
| }; |
| |
| class DetupleIndirectCallOp : public OpRewritePattern<func::CallIndirectOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(func::CallIndirectOp oldOp, |
| PatternRewriter &builder) const override { |
| if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) |
| return builder.notifyMatchFailure(oldOp, "No detupling required"); |
| |
| llvm::SmallVector<Value> newArgs; |
| if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, |
| oldOp.getLoc()))) { |
| return builder.notifyMatchFailure(oldOp, "failed to untuple values"); |
| } |
| |
| auto newOp = func::CallIndirectOp::create(builder, oldOp.getLoc(), |
| oldOp.getCallee(), newArgs); |
| copyOperationAttrs(oldOp, newOp); |
| builder.replaceOp(oldOp, newOp.getResults()); |
| return success(); |
| } |
| }; |
| |
| class DetupleBranchOp : public OpRewritePattern<cf::BranchOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(cf::BranchOp oldOp, |
| PatternRewriter &builder) const override { |
| if (!hasTuples(oldOp.getOperands())) |
| return builder.notifyMatchFailure(oldOp, "No detupling required"); |
| |
| llvm::SmallVector<Value> newArgs; |
| if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, |
| oldOp.getLoc()))) { |
| return builder.notifyMatchFailure(oldOp, "failed to untuple values"); |
| } |
| |
| auto newOp = |
| cf::BranchOp::create(builder, oldOp.getLoc(), oldOp.getDest(), newArgs); |
| |
| copyOperationAttrs(oldOp, newOp); |
| builder.eraseOp(oldOp); |
| return success(); |
| } |
| }; |
| |
| class DetupleConditionOp : public OpRewritePattern<cf::CondBranchOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(cf::CondBranchOp oldOp, |
| PatternRewriter &builder) const override { |
| if (!hasTuples(oldOp.getOperands())) |
| return builder.notifyMatchFailure(oldOp, "No detupling required"); |
| |
| llvm::SmallVector<Value> trueArgs; |
| if (failed(untupleAndLookupValues(oldOp.getTrueOperands(), trueArgs, |
| builder, oldOp.getLoc()))) { |
| return builder.notifyMatchFailure(oldOp, "Failed to detuple true args"); |
| } |
| |
| llvm::SmallVector<Value> falseArgs; |
| if (failed(untupleAndLookupValues(oldOp.getFalseOperands(), falseArgs, |
| builder, oldOp.getLoc()))) { |
| return builder.notifyMatchFailure(oldOp, "Failed to detuple false args"); |
| } |
| |
| auto newOp = cf::CondBranchOp::create( |
| builder, oldOp.getLoc(), oldOp.getCondition(), oldOp.getTrueDest(), |
| trueArgs, oldOp.getFalseDest(), falseArgs); |
| |
| copyOperationAttrs(oldOp, newOp); |
| |
| builder.eraseOp(oldOp); |
| return success(); |
| } |
| }; |
| |
| LogicalResult convertFunction(func::FuncOp oldFunction, |
| func::FuncOp newFunction) { |
| OpBuilder builder(newFunction.getBody()); |
| IRMapping mapping; |
| |
| // Check whether has tuple in signature. |
| bool hasTupleSig = (oldFunction.getArgumentTypes().size() != |
| newFunction.getArgumentTypes().size()) || |
| (oldFunction.getResultTypes().size() != |
| newFunction.getResultTypes().size()); |
| |
| auto xlaAbiParam = StringAttr::get(newFunction.getContext(), |
| "xla_entry_computation_parameter_layouts"); |
| auto xlaAbiLayout = StringAttr::get(newFunction.getContext(), |
| "xla_entry_computation_result_layout"); |
| |
| for (NamedAttribute attr : oldFunction->getAttrs()) { |
| // Currently skipping all arg, result and XLA specific ABI attributes. |
| if (llvm::is_contained( |
| {oldFunction.getFunctionTypeAttrName(), xlaAbiParam, xlaAbiLayout}, |
| attr.getName())) { |
| continue; |
| } |
| |
| // If it has tuples in sig, then skip arg and res attrs. None of the |
| // existing ones along path that produces tuples are used further, so just |
| // remove instead of flattening. |
| if (hasTupleSig && (attr.getName() == oldFunction.getArgAttrsAttrName() || |
| attr.getName() == oldFunction.getResAttrsAttrName())) |
| continue; |
| newFunction->setAttr(attr.getName(), attr.getValue()); |
| } |
| |
| newFunction.getBlocks().clear(); |
| for (Block &oldBlock : oldFunction.getBlocks()) { |
| Block *newBlock = builder.createBlock(&newFunction.getBody()); |
| for (BlockArgument oldArg : oldBlock.getArguments()) { |
| llvm::SmallVector<Type> newTypes; |
| untupleTypes(oldArg.getType(), newTypes); |
| |
| Value newTuple = processTuple(oldArg.getType(), oldFunction.getLoc(), |
| *newBlock, builder); |
| if (!newTuple) { |
| return failure(); |
| } |
| |
| mapping.map(oldArg, newTuple); |
| } |
| mapping.map(&oldBlock, newBlock); |
| } |
| |
| // Convert all ops in the blocks. |
| for (Block &oldBlock : oldFunction.getBlocks()) { |
| builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); |
| for (Operation &oldOp : oldBlock.getOperations()) { |
| builder.clone(oldOp, mapping); |
| } |
| } |
| |
| return success(); |
| } |
| |
| struct FlattenTuplesInCFG final |
| : impl::FlattenTuplesInCFGBase<FlattenTuplesInCFG> { |
| void runOnOperation() override { |
| ModuleOp module = getOperation(); |
| MLIRContext *ctx = module.getContext(); |
| Builder builder(ctx); |
| |
| // Build a list of (oldFunction, newFunction) for all functions we need to |
| // replace. This will ensure that when we go to convert function bodies we |
| // have only new functions defined. |
| SmallVector<std::pair<func::FuncOp, func::FuncOp>> convertedFunctions; |
| for (auto oldFunction : module.getOps<func::FuncOp>()) { |
| FunctionType oldFunctionType = oldFunction.getFunctionType(); |
| |
| llvm::SmallVector<Type> newInputTypes; |
| untupleTypes(oldFunctionType.getInputs(), newInputTypes); |
| |
| llvm::SmallVector<Type> newResultTypes; |
| untupleTypes(oldFunctionType.getResults(), newResultTypes); |
| |
| FunctionType newFunctionType = |
| builder.getFunctionType(newInputTypes, newResultTypes); |
| func::FuncOp newFunction = |
| func::FuncOp::create(oldFunction.getLoc(), oldFunction.getName(), |
| newFunctionType, oldFunction->getDialectAttrs()); |
| convertedFunctions.push_back({oldFunction, newFunction}); |
| |
| // Perform the actual body conversion now that we have proper signatures. |
| if (failed(convertFunction(oldFunction, newFunction))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| // Replace functions in the module. |
| for (auto [oldFunction, newFunction] : convertedFunctions) { |
| oldFunction.erase(); |
| module.push_back(newFunction); |
| } |
| |
| // Run canonicalization patterns to cancel out remaining tuple ops. We need |
| // to run these manually here because StableHLO does not define |
| // folds/canonicalization patterns for its ops. |
| RewritePatternSet patterns(ctx); |
| patterns.insert<DetupleCallOp, DetupleIndirectCallOp, DetupleConditionOp, |
| DetupleReturnOp, DetupleBranchOp>(ctx); |
| populateCanonicalizationPatterns(ctx, &patterns); |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| } // namespace mlir::iree_compiler::stablehlo |