| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- BufferizationAnalysis.cpp - Pre bufferization analysis -------------===// |
| // |
| // Analysis to group together tensors within a dispatch region into an |
| // equivalence class such that all members of a set can be mapped to the same |
| // memory region. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "iree/compiler/Codegen/Common/BufferizationAnalysis.h" |
| |
| #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Analysis/SliceAnalysis.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| |
| #define DEBUG_TYPE "iree-codegen-bufferization-analysis" |
| |
| namespace mlir::iree_compiler { |
| |
| //===----------------------------------------------------------------------===// |
| // Analysis to compute equivalence sets. |
| // |
| // These functions compute the equivalence relationships between all tensors in |
| // the program. Two tensors are equivalent if they are to be mapped to the same |
| // buffer. For every operation, based on the operation semantics the result of |
| // the operation can reuse the buffer for an operand of the operation. This |
| // information is captured by adding these two tensors to the same equivalence |
| // class. Eventually the result of the dispatch tensor is added to some |
| // equivalence set. All tensors in that equivalence set can reuse the result |
| // buffer and compute the values in place. You can add tensors to equivalence |
| // set only if |
| // - They have a single use |
| // - They are derived from a read-only buffer. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| /// Check if all users of an op that lowers to a subview eventually can use the |
| /// subview when converted to buffers. For example `linalg.reshape` (which is |
| /// the buffer version of `linalg.tensor_reshape`) cannot handle subviews. |
| static bool canUsersHandleSubviews(Operation *op) { |
| // TODO(ravishankarm): Maybe this is too aggressive, might have to switch this |
| // to have a white-list instead of blacklist. |
| for (Operation *user : op->getUsers()) { |
| if (isa<IREE::TensorExt::DispatchTensorStoreOp, tensor::CollapseShapeOp, |
| tensor::ExpandShapeOp>(user)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /// Walks the use-def chain and see if this value comes from a read-only tensor. |
| static bool isFromReadOnlyTensor(Value v, const BufferizationPlan &plan) { |
| Operation *definingOp = v.getDefiningOp(); |
| if (!definingOp) { |
| auto arg = cast<BlockArgument>(v); |
| return TypeSwitch<Operation *, bool>(arg.getOwner()->getParentOp()) |
| .Case([&](scf::ForOp forOp) { |
| Value initOperand = forOp.getTiedLoopInit(arg)->get(); |
| if (plan.isEquivalent(arg, initOperand)) { |
| return isFromReadOnlyTensor(initOperand, plan); |
| } |
| return false; |
| }) |
| .Default(false); |
| } |
| return isReadOnly(v); |
| } |
| |
| /// Adds the result of `std.constant` to its set (there is nothing to tie to |
| /// here). |
| static LogicalResult analyseConstantOp(arith::ConstantOp constantOp, |
| BufferizationPlan &plan) { |
| if (!isa<ShapedType>(constantOp.getResult().getType())) { |
| return success(); |
| } |
| plan.insert(constantOp.getResult()); |
| return success(); |
| } |
| |
| /// Adds the result of the `iree_tensor_ext.dispatch.tensor.load` op to the same |
| /// equivalence class as the source. |
| static LogicalResult |
| analyseInterfaceLoadTensorOp(IREE::TensorExt::DispatchTensorLoadOp loadOp, |
| BufferizationPlan &plan) { |
| plan.unionSets(loadOp.getResult(), loadOp.getSource()); |
| return success(); |
| } |
| |
| /// Helper method to returns an operation of type `OpType` whose result is in |
| /// the same equivalence set as `value`. Returns an operation if there is only |
| /// one such op in the equivalence set or nullptr in all other cases. |
| template <typename OpType> |
| static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) { |
| OpType equivalentOp; |
| SmallVector<Value> mappedTensors = plan.getTensorsMappedToSameSet(value); |
| for (auto v : mappedTensors) { |
| auto definingOp = v.getDefiningOp<OpType>(); |
| if (!definingOp) { |
| continue; |
| } |
| assert((!equivalentOp || equivalentOp == definingOp) && |
| "found two interface binding ops marked as equivalent"); |
| if (!equivalentOp) { |
| equivalentOp = definingOp; |
| } |
| } |
| return equivalentOp; |
| } |
| |
| /// Check if two sets can be merged based on what operations exist in that set. |
| static bool canSetsBeMerged(Value v1, Value v2, BufferizationPlan &plan) { |
| // Dont merge two sets if one of the sets is a constant. |
| if (getEquivalentOpOfType<arith::ConstantOp>(v1, plan) || |
| getEquivalentOpOfType<arith::ConstantOp>(v2, plan)) { |
| return false; |
| } |
| auto v1InterfaceBinding = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v1, plan); |
| auto v2InterfaceBinding = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(v2, plan); |
| // If any of these sets do not have a interface binding, they can be merged. |
| if (!v1InterfaceBinding || !v2InterfaceBinding) { |
| return true; |
| } |
| if (v1InterfaceBinding.getBinding() != v2InterfaceBinding.getBinding() || |
| v1InterfaceBinding.getByteOffset() != |
| v2InterfaceBinding.getByteOffset()) { |
| // If the set, binding or offsets are different, map these to different |
| // memrefs. |
| return false; |
| } |
| return true; |
| } |
| |
| /// Returns true if the value and target of a |
| /// `iree_tensor_ext.dispatch.tensor.store` operation can be added to the same |
| /// equivalence set. This can be done only if |
| /// - The `value` is not from a equivalence set that contains a read-only |
| /// tensor. |
| /// - All `hal.interface.binding.subspan` operations in the equivalence class of |
| /// `value` and `target` have the same binding and offset. For now, it is |
| /// assumed that the equivalence classes contain only 1 such instruction. |
| /// This method asserts that the `target` equivalence class already contains a |
| /// `hal.interface.binding.subspan` op.' |
| static bool canSetStoreValueAndTargetAsEquivalent( |
| IREE::TensorExt::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) { |
| if (!canSetsBeMerged(storeOp.getValue(), storeOp.getTarget(), plan)) { |
| return false; |
| } |
| auto valueInterfaceBinding = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>( |
| storeOp.getValue(), plan); |
| auto targetInterfaceBinding = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>( |
| storeOp.getTarget(), plan); |
| if (!valueInterfaceBinding || !targetInterfaceBinding) { |
| return true; |
| } |
| // If the binding and offsets are the same, make sure that the |
| // !iree_tensor_ext.dispatch.tensor is read-write. |
| auto sourceType = dyn_cast<IREE::TensorExt::DispatchTensorType>( |
| valueInterfaceBinding.getType()); |
| return sourceType && |
| sourceType.getAccess() == IREE::TensorExt::TensorAccess::ReadWrite; |
| } |
| |
| /// Tries to add the `value` and `target` to the same equivalence class. |
| static LogicalResult |
| analyseInterfaceStoreTensorOp(IREE::TensorExt::DispatchTensorStoreOp storeOp, |
| BufferizationPlan &plan) { |
| // The value and target can be union-ed if the set the value is part of does |
| // not contain any hal.interface.binding.subspan from a different binding. |
| Value value = storeOp.getValue(); |
| Value target = storeOp.getTarget(); |
| if (!getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, |
| plan)) { |
| return storeOp.emitError( |
| "expected target of store op to already be added to an equivalence " |
| "set"); |
| } |
| if (canSetStoreValueAndTargetAsEquivalent(storeOp, plan)) { |
| plan.unionSets(value, target); |
| } else { |
| plan.insert(value); |
| } |
| plan.storeSet(target); |
| return success(); |
| } |
| |
| static LogicalResult analyseLoadFromBufferOp(IREE::Codegen::LoadFromBufferOp op, |
| BufferizationPlan &plan) { |
| plan.insert(op.getTensor()); |
| return success(); |
| } |
| |
| static LogicalResult analyseStoreToBufferOp(IREE::Codegen::StoreToBufferOp op, |
| BufferizationPlan &plan) { |
| plan.storeSet(op.getTensor()); |
| return success(); |
| } |
| |
| static LogicalResult |
| analyseInterfaceBindingSubspanOp(IREE::HAL::InterfaceBindingSubspanOp subspanOp, |
| BufferizationPlan &plan) { |
| if (isa<MemRefType>(subspanOp.getResult().getType())) { |
| return success(); |
| } |
| plan.insert(subspanOp.getResult()); |
| return success(); |
| } |
| |
| static LogicalResult analysePadTensorOp(tensor::PadOp padTensorOp, |
| BufferizationPlan &plan) { |
| plan.insert(padTensorOp.getSource()); |
| plan.insert(padTensorOp.getResult()); |
| return success(); |
| } |
| |
| /// For every result of the LinalgOp, gets the operands (`ins` or `outs`) whose |
| /// buffer can be reused for the result. |
| static SmallVector<Value> |
| getTiedOperandsForDPSOps(DestinationStyleOpInterface dpsOp, |
| const BufferizationPlan &plan) { |
| SmallVector<Value> tiedOperands(dpsOp.getOperation()->getNumResults()); |
| auto outputOperands = dpsOp.getDpsInits(); |
| for (auto [index, outTensor] : llvm::enumerate(outputOperands)) { |
| // If the `outs` tensor has a single use (this op) and is not from a |
| // read-only buffer, the `outs` tensor can be tied to the result. |
| if (outTensor.hasOneUse() && !isFromReadOnlyTensor(outTensor, plan)) { |
| tiedOperands[index] = outTensor; |
| } |
| } |
| return tiedOperands; |
| } |
| |
| /// Adds the corresponding `outs` and result tensors of the linalg op into the |
| /// same equivalence class. |
| static LogicalResult analyseDPSOps(DestinationStyleOpInterface dpsOp, |
| BufferizationPlan &plan) { |
| if (!dpsOp.hasPureTensorSemantics()) { |
| return success(); |
| } |
| auto results = dpsOp->getResults(); |
| auto tiedOperands = getTiedOperandsForDPSOps(dpsOp, plan); |
| if (tiedOperands.empty()) { |
| return failure(); |
| } |
| for (auto [index, resultTensor, tiedOperand] : llvm::zip_equal( |
| llvm::seq<int64_t>(0, results.size()), results, tiedOperands)) { |
| if (tiedOperand) { |
| plan.unionSets(resultTensor, tiedOperand); |
| } |
| plan.insert(dpsOp.getDpsInitOperand(index)->get()); |
| plan.insert(resultTensor); |
| } |
| return success(); |
| } |
| |
| /// Returns true if there is a single use of the `value` that is "real", |
| /// i.e. where the value itself is used, and not the type of the value. For |
| /// example, a use in a `memref.dim` is only looking at the type and not the |
| /// value. |
| static bool hasSingleRealUse(Value value) { |
| int numUsers = 0; |
| for (OpOperand &use : value.getUses()) { |
| if (!isa<memref::DimOp, tensor::DimOp>(use.getOwner())) { |
| numUsers++; |
| } |
| } |
| return numUsers == 1; |
| } |
| |
| /// For operations that have a single operand and result, adds both to the same |
| /// equivalence class. |
| static LogicalResult analyseSingleOperandResultOp(Value source, Value result, |
| BufferizationPlan &plan) { |
| if (hasSingleRealUse(source) || isFromReadOnlyTensor(source, plan)) { |
| plan.unionSets(source, result); |
| return success(); |
| } |
| plan.insert(source); |
| plan.insert(result); |
| return success(); |
| } |
| |
| static LogicalResult analyseSubTensorOp(tensor::ExtractSliceOp subTensorOp, |
| BufferizationPlan &plan) { |
| if (!canUsersHandleSubviews(subTensorOp)) { |
| plan.insert(subTensorOp.getSource()); |
| plan.insert(subTensorOp.getResult()); |
| return success(); |
| } |
| return analyseSingleOperandResultOp(subTensorOp.getSource(), |
| subTensorOp.getResult(), plan); |
| } |
| |
| /// Adds the `dest` and `result` tensor of a subtensor insert operation into the |
| /// same equivalence class. If `source` is not null also checks that the |
| /// `source` and `dest` are not equivalent. |
| static LogicalResult analyseDestructiveUpdateOp(Operation *op, Value source, |
| Value dest, Value result, |
| BufferizationPlan &plan) { |
| if (hasSingleRealUse(dest) && !isFromReadOnlyTensor(dest, plan)) { |
| plan.unionSets(dest, result); |
| } else if (source && plan.isEquivalent(source, dest)) { |
| // The destructive update pattern can put the source and dest in the same |
| // equivalence class, but that is checked explicitly later on. So at this |
| // stage this shouldnt happen. |
| return op->emitError( |
| "unexpected source and dest being equivalent in destructive update op"); |
| } |
| plan.insert(dest); |
| plan.insert(result); |
| return success(); |
| } |
| |
| static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { |
| if (!ifOp.getNumResults()) { |
| return success(); |
| } |
| for (auto [result, thenOperand, elseOperand] : |
| llvm::zip_equal(ifOp.getResults(), ifOp.thenYield().getOperands(), |
| ifOp.elseYield().getOperands())) { |
| if (!isa<RankedTensorType>(result.getType())) { |
| continue; |
| } |
| // All results and yields of the if-then-else are tied together. |
| plan.unionSets(result, thenOperand); |
| plan.unionSets(result, elseOperand); |
| } |
| return success(); |
| } |
| |
| static LogicalResult analyseScfForOp(scf::ForOp forOp, |
| BufferizationPlan &plan) { |
| if (forOp.getResults().empty()) { |
| return success(); |
| } |
| if (!llvm::all_of(forOp->getResultTypes(), [](Type resultType) { |
| return isa<RankedTensorType>(resultType); |
| })) { |
| return success(); |
| } |
| |
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| auto regionArgs = forOp.getRegionIterArgs(); |
| auto initArgs = forOp.getInitArgs(); |
| for (int i = 0; i < yieldOp.getResults().size(); ++i) { |
| Value yieldTensor = yieldOp.getResults()[i]; |
| Value resultTensor = forOp.getResults()[i]; |
| Value initArg = initArgs[i]; |
| Value arg = regionArgs[i]; |
| // Always tie the yield, the result tensor, and the region arg |
| plan.unionSets(yieldTensor, resultTensor); |
| plan.unionSets(yieldTensor, arg); |
| |
| // If the init value is not read-only and has single use, the tie the init |
| // and result (and by extension all 4 tensors here). |
| if (hasSingleRealUse(initArg) && !isFromReadOnlyTensor(initArg, plan)) { |
| plan.unionSets(initArg, resultTensor); |
| } |
| } |
| return success(); |
| } |
| |
| /// Look for destructive update loop pattern involving `source` using these |
| /// constraints |
| /// - single tensor.insert_slice operation where `source` is the `dest` operand. |
| /// - all `tensor.extract_slice` operations dominate the `tensor.insert_slice` |
| /// op. |
| static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { |
| auto isUpdateOp = |
| llvm::IsaPred<tensor::InsertSliceOp, vector::TransferWriteOp>; |
| auto isReadOp = llvm::IsaPred<tensor::ExtractSliceOp, vector::TransferReadOp>; |
| auto getDest = [](Operation *op) -> Value { |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { |
| return insertSliceOp.getDest(); |
| } |
| if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) { |
| return transferWriteOp.getBase(); |
| } |
| return nullptr; |
| }; |
| auto getSource = [](Operation *op) -> Value { |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { |
| return extractSliceOp.getSource(); |
| } |
| if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) { |
| return transferReadOp.getBase(); |
| } |
| return nullptr; |
| }; |
| // Source should have only one use that is a tensor::InsertSliceOp as a `dest` |
| // operand. |
| Operation *updateOp = nullptr; |
| for (OpOperand &use : source.getUses()) { |
| auto user = use.getOwner(); |
| // Process only update ops uses here. |
| if (!isUpdateOp(user)) { |
| continue; |
| } |
| // If this is not the first use in a tensor::InsertSliceOp abort. |
| if (updateOp) { |
| return; |
| } |
| // If the use is not the `dest` operand, abort. |
| Value dest = getDest(user); |
| assert(dest && "unable to get dest of update op"); |
| if (use.get() != dest) { |
| return; |
| } |
| if (isFromReadOnlyTensor(dest, plan)) { |
| return; |
| } |
| updateOp = user; |
| } |
| // Need to have one use of tensor::InsertSliceOp for destructive update |
| // pattern. |
| if (!updateOp) { |
| return; |
| } |
| |
| Block *updateOpBlock = updateOp->getBlock(); |
| for (OpOperand &use : source.getUses()) { |
| Operation *user = use.getOwner(); |
| if (user == updateOp) { |
| continue; |
| } |
| if (isReadOp(user)) { |
| Value source = getSource(user); |
| assert(source && "unable to find source from read op"); |
| if (source != use.get()) { |
| return; |
| } |
| // The read must dominate the insert op. For now just check its in the |
| // same block and before it. |
| if (user->getBlock() != updateOpBlock || |
| !user->isBeforeInBlock(updateOp)) { |
| return; |
| } |
| continue; |
| } else if (isa<scf::YieldOp, tensor::DimOp>(user)) { |
| continue; |
| } |
| // Unaccounted for use. Return without doing anything; |
| return; |
| } |
| |
| // Found destructive update pattern. Tie all the |
| // - extract_slice source and result |
| // - insert_slice value and dest |
| for (Operation *user : source.getUsers()) { |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) { |
| plan.unionSets(extractSliceOp.getSource(), extractSliceOp.getResult()); |
| continue; |
| } |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(user)) { |
| if (!isFromReadOnlyTensor(insertSliceOp.getSource(), plan)) { |
| plan.unionSets(insertSliceOp.getSource(), insertSliceOp.getDest()); |
| } |
| } |
| } |
| } |
| |
| void BufferizationPlan::unionSets(Value v1, Value v2) { |
| if (!canSetsBeMerged(v1, v2, *this)) { |
| return; |
| } |
| // If one the sets was part of the store set, the store set |
| // needs to be updated to drop the all leaders from the store set |
| // and add the new leader to it. |
| Value leader1 = getLeaderValue(v1); |
| Value leader2 = getLeaderValue(v2); |
| bool insertNewStoreLeader = |
| storeLeaders.count(leader1) || storeLeaders.count(leader2); |
| storeLeaders.erase(leader1); |
| storeLeaders.erase(leader2); |
| mappedTensors.unionSets(getPointer(v1), getPointer(v2)); |
| if (insertNewStoreLeader) { |
| storeLeaders.insert(getLeaderValue(v1)); |
| } |
| } |
| |
| void BufferizationPlan::dump() { |
| llvm::dbgs() << "BufferMappings : \n"; |
| unsigned numSets = 0; |
| for (auto it = mappedTensors.begin(), ie = mappedTensors.end(); it != ie; |
| ++it) { |
| if (!(*it)->isLeader()) { |
| continue; |
| } |
| llvm::dbgs() << "\tSet " << numSets; |
| if (storeLeaders.count( |
| getLeaderValue(getValue(*mappedTensors.member_begin(**it))))) { |
| llvm::dbgs() << "(StoreSet) "; |
| } |
| llvm::dbgs() << ":\n"; |
| for (auto member : llvm::make_range(mappedTensors.member_begin(**it), |
| mappedTensors.member_end())) { |
| llvm::dbgs() << "\t\t"; |
| getValue(member).print(llvm::dbgs()); |
| llvm::dbgs() << "\n"; |
| } |
| numSets++; |
| } |
| llvm::dbgs() << "StoreLeaders : \n"; |
| for (auto storeLeader : storeLeaders) { |
| storeLeader.print(llvm::dbgs()); |
| llvm::dbgs() << "\n"; |
| } |
| } |
| |
| LogicalResult createTensorEquivalenceClasses(mlir::FunctionOpInterface funcOp, |
| BufferizationPlan &plan) { |
| auto bufferMappingFn = [&](Operation *op) -> WalkResult { |
| return TypeSwitch<Operation *, LogicalResult>(op) |
| .Case([&](arith::ConstantOp constantOp) { |
| return analyseConstantOp(constantOp, plan); |
| }) |
| .Case([&](IREE::TensorExt::DispatchTensorLoadOp loadOp) { |
| return analyseInterfaceLoadTensorOp(loadOp, plan); |
| }) |
| .Case([&](IREE::TensorExt::DispatchTensorStoreOp storeOp) { |
| return analyseInterfaceStoreTensorOp(storeOp, plan); |
| }) |
| .Case([&](IREE::Codegen::LoadFromBufferOp loadOp) { |
| return analyseLoadFromBufferOp(loadOp, plan); |
| }) |
| .Case([&](IREE::Codegen::StoreToBufferOp storeOp) { |
| return analyseStoreToBufferOp(storeOp, plan); |
| }) |
| .Case([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) { |
| return analyseInterfaceBindingSubspanOp(subspanOp, plan); |
| }) |
| .Case([&](tensor::PadOp padTensorOp) { |
| return analysePadTensorOp(padTensorOp, plan); |
| }) |
| .Case([&](DestinationStyleOpInterface dpsOp) { |
| return analyseDPSOps(dpsOp, plan); |
| }) |
| .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>( |
| [&](auto reshapeOp) { |
| return analyseSingleOperandResultOp(reshapeOp.getSrc(), |
| reshapeOp.getResult(), plan); |
| }) |
| .Case([&](tensor::ExtractSliceOp sliceOp) { |
| return analyseSubTensorOp(sliceOp, plan); |
| }) |
| .Case([&](tensor::InsertSliceOp subTensorInsertOp) { |
| return analyseDestructiveUpdateOp( |
| subTensorInsertOp, subTensorInsertOp.getSource(), |
| subTensorInsertOp.getDest(), subTensorInsertOp.getResult(), plan); |
| }) |
| .Case([&](tensor::ParallelInsertSliceOp subTensorInsertOp) { |
| return analyseDestructiveUpdateOp( |
| subTensorInsertOp, subTensorInsertOp.getSource(), |
| subTensorInsertOp.getDest(), subTensorInsertOp.getTiedOpResult(), |
| plan); |
| }) |
| .Case([&](tensor::CastOp castOp) { |
| return analyseSingleOperandResultOp(castOp.getSource(), |
| castOp.getDest(), plan); |
| }) |
| .Case([&](tensor::InsertOp insertOp) { |
| return analyseDestructiveUpdateOp(insertOp, /*source =*/nullptr, |
| insertOp.getDest(), |
| insertOp.getResult(), plan); |
| }) |
| .Case([&](vector::TransferReadOp transferReadOp) { |
| if (isa<RankedTensorType>(transferReadOp.getBase().getType())) { |
| plan.insert(transferReadOp.getBase()); |
| } |
| return success(); |
| }) |
| .Case([&](vector::TransferWriteOp transferWriteOp) { |
| if (!isa<RankedTensorType>(transferWriteOp.getBase().getType())) { |
| return success(); |
| } |
| return analyseDestructiveUpdateOp(transferWriteOp, nullptr, |
| transferWriteOp.getBase(), |
| transferWriteOp.getResult(), plan); |
| }) |
| .Case([&](scf::IfOp ifOp) { return analyseScfIfOp(ifOp, plan); }) |
| .Case([&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); }) |
| .Case<scf::YieldOp, tensor::EmptyOp, tensor::DimOp, tensor::ExtractOp, |
| tensor::GenerateOp, tensor::PadOp, bufferization::ToBufferOp, |
| bufferization::AllocTensorOp>( |
| [&](Operation *op) { return success(); }) |
| .Default([&](Operation *op) -> LogicalResult { |
| if (llvm::any_of( |
| op->getOperands(), |
| [](Value v) { return isa<RankedTensorType>(v.getType()); }) || |
| llvm::any_of(op->getResultTypes(), |
| llvm::IsaPred<RankedTensorType>)) { |
| return op->emitOpError("unhandled tensor operation"); |
| } |
| return success(); |
| }); |
| }; |
| if (funcOp.walk<WalkOrder::PreOrder>(bufferMappingFn).wasInterrupted()) { |
| return failure(); |
| } |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After First walk "; |
| plan.dump(); |
| }); |
| |
| funcOp.walk([&](Operation *updateOp) { |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(updateOp)) { |
| hasDestructiveUpdatePattern(insertSliceOp.getDest(), plan); |
| return; |
| } |
| if (auto vectorWriteOp = dyn_cast<vector::TransferWriteOp>(updateOp)) { |
| if (isa<RankedTensorType>(vectorWriteOp.getBase().getType())) { |
| hasDestructiveUpdatePattern(vectorWriteOp.getBase(), plan); |
| } |
| } |
| }); |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After Destructive update walk "; |
| plan.dump(); |
| }); |
| |
| if (funcOp |
| .walk([&](IREE::TensorExt::DispatchTensorStoreOp storeOp) |
| -> WalkResult { |
| return analyseInterfaceStoreTensorOp(storeOp, plan); |
| }) |
| .wasInterrupted()) { |
| return failure(); |
| } |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << "After Store walk "; |
| plan.dump(); |
| }); |
| |
| return success(); |
| } |
| |
| } // namespace mlir::iree_compiler |