| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- BufferizationAnalysis.cpp - Pre bufferization analysis -------------===// |
| // |
| // Analysis to group together tensors within a dispatch region into an |
| // equivalance class such that all members of a set can be mapped to the same |
| // memory region. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "iree/compiler/Codegen/Common/BufferizationAnalysis.h" |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Analysis/SliceAnalysis.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| |
| #define DEBUG_TYPE "iree-codegen-bufferization-analysis" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| //===----------------------------------------------------------------------===// |
| // Analysis to compute equivalence sets. |
| // |
| // These functions compute the equivalence relationships between all tensors in |
| // the program. Two tensors are equivalent if they are to be mapped to the same |
| // buffer. For every operation, based on the operation semantics the result of |
| // the operation can reuse the buffer for an operand of the operation. This |
| // information is captured by adding these two tensors to the same equivalence |
| // class. Eventually the result of the dispatch tensor is added to some |
| // equivalence set. All tensors in that equivalence set can reuse the result |
| // buffer and compute the values in place. You can add tensors to equivalence |
| // set only if |
| // - They have a single use |
| // - They are derived from a read-only buffer. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| /// Check if all users of an op that lowers to a subview eventually can use the |
| /// subview when converted to buffers. For example `linalg.reshape` (which is |
| /// the buffer version of `linalg.tensor_reshape`) cannot handle subviews. |
| static bool canUsersHandleSubviews(Operation *op) { |
| // TODO(ravishankarm): Maybe this is too aggressive, might have to switch this |
| // to have a white-list instead of blacklist. |
| for (Operation *user : op->getUsers()) { |
| if (isa<IREE::Flow::DispatchTensorStoreOp, tensor::CollapseShapeOp, |
| tensor::ExpandShapeOp>(user)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /// Walks the use-def chain and see if this value comes from a read-only tensor. |
| static bool isFromReadOnlyTensor(Value v, const BufferizationPlan &plan) { |
| auto definingOp = v.getDefiningOp(); |
| if (!definingOp) { |
| auto arg = v.cast<BlockArgument>(); |
| return TypeSwitch<Operation *, bool>(arg.getOwner()->getParentOp()) |
| .Case<scf::ForOp>([&](scf::ForOp forOp) { |
| Value initOperand = forOp.getOpOperandForRegionIterArg(arg).get(); |
| if (plan.isEquivalent(arg, initOperand)) { |
| return isFromReadOnlyTensor(initOperand, plan); |
| } |
| return false; |
| }) |
| .Default([&](Operation *op) { return false; }); |
| } |
| return TypeSwitch<Operation *, bool>(definingOp) |
| .Case<arith::ConstantOp>( |
| [&](arith::ConstantOp constantOp) { return true; }) |
| .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>( |
| [&](auto op) { return isFromReadOnlyTensor(op.src(), plan); }) |
| .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) { |
| return isFromReadOnlyTensor(sliceOp.source(), plan); |
| }) |
| .Case<IREE::Flow::DispatchTensorLoadOp>( |
| [&](IREE::Flow::DispatchTensorLoadOp loadOp) { |
| return loadOp.source() |
| .getType() |
| .cast<IREE::Flow::DispatchTensorType>() |
| .getAccess() == IREE::Flow::TensorAccess::ReadOnly; |
| }) |
| .Default([&](Operation *op) { return false; }); |
| } |
| |
| /// Adds the result of `std.constant` to its set (there is nothing to tie to |
| /// here). |
| static LogicalResult analyseConstantOp(arith::ConstantOp constantOp, |
| BufferizationPlan &plan) { |
| if (!constantOp.getResult().getType().isa<ShapedType>()) return success(); |
| plan.insert(constantOp.getResult()); |
| return success(); |
| } |
| |
| /// Adds the result of the `flow.dispatch.tensor.load` op to the same |
| /// equivalence class as the source. |
| static LogicalResult analyseInterfaceLoadTensorOp( |
| IREE::Flow::DispatchTensorLoadOp loadOp, BufferizationPlan &plan) { |
| plan.unionSets(loadOp.result(), loadOp.source()); |
| return success(); |
| } |
| |
| /// Helper method to returns an operation of type `OpType` whose result is in |
| /// the same equivalence set as `value`. Returns an operation if there is only |
| /// one such op in the equivalence set or nullptr in all other cases. |
| template <typename OpType> |
| static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) { |
| OpType equivalentOp; |
| SmallVector<Value> mappedTensors = plan.getTensorsMappedToSameSet(value); |
| for (auto v : mappedTensors) { |
| auto definingOp = v.getDefiningOp<OpType>(); |
| if (!definingOp) continue; |
| assert((!equivalentOp || equivalentOp == definingOp) && |
| "found two interface binding ops marked as equivalent"); |
| if (!equivalentOp) equivalentOp = definingOp; |
| } |
| return equivalentOp; |
| } |
| |
| /// Returns true if the value and target of a `flow.dispatch.tensor.store` |
| /// operation can be added to the same equivalence set. This can be done only if |
| /// - The `value` is not from a equivalence set that contains a read-only |
| /// tensor. |
| /// - All `hal.interface.binding.subspan` operations in the equivalence class of |
| /// `value` and `target` have the same binding and offset. For now, it is |
| /// assumed that the equivalence classes contain only 1 such instruction. |
| /// This method asserts that the `target` equivalence class already contains a |
| /// `hal.interface.binding.subspan` op.' |
| static bool canSetStoreValueAndTargetAsEquivalent( |
| IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) { |
| Value value = storeOp.value(); |
| Value target = storeOp.target(); |
| auto targetInterfaceOp = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, plan); |
| assert(targetInterfaceOp); |
| if (auto valueConstantOp = |
| getEquivalentOpOfType<arith::ConstantOp>(value, plan)) { |
| return false; |
| } |
| if (auto valueInterfaceOp = |
| getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(value, |
| plan)) { |
| if (targetInterfaceOp.binding() != valueInterfaceOp.binding() || |
| targetInterfaceOp.byte_offset() != valueInterfaceOp.byte_offset()) { |
| // If the binding and offsets are different, map these to different |
| // memrefs. |
| return false; |
| } |
| // If the binding and offsets are the same, make sure that the |
| // !flow.dispatch.tensor is read-write. |
| auto sourceType = |
| valueInterfaceOp.getType().dyn_cast<IREE::Flow::DispatchTensorType>(); |
| return sourceType && |
| sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite; |
| } |
| return true; |
| } |
| |
| /// Tries to add the `value` and `target` to the same equivalence class. |
| static LogicalResult analyseInterfaceStoreTensorOp( |
| IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) { |
| // The value and target can be union-ed if the set the value is part of does |
| // not contain any hal.interface.binding.subspan from a different binding. |
| Value value = storeOp.value(); |
| Value target = storeOp.target(); |
| if (!getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, |
| plan)) { |
| return storeOp.emitError( |
| "expected target of store op to already be added to an equivalence " |
| "set"); |
| } |
| if (canSetStoreValueAndTargetAsEquivalent(storeOp, plan)) { |
| plan.unionSets(value, target); |
| } else { |
| plan.insert(value); |
| } |
| plan.storeSet(target); |
| return success(); |
| } |
| |
| static LogicalResult analyseInterfaceBindingSubspanOp( |
| IREE::HAL::InterfaceBindingSubspanOp subspanOp, BufferizationPlan &plan) { |
| plan.insert(subspanOp.getResult()); |
| return success(); |
| } |
| |
| static LogicalResult analysePadTensorOp(tensor::PadOp padTensorOp, |
| BufferizationPlan &plan) { |
| plan.insert(padTensorOp.source()); |
| plan.insert(padTensorOp.result()); |
| return success(); |
| } |
| |
| /// For every result of the LinalgOp, gets the operands (`ins` or `outs`) whose |
| /// buffer can be reused for the result. |
| static SmallVector<Value> getTiedOperandsForLinalgOps( |
| linalg::LinalgOp linalgOp, const BufferizationPlan &plan) { |
| SmallVector<Value> tiedOperands(linalgOp.getOperation()->getNumResults()); |
| auto outputOperands = linalgOp.getOutputOperands(); |
| for (auto outTensor : llvm::enumerate(outputOperands)) { |
| // If the `outs` tensor has a single use (this op) and is not from a |
| // read-only buffer, the `outs` tensor can be tied to the result. |
| if (outTensor.value()->get().hasOneUse() && |
| !isFromReadOnlyTensor(outTensor.value()->get(), plan)) { |
| tiedOperands[outTensor.index()] = outTensor.value()->get(); |
| } |
| } |
| return tiedOperands; |
| } |
| |
| static LogicalResult analyseLinalgExtOps(IREE::LinalgExt::LinalgExtOp op, |
| BufferizationPlan &plan) { |
| if (!op.hasTensorSemantics()) return success(); |
| // TODO(hanchung): Revisit if we can tie together op.getOutputOperands() with |
| // the corresponding op.getInputOperands(). For now we have limit LinalgExt |
| // ops, and there is no use case. So we ignore it. |
| // Note: this is what should be done for LinalgOps, except for a what is done |
| // for operand fusion today. |
| for (auto input : op.getInputOperands()) { |
| plan.insert(input->get()); |
| } |
| for (auto output : op.getOutputOperands()) { |
| plan.insert(output->get()); |
| } |
| for (auto result : op->getResults()) { |
| plan.insert(result); |
| } |
| return success(); |
| } |
| |
| /// Adds the corresponding `outs` and result tensors of the linalg op into the |
| /// same equivalence class. |
| static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp, |
| BufferizationPlan &plan) { |
| if (!linalgOp.hasTensorSemantics()) return success(); |
| auto results = linalgOp->getResults(); |
| auto tiedOperands = getTiedOperandsForLinalgOps(linalgOp, plan); |
| for (auto it : llvm::enumerate(llvm::zip(results, tiedOperands))) { |
| Value resultTensor = std::get<0>(it.value()); |
| Value tiedOperand = std::get<1>(it.value()); |
| if (tiedOperand) { |
| plan.unionSets(resultTensor, tiedOperand); |
| } |
| plan.insert(linalgOp.getOutputOperand(it.index())->get()); |
| plan.insert(resultTensor); |
| } |
| return success(); |
| } |
| |
| /// Returns true if there is a single use of the `value` that is "real", |
| /// i.e. where the value itself is used, and not the type of the value. For |
| /// example, a use in a `memref.dim` is only looking at the type and not the |
| /// value. |
| static bool hasSingleRealUse(Value value) { |
| int numUsers = 0; |
| for (OpOperand &use : value.getUses()) { |
| if (!isa<memref::DimOp, tensor::DimOp>(use.getOwner())) { |
| numUsers++; |
| } |
| } |
| return numUsers == 1; |
| } |
| |
| /// For operations that have a single operand and result, adds both to the same |
| /// equivalence class. |
| static LogicalResult analyseSingleOperandResultOp(Value source, Value result, |
| BufferizationPlan &plan) { |
| if (hasSingleRealUse(source) || isFromReadOnlyTensor(source, plan)) { |
| plan.unionSets(source, result); |
| return success(); |
| } |
| plan.insert(source); |
| plan.insert(result); |
| return success(); |
| } |
| |
| static LogicalResult analyseSubTensorOp(tensor::ExtractSliceOp subTensorOp, |
| BufferizationPlan &plan) { |
| if (!canUsersHandleSubviews(subTensorOp)) { |
| plan.insert(subTensorOp.source()); |
| plan.insert(subTensorOp.result()); |
| return success(); |
| } |
| return analyseSingleOperandResultOp(subTensorOp.source(), |
| subTensorOp.result(), plan); |
| } |
| |
| /// Adds the `dest` and `result` tensor of a subtensor insert operation into the |
| /// same equivalence class. If `source` is not null also checks that the |
| /// `source` and `dest` are not equivalent. |
| static LogicalResult analyseDestructiveUpdateOp(Operation *op, Value source, |
| Value dest, Value result, |
| BufferizationPlan &plan) { |
| if (hasSingleRealUse(dest) && !isFromReadOnlyTensor(dest, plan)) { |
| plan.unionSets(dest, result); |
| } else if (source && plan.isEquivalent(source, dest)) { |
| // The destructive update pattern can put the source and dest in the same |
| // equivalence class, but that is checked explicitly later on. So at this |
| // stage this shouldnt happen. |
| return op->emitError( |
| "unexpected source and dest being equivalent in destructive update op"); |
| } |
| plan.insert(dest); |
| plan.insert(result); |
| return success(); |
| } |
| |
| static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { |
| if (!ifOp.getNumResults()) return success(); |
| for (auto it : llvm::zip(ifOp.getResults(), ifOp.thenYield().getOperands(), |
| ifOp.elseYield().getOperands())) { |
| Value result = std::get<0>(it); |
| if (!result.getType().isa<RankedTensorType>()) continue; |
| // All results and yields of the if-then-else are tied together. |
| plan.unionSets(result, std::get<1>(it)); |
| plan.unionSets(result, std::get<2>(it)); |
| } |
| return success(); |
| } |
| |
| static LogicalResult analyseScfForOp(scf::ForOp forOp, |
| BufferizationPlan &plan) { |
| if (forOp.getResults().empty()) return success(); |
| if (!llvm::all_of(forOp->getResultTypes(), [](Type resultType) { |
| return resultType.isa<RankedTensorType>(); |
| })) { |
| return success(); |
| } |
| |
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| auto regionArgs = forOp.getRegionIterArgs(); |
| auto initArgs = forOp.getInitArgs(); |
| for (int i = 0; i < yieldOp.getResults().size(); ++i) { |
| Value yieldTensor = yieldOp.getResults()[i]; |
| Value resultTensor = forOp.getResults()[i]; |
| Value initArg = initArgs[i]; |
| Value arg = regionArgs[i]; |
| // Always tie the yield, the result tensor, and the region arg |
| plan.unionSets(yieldTensor, resultTensor); |
| plan.unionSets(yieldTensor, arg); |
| |
| // If the init value is not read-only and has single use, the tie the init |
| // and result (and by extension all 4 tensors here). |
| if (hasSingleRealUse(initArg) && !isFromReadOnlyTensor(initArg, plan)) { |
| plan.unionSets(initArg, resultTensor); |
| } |
| } |
| return success(); |
| } |
| |
| /// Look for destructive update loop pattern involving `source` using these |
| /// constraints |
| /// - single tensor.insert_slice operation where `source` is the `dest` operand. |
| /// - all `tensor.extract_slice` operations dominate the `tensor.insert_slice` |
| /// op. |
| static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { |
| auto isUpdateOp = [](Operation *op) { |
| return isa<tensor::InsertSliceOp, vector::TransferWriteOp>(op); |
| }; |
| auto isReadOp = [](Operation *op) { |
| return isa<tensor::ExtractSliceOp, vector::TransferReadOp>(op); |
| }; |
| auto getDest = [](Operation *op) -> Value { |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { |
| return insertSliceOp.dest(); |
| } |
| if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) { |
| return transferWriteOp.source(); |
| } |
| return nullptr; |
| }; |
| auto getSource = [](Operation *op) -> Value { |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { |
| return extractSliceOp.source(); |
| } |
| if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) { |
| return transferReadOp.source(); |
| } |
| return nullptr; |
| }; |
| // Source should have only one use that is a tensor::InsertSliceOp as a `dest` |
| // operand. |
| Operation *updateOp = nullptr; |
| for (OpOperand &use : source.getUses()) { |
| auto user = use.getOwner(); |
| // Process only update ops uses here. |
| if (!isUpdateOp(user)) continue; |
| // If this is not the first use in a tensor::InsertSliceOp abort. |
| if (updateOp) { |
| return; |
| } |
| // If the use is not the `dest` operand, abort. |
| Value dest = getDest(user); |
| assert(dest && "unable to get dest of update op"); |
| if (use.get() != dest) { |
| return; |
| } |
| if (isFromReadOnlyTensor(dest, plan)) { |
| return; |
| } |
| updateOp = user; |
| } |
| // Need to have one use of tensor::InsertSliceOp for destructive update |
| // pattern. |
| if (!updateOp) { |
| return; |
| } |
| |
| Block *updateOpBlock = updateOp->getBlock(); |
| for (OpOperand &use : source.getUses()) { |
| Operation *user = use.getOwner(); |
| if (user == updateOp) continue; |
| if (isReadOp(user)) { |
| Value source = getSource(user); |
| assert(source && "unable to find source from read op"); |
| if (source != use.get()) { |
| return; |
| } |
| // The read must dominate the insert op. For now just check its in the |
| // same block and before it. |
| if (user->getBlock() != updateOpBlock || |
| !user->isBeforeInBlock(updateOp)) { |
| return; |
| } |
| continue; |
| } else if (isa<scf::YieldOp, tensor::DimOp>(user)) { |
| continue; |
| } |
| // Unaccounted for use. Return without doing anything; |
| return; |
| } |
| |
| // Found destructive update pattern. Tie all the |
| // - extract_slice source and result |
| // - insert_slice value and dest |
| for (Operation *user : source.getUsers()) { |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) { |
| plan.unionSets(extractSliceOp.source(), extractSliceOp.result()); |
| continue; |
| } |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(user)) { |
| if (!isFromReadOnlyTensor(insertSliceOp.source(), plan)) { |
| plan.unionSets(insertSliceOp.source(), insertSliceOp.dest()); |
| } |
| } |
| } |
| } |
| |
| /// Ties together operands for operand fusion as exists today by reusing buffer |
| /// for the result for one of the inputs to do in-place update. Ideally we dont |
| /// need to do this if the fusion just happens at vector level. To be removed |
| /// when that is worked out and can be load-bearing. Conditions checked here are |
| /// 1) the result does not use the value of the `outs` buffer. |
| /// 2) the input has a single use (this op) and has the same indexing map as the |
| /// result. |
| /// 3) the input equivalence set does not have an interface binding, i.e. it is |
| /// not using a buffer from the dispatch ABI. |
| static void tieOperandsForOperandFusion(linalg::LinalgOp linalgOp, |
| BufferizationPlan &plan) { |
| for (auto result : enumerate(linalgOp.getOutputOperands())) { |
| if (linalgOp.payloadUsesValueFromOperand(result.value())) { |
| continue; |
| } |
| for (OpOperand *input : linalgOp.getInputTensorOperands()) { |
| Type inputElementType = |
| input->get().getType().cast<RankedTensorType>().getElementType(); |
| Type resultElementType = result.value() |
| ->get() |
| .getType() |
| .cast<RankedTensorType>() |
| .getElementType(); |
| if (input->get().hasOneUse() && (inputElementType == resultElementType) && |
| linalgOp.getTiedIndexingMap(input) == |
| linalgOp.getTiedIndexingMap(result.value()) && |
| !getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>( |
| input->get(), plan) && |
| !isFromReadOnlyTensor(input->get(), plan)) { |
| plan.unionSets(linalgOp->getResult(result.index()), input->get()); |
| break; |
| } |
| } |
| } |
| } |
| |
| void BufferizationPlan::dump() { |
| llvm::dbgs() << "BufferMappings : \n"; |
| unsigned numSets = 0; |
| for (auto it = mappedTensors.begin(), ie = mappedTensors.end(); it != ie; |
| ++it) { |
| if (!it->isLeader()) continue; |
| llvm::dbgs() << "\tSet " << numSets << ":\n"; |
| for (auto member : llvm::make_range(mappedTensors.member_begin(it), |
| mappedTensors.member_end())) { |
| llvm::dbgs() << "\t\t"; |
| getValue(member).print(llvm::dbgs()); |
| llvm::dbgs() << "\n"; |
| } |
| numSets++; |
| } |
| } |
| |
| LogicalResult createTensorEquivalenceClasses(FuncOp funcOp, |
| BufferizationPlan &plan) { |
| auto bufferMappingFn = [&](Operation *op) -> WalkResult { |
| return TypeSwitch<Operation *, LogicalResult>(op) |
| .Case<arith::ConstantOp>([&](arith::ConstantOp constantOp) { |
| return analyseConstantOp(constantOp, plan); |
| }) |
| .Case<IREE::Flow::DispatchTensorLoadOp>( |
| [&](IREE::Flow::DispatchTensorLoadOp loadOp) { |
| return analyseInterfaceLoadTensorOp(loadOp, plan); |
| }) |
| .Case<IREE::Flow::DispatchTensorStoreOp>( |
| [&](IREE::Flow::DispatchTensorStoreOp storeOp) { |
| return analyseInterfaceStoreTensorOp(storeOp, plan); |
| }) |
| .Case<IREE::HAL::InterfaceBindingSubspanOp>( |
| [&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) { |
| return analyseInterfaceBindingSubspanOp(subspanOp, plan); |
| }) |
| .Case<tensor::PadOp>([&](tensor::PadOp padTensorOp) { |
| return analysePadTensorOp(padTensorOp, plan); |
| }) |
| .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { |
| return analyseLinalgOps(linalgOp, plan); |
| }) |
| .Case<IREE::LinalgExt::LinalgExtOp>( |
| [&](IREE::LinalgExt::LinalgExtOp linalgExtOp) { |
| return analyseLinalgExtOps(linalgExtOp, plan); |
| }) |
| .Case<tensor::CollapseShapeOp, tensor::ExpandShapeOp>( |
| [&](auto reshapeOp) { |
| return analyseSingleOperandResultOp(reshapeOp.src(), |
| reshapeOp.result(), plan); |
| }) |
| .Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp sliceOp) { |
| return analyseSubTensorOp(sliceOp, plan); |
| }) |
| .Case<tensor::InsertSliceOp>( |
| [&](tensor::InsertSliceOp subTensorInsertOp) { |
| return analyseDestructiveUpdateOp( |
| subTensorInsertOp, subTensorInsertOp.source(), |
| subTensorInsertOp.dest(), subTensorInsertOp.result(), plan); |
| }) |
| .Case<tensor::CastOp>([&](tensor::CastOp castOp) { |
| return analyseSingleOperandResultOp(castOp.source(), castOp.dest(), |
| plan); |
| }) |
| .Case<tensor::InsertOp>([&](tensor::InsertOp insertOp) { |
| return analyseDestructiveUpdateOp(insertOp, /*source =*/nullptr, |
| insertOp.dest(), insertOp.result(), |
| plan); |
| }) |
| .Case<vector::TransferReadOp>( |
| [&](vector::TransferReadOp transferReadOp) { |
| if (transferReadOp.source().getType().isa<RankedTensorType>()) { |
| plan.insert(transferReadOp.source()); |
| } |
| return success(); |
| }) |
| .Case<vector::TransferWriteOp>( |
| [&](vector::TransferWriteOp transferWriteOp) { |
| if (!transferWriteOp.result().getType().isa<RankedTensorType>()) { |
| return success(); |
| } |
| return analyseDestructiveUpdateOp(transferWriteOp, nullptr, |
| transferWriteOp.source(), |
| transferWriteOp.result(), plan); |
| }) |
| .Case<scf::IfOp>( |
| [&](scf::IfOp ifOp) { return analyseScfIfOp(ifOp, plan); }) |
| .Case<scf::ForOp>( |
| [&](scf::ForOp forOp) { return analyseScfForOp(forOp, plan); }) |
| .Case<scf::YieldOp, linalg::InitTensorOp, tensor::DimOp, |
| tensor::ExtractOp, tensor::PadOp>( |
| [&](Operation *op) { return success(); }) |
| .Default([&](Operation *op) -> LogicalResult { |
| if (llvm::any_of(op->getOperands(), |
| [](Value v) { |
| return v.getType().isa<RankedTensorType>(); |
| }) || |
| llvm::any_of(op->getResultTypes(), |
| [](Type t) { return t.isa<RankedTensorType>(); })) { |
| return op->emitOpError("unhandled tensor operation"); |
| } |
| return success(); |
| }); |
| }; |
| if (funcOp.walk<WalkOrder::PreOrder>(bufferMappingFn).wasInterrupted()) { |
| return failure(); |
| } |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "After First walk "; |
| plan.dump(); |
| }); |
| |
| funcOp.walk([&](Operation *updateOp) { |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(updateOp)) { |
| hasDestructiveUpdatePattern(insertSliceOp.dest(), plan); |
| return; |
| } |
| if (auto vectorWriteOp = dyn_cast<vector::TransferWriteOp>(updateOp)) { |
| if (vectorWriteOp.source().getType().isa<RankedTensorType>()) { |
| hasDestructiveUpdatePattern(vectorWriteOp.source(), plan); |
| } |
| } |
| }); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "After Destructive update walk "; |
| plan.dump(); |
| }); |
| |
| // Tie operands to allow for operand fusion support. To be dropped once the |
| // operand fusion is generalized in IREE. |
| funcOp.walk([&](linalg::LinalgOp linalgOp) { |
| return tieOperandsForOperandFusion(linalgOp, plan); |
| }); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "After union for supporting operand fusion"; |
| plan.dump(); |
| }); |
| |
| if (funcOp |
| .walk([&](IREE::Flow::DispatchTensorStoreOp storeOp) -> WalkResult { |
| return analyseInterfaceStoreTensorOp(storeOp, plan); |
| }) |
| .wasInterrupted()) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |