| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/PassRegistry.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| /// Pattern to combine instructions across ForOp boundary. It is common when |
| /// doing incremental lowering to generate transient ops that cancel each others |
| /// out. Canonicalization usually clean up those operations. When the value is |
| /// loop carried, MLIR canonicalization currently doesn't remove the redundant |
| /// operations. |
| /// |
| /// This pass allow to workaround MLIR limitation and does ad hoc clean up of |
| /// instructions found in IREE. Once we have a more general mechanism in MLIR |
| /// this pass can be completely removed. |
| /// This pass does this kind of transformation: |
| /// ``` |
| /// %21 = vector.shape_cast %20 : vector<4xf32> to vector<1x4xf32> |
| /// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %21) |
| /// -> vector<1x4xf32> { |
| /// [...] |
| /// %100 = vector.shape_cast %arg4 : vector<1x4xf32> to vector<4xf32> |
| /// [...] |
| /// %109 = vector.shape_cast %108 : vector<4xf32> to vector<1x4xf32> |
| /// scf.yield %109 : vector<1x4xf32> |
| /// } |
| /// %24 = vector.shape_cast %22 : vector<1x4xf32> to vector<4xf32> |
| /// ``` |
| /// -> |
| /// ``` |
| /// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %20) |
| /// -> vector<4xf32> { |
| /// [...] |
| /// scf.yield %108 : vector<4xf32> |
| /// } |
| /// ``` |
| struct CanonicalizeForOpInductionVarShape final |
| : public OpRewritePattern<scf::ForOp> { |
| using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
| |
| Value FoldCarryDep(scf::ForOp forOp, Operation* ivUser, |
| Operation* ivDef) const { |
| if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(ivUser)) { |
| if (auto souceOp = dyn_cast<vector::ShapeCastOp>(ivDef)) { |
| if (shapeCast.getType() == souceOp.source().getType()) { |
| return souceOp.source(); |
| } |
| } |
| } else if (auto extractOp = dyn_cast<vector::ExtractOp>(ivUser)) { |
| if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(ivDef)) { |
| if (extractOp.getType() == broadcastOp.getSourceType()) { |
| return broadcastOp.source(); |
| } |
| } |
| } else if (auto targetOp = dyn_cast<UnrealizedConversionCastOp>(ivUser)) { |
| if (auto sourceOp = dyn_cast<UnrealizedConversionCastOp>(ivDef)) { |
| if (sourceOp->getNumOperands() == 1 && targetOp->getNumResults() == 1 && |
| sourceOp->getOperandTypes().front() == |
| targetOp.getResultTypes().front()) { |
| return sourceOp.inputs().front(); |
| } |
| } |
| } |
| return Value(); |
| } |
| |
| void transferBody(Block* source, Block* dest, ArrayRef<Value> results, |
| PatternRewriter& rewriter) const { |
| // Move all operations to the destination block. |
| rewriter.mergeBlocks(source, dest, dest->getArguments()); |
| // Replace the yield op by one that returns only the used values. |
| auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); |
| yieldOp.getOperation()->setOperands(results); |
| } |
| |
| LogicalResult matchAndRewrite(scf::ForOp forOp, |
| PatternRewriter& rewriter) const override { |
| SmallVector<unsigned, 8> iteratorFolded; |
| SmallVector<Operation*, 8> resultOps; |
| auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| auto returnValues = llvm::to_vector<8>(terminator.getOperands()); |
| auto initArgs = llvm::to_vector<8>(forOp.getIterOperands()); |
| for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { |
| if (!it.value().hasOneUse()) continue; |
| Operation* op = it.value().use_begin()->getOwner(); |
| if (!isa<vector::ShapeCastOp, vector::ExtractOp, |
| UnrealizedConversionCastOp>(op)) { |
| continue; |
| } |
| Operation* returnValDef = returnValues[it.index()].getDefiningOp(); |
| Value newReturn = FoldCarryDep(forOp, op, returnValDef); |
| if (!newReturn) continue; |
| iteratorFolded.push_back(it.index()); |
| resultOps.push_back(returnValDef); |
| returnValues[it.index()] = newReturn; |
| |
| BlockAndValueMapping mapping; |
| mapping.map(it.value(), initArgs[it.index()]); |
| initArgs[it.index()] = rewriter.clone(*op, mapping)->getResult(0); |
| } |
| if (iteratorFolded.empty()) return failure(); |
| auto newLoop = rewriter.create<scf::ForOp>( |
| forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| forOp.getStep(), initArgs); |
| transferBody(forOp.getBody(), newLoop.getBody(), returnValues, rewriter); |
| |
| // Replace the operation by the new one. |
| SmallVector<Value, 8> repResults(newLoop.getResults().begin(), |
| newLoop.getResults().end()); |
| for (auto en : llvm::enumerate(iteratorFolded)) { |
| BlockAndValueMapping mapping; |
| mapping.map(returnValues[en.value()], newLoop.getResult(en.value())); |
| repResults[en.index()] = |
| rewriter.clone(*resultOps[en.index()], mapping)->getResult(0); |
| Operation* oldOp = |
| newLoop.getRegionIterArgs()[en.index()].use_begin()->getOwner(); |
| SmallVector<Value, 1> arg(1, newLoop.getRegionIterArgs()[en.index()]); |
| oldOp->replaceAllUsesWith(arg); |
| } |
| rewriter.replaceOp(forOp, repResults); |
| return success(); |
| } |
| }; |
| |
| /// An ad-hoc pattern to convert scf.for loop-carried values from |
| /// `vector<8xf16>` to `vector<4xf32>` by inserting `vector.bitcast` around |
| /// scf.for boundaries. |
| /// |
| /// Those loop-carried values will be lowered into SPIR-V local variables. This |
| /// pattern allows packing f16 values into f32 variables tightly so that we can |
| /// generate shader conformant SPIR-V. |
| struct PackForOpInductionVarVector final : public OpRewritePattern<scf::ForOp> { |
| using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(scf::ForOp forOp, |
| PatternRewriter& rewriter) const override { |
| VectorType v8f16Type = VectorType::get({8}, rewriter.getF16Type()); |
| VectorType v4f32Type = VectorType::get({4}, rewriter.getF32Type()); |
| |
| SmallVector<unsigned, 8> ivIndices; |
| for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { |
| if (it.value().getType() == v8f16Type) ivIndices.push_back(it.index()); |
| } |
| if (ivIndices.empty()) return failure(); |
| |
| // Bit cast all init values from v8f16 to v4f32. |
| auto ivInitValues = llvm::to_vector<8>(forOp.getIterOperands()); |
| for (unsigned index : ivIndices) { |
| Value oldValue = ivInitValues[index]; |
| ivInitValues[index] = rewriter.create<vector::BitCastOp>( |
| oldValue.getLoc(), v4f32Type, oldValue); |
| } |
| |
| // Create a new loop with the casted init values. This also creates |
| // induction variables with proper type. |
| auto newLoop = rewriter.create<scf::ForOp>( |
| forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| forOp.getStep(), ivInitValues); |
| |
| // Move all operations to the new for op. This also replaces block |
| // arguments. to the new block arguments. |
| rewriter.mergeBlocks(forOp.getBody(), newLoop.getBody(), |
| newLoop.getBody()->getArguments()); |
| |
| // Bit cast induction variables back to the original type to fix uses. |
| rewriter.setInsertionPointToStart(newLoop.getBody()); |
| for (unsigned index : ivIndices) { |
| Value newIv = newLoop.getRegionIterArgs()[index]; |
| auto bitcastOp = |
| rewriter.create<vector::BitCastOp>(newIv.getLoc(), v8f16Type, newIv); |
| // Replace all uses of the new induction variable with a bitcast. We need |
| // to exclude the bitcast op itself given it also uses the induction |
| // variable. |
| SmallPtrSet<Operation*, 1> exceptions{bitcastOp}; |
| newIv.replaceAllUsesExcept(bitcastOp, exceptions); |
| } |
| |
| auto yieldOp = cast<scf::YieldOp>(newLoop.getBody()->getTerminator()); |
| auto ivRetValues = llvm::to_vector<8>(yieldOp.getOperands()); |
| |
| // Bit cast return values to the new type to fix yield. |
| rewriter.setInsertionPoint(yieldOp); |
| for (unsigned index : ivIndices) { |
| Value oldRet = ivRetValues[index]; |
| ivRetValues[index] = rewriter.create<vector::BitCastOp>( |
| oldRet.getLoc(), v4f32Type, oldRet); |
| } |
| yieldOp->setOperands(ivRetValues); |
| |
| SmallVector<Value, 8> forRetValues; |
| for (Value result : newLoop.getResults()) forRetValues.push_back(result); |
| |
| // Bit cast return values to the old type to fix for op uses. |
| rewriter.setInsertionPointAfter(newLoop); |
| for (unsigned index : ivIndices) { |
| Value oldRet = forRetValues[index]; |
| forRetValues[index] = rewriter.create<vector::BitCastOp>( |
| oldRet.getLoc(), v8f16Type, oldRet); |
| } |
| |
| rewriter.replaceOp(forOp, forRetValues); |
| return success(); |
| } |
| }; |
| |
| struct ForOpCanonicalizationPass |
| : public ForOpCanonicalizationBase<ForOpCanonicalizationPass> { |
| void getDependentDialects(DialectRegistry& registry) const override { |
| registry.insert<scf::SCFDialect, vector::VectorDialect>(); |
| } |
| |
| void runOnOperation() override { |
| FuncOp fn = getOperation(); |
| RewritePatternSet patterns(&getContext()); |
| patterns.insert<CanonicalizeForOpInductionVarShape, |
| PackForOpInductionVarVector>(fn.getContext()); |
| if (failed(applyPatternsAndFoldGreedily(fn, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createForOpCanonicalizationPass() { |
| return std::make_unique<ForOpCanonicalizationPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |