Changing stream conversion to use a value/op affinity analysis. This reworks some of the prior stack to support transfer ops and analysis to determine the placement of ops for execution and resource control.
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index b8a630d..553dbde 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -124,9 +124,17 @@ FENCE, }; +struct BarrierResult { + BlockArgument storage; + Type torchType; + int returnIndex = -1; +}; + struct ConvertedAsyncFunctionInfo { IREE::Util::FuncOp funcOp; SmallVector<IREE::Util::ReturnOp> returnOps; + SmallVector<DictionaryAttr> torchArgAttrs; + SmallVector<DictionaryAttr> torchResultAttrs; SmallVector<Type> torchInputTypes; SmallVector<Type> torchResultTypes; SmallVector<TypeDisposition> inputDispositions; @@ -136,7 +144,7 @@ // Values that must be captured in the coarse barrier. SmallVector<Value> barrierInputs; // Meta data per barrier input: storage, torchType, returnIndex (or -1) - SmallVector<std::tuple<Value, Type, int>> barrierResultMeta; + SmallVector<BarrierResult> barrierResultMeta; LogicalResult postProcess(); LogicalResult convertImmutableTensorArg(BlockArgument argValue, @@ -144,10 +152,25 @@ LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType, OpBuilder &builder); - void addBarrierInput(Value inputTensor, Value storage, Type torchType, + void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType, int returnIndex) { barrierInputs.push_back(inputTensor); - barrierResultMeta.emplace_back(storage, torchType, returnIndex); + barrierResultMeta.emplace_back(BarrierResult{ + storage, + torchType, + returnIndex, + }); + } + + Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) { + return torchArgAttrs.empty() + ? Attribute{} + : torchArgAttrs[argValue.getArgNumber()].get(attrName); + } + Attribute getTorchResultAttr(int returnIndex, StringRef attrName) { + return torchResultAttrs.empty() + ? Attribute{} + : torchResultAttrs[returnIndex].get(attrName); } }; @@ -232,7 +255,8 @@ } if (needsBarrier) { Value source = convertToBuiltinTensor(postambleBuilder, returnValue); - addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex); + addBarrierInput(source, /*storage=*/BlockArgument{}, torchType, + returnIndex); } break; } @@ -276,15 +300,13 @@ SmallVector<Value> aliasedResults; for (auto [barrierInput, meta] : llvm::zip_equal(barrierInputs, barrierResultMeta)) { - Value exportStorage; - Type torchType; - int returnIndex; - std::tie(exportStorage, torchType, returnIndex) = meta; - if (exportStorage) { + if (meta.storage) { // Use the wait fence indicating when the storage is available for // mutation. We need to ensure that no writes are made to the storage // until it indicates it's safe to do so. - auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage); + auto storageAffinityAttr = + getTorchArgAttr(meta.storage, "iree.abi.affinity"); + auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage); assert(waitSignalFences && "async function missing fences"); Value waitFence = waitSignalFences->first; auto barrierInputDims = IREE::Util::buildDynamicDimsForValue( @@ -292,28 +314,30 @@ aliasedResults.push_back( postambleBuilder.create<IREE::HAL::TensorAliasOp>( barrierInput.getLoc(), barrierInput.getType(), barrierInput, - barrierInputDims, exportStorage, waitFence, - /*affinity=*/nullptr)); + barrierInputDims, meta.storage, waitFence, + storageAffinityAttr)); } else { aliasedResults.push_back(barrierInput); } } auto barrierOp = postambleBuilder.create<IREE::HAL::TensorBarrierOp>( - funcOp.getLoc(), aliasedResults, coarseSignalFence, - /*affinity=*/nullptr); + funcOp.getLoc(), aliasedResults, coarseSignalFence); for (auto [barrierResult, meta] : llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) { - Value exportStorage; - Type torchType; - int returnIndex; - std::tie(exportStorage, torchType, returnIndex) = meta; + Attribute exportAffinityAttr; + if (meta.storage) { + exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity"); + } else if (meta.returnIndex >= 0) { + exportAffinityAttr = + getTorchResultAttr(meta.returnIndex, "iree.abi.affinity"); + } Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>( funcOp.getLoc(), postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult, TypeAttr::get(barrierResult.getType()), /*name=*/nullptr, - /*affinity=*/nullptr); - if (returnIndex >= 0) { - newReturnOperands[returnIndex] = exportedValue; + exportAffinityAttr); + if (meta.returnIndex >= 0) { + newReturnOperands[meta.returnIndex] = exportedValue; } } } @@ -377,14 +401,16 @@ << torchType; } + // Propagate explicit affinities to the read. + auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity"); + auto waitSignalFences = getEnclosingWaitSignalFences(argValue); assert(waitSignalFences && "async function missing fences"); Value waitFence = waitSignalFences->first; Value importedTensor = builder.create<IREE::HAL::TensorImportOp>( loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType), waitFence, - /*name=*/nullptr, - /*affinity=*/nullptr); + /*name=*/nullptr, affinityAttr); if (builtinTensorType != torchType) { importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>( loc, torchType, importedTensor); @@ -408,6 +434,9 @@ .toBuiltinTensor(); } + // Propagate explicit affinities to the read and write. + auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity"); + // There are only a small set of possible users of a mutable tensor. // Handle them by operation here. SmallVector<Operation *> users(argValue.getUsers()); @@ -419,8 +448,7 @@ loc, builtinTensorType, argValue, /*target_encoding=*/TypeAttr::get(builtinTensorType), /*wait_fence*/ fences->first, - /*name=*/nullptr, - /*affinity=*/nullptr); + /*name=*/nullptr, affinityAttr); rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>( userOp, copyToVtOp.getResult().getType(), imported); } else if (auto overwriteOp = @@ -444,7 +472,6 @@ // Allowlist of function attributes to retain when importing funcs. constexpr const char *kRetainedAttributes[] = { "iree.reflection", - "stream.affinity", }; auto retainedAttributes = ArrayRef<const char *>( kRetainedAttributes, @@ -476,6 +503,9 @@ syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr()); retainFunctionAttributes(asyncFuncOp, syncFuncOp); syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr()); + if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) { + syncFuncOp->setAttr("iree.abi.affinity", affinityAttr); + } Block *entryBlock = syncFuncOp.addEntryBlock(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(entryBlock); @@ -584,6 +614,10 @@ asyncFunctionName.append("$async"); } + // Stash arg/result attrs so they can be referenced during conversion. + torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs); + torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs); + // Convert function signature. Type fenceType = rewriter.getType<IREE::HAL::FenceType>(); FunctionType torchFuncType = torchFunc.getFunctionType(); @@ -644,6 +678,9 @@ asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr()); asyncFuncOp->setAttr("iree.abi.model", rewriter.getStringAttr("coarse-fences")); + if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) { + asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr); + } rewriter.inlineRegionBefore( torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());
diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir index 3e167ad..3bca01b 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
@@ -111,6 +111,37 @@ } // ----- +// Tests the immutable + mutable argument case with explicit affinities. +// CHECK-LABEL: @mutable_input_overwrite_no_return +// CHECK: util.func public @main$async( +// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view, +// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view +// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0 +// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]] +// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1 +// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]] +// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]]) +// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]]) +// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]] +// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]] +// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence +// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1 +// CHECK: util.return %[[EXPORT_RESULT1]] +builtin.module @mutable_input_overwrite_no_return_affinities { +func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}, + %arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>}) + -> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) { + %0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32> + %1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32> + %2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32> + torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32> + return %2 : !torch.vtensor<[4,5],si32> +} +} + +// ----- // CHECK-LABEL: @retained_attribute_reflection // CHECK: util.func public @main$async( // CHECK-SAME: iree.reflection = {some.attr = 4 : index}
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index e91bc54..577cf52 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -43,16 +43,22 @@ return type; } +// Returns true if the given |attr| is a known ABI attribute that is only used +// by this pass. +static bool isABIAttr(NamedAttribute attr) { + return attr.getName() == "iree.abi.affinity" || + attr.getName() == "iree.abi.encoding" || + attr.getName() == "iree.abi.model" || + attr.getName() == "iree.abi.output"; +} + // Removes all ABI attrs handled by this pass from all dictionaries. static void stripABIAttrs(SmallVectorImpl<DictionaryAttr> &allAttrs) { for (auto &attrDict : allAttrs) { SmallVector<NamedAttribute> attrs; attrs.reserve(attrDict.size()); for (auto attr : attrDict) { - // TODO(benvanik): faster lookup. - if (attr.getName() != "iree.abi.output" && - attr.getName() != "iree.abi.encoding" && - attr.getName() != "iree.abi.affinity") { + if (!isABIAttr(attr)) { attrs.push_back(attr); } } @@ -60,7 +66,16 @@ } } +// Removes all ABI attrs from the |op| and its args/results. static void stripABIAttrs(FunctionOpInterface op) { + NamedAttrList attrs; + for (auto attr : op->getAttrs()) { + if (!isABIAttr(attr)) { + attrs.push_back(attr); + } + } + op->setAttrs(attrs); + SmallVector<DictionaryAttr> argAttrs; op.getAllArgAttrs(argAttrs); stripABIAttrs(argAttrs); @@ -71,6 +86,11 @@ op.setAllResultAttrs(resultAttrs); } +template <typename T> +static T fallback(T optionalValue, T defaultValue) { + return optionalValue ? optionalValue : defaultValue; +} + // Creates the corresponding wrapper function for the given import function. static IREE::Util::FuncOp createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, @@ -101,12 +121,7 @@ argAttrDict.push_back(nullptr); // signal break; } - - // Update the import type and propagate back the attributes we may have - // modified above. importOp.setType(newImportType); - importOp.setAllArgAttrs(argAttrDict); - importOp.setAllResultAttrs(resultAttrDict); auto *entryBlock = wrapperOp.addEntryBlock(); auto entryBuilder = OpBuilder::atBlockBegin(entryBlock); @@ -129,6 +144,12 @@ // users mark their functions 'nosideeffects' to avoid the host wait. const bool hasSideEffects = !importOp->hasAttr("nosideeffects"); + // Fetch and normalize any explicitly assigned affinity. + auto defaultAffinityAttr = importOp->getAttr("iree.abi.affinity"); + if (defaultAffinityAttr) { + importOp->setAttr("stream.affinity", defaultAffinityAttr); + } + // When running async we insert a barrier on tensor arguments and attach that // to the fence we pass to the import for waiting. We'll also allocate the // signal fence that the import must signal when the returned tensors are @@ -141,15 +162,24 @@ // No fences. break; case IREE::ABI::InvocationModel::CoarseFences: { - // HACK: this is relying on the fact that there's only one HAL device. - // We should instead have a way of creating fences on the device that - // is used to produce the tensors we're wrapping. - // - // TODO(multi-device): emit get with derived ordinal or lookup with attr. We - // could always say device 0 for now but could instead look for an - // iree.abi.affinity/iree.abi.device/etc. - Value device = - IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder); + Value device; + // TODO(benvanik): support other affinity types. + if (auto deviceAffinityAttr = + dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>( + defaultAffinityAttr)) { + device = entryBuilder + .create<IREE::HAL::DeviceResolveOp>( + importOp.getLoc(), + entryBuilder.getType<IREE::HAL::DeviceType>(), + deviceAffinityAttr) + .getResult(0); + } else { + // HACK: if no devices are available we get the first one available at + // runtime. This is suboptimal but we expect most usage to have affinities + // assigned prior to ABI conversion. + device = + IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder); + } // When exporting a fence we need to put a barrier between the rest of the // program and the tensors consumed by the import. @@ -162,7 +192,7 @@ importOp.getLoc(), entryBuilder.getType<IREE::HAL::FenceType>(), device, IREE::HAL::FenceFlagBitfield::None); auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>( - importOp.getLoc(), tensorArgs, waitFence, /*affinity=*/nullptr); + importOp.getLoc(), tensorArgs, waitFence); for (auto [argIndex, readyArg] : llvm::zip_equal(tensorArgIndices, barrierOp.getResults())) { entryArgs[argIndex] = readyArg; @@ -203,9 +233,10 @@ importOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding"); auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>( arg.getLoc(), newType, arg, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), + fallback(encodingAttr, TypeAttr::get(oldType)), /*name=*/nullptr, - /*affinity=*/nullptr); + fallback(importOp.getArgAttr(argIndex, "iree.abi.affinity"), + defaultAffinityAttr)); arguments.push_back(tensorExportOp.getTarget()); } else { arguments.push_back(arg); @@ -245,9 +276,10 @@ resultIndex, "iree.abi.encoding"); auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>( importOp.getLoc(), oldType, result, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), signalFence, + fallback(encodingAttr, TypeAttr::get(oldType)), signalFence, /*name=*/nullptr, - /*affinity=*/nullptr); + fallback(importOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); results.push_back(tensorImportOp); } else { results.push_back(result); @@ -255,6 +287,9 @@ } entryBuilder.create<IREE::Util::ReturnOp>(importOp.getLoc(), results); + + stripABIAttrs(importOp); + return wrapperOp; } @@ -518,8 +553,11 @@ // Populate the reflection attrs based on the original types. populateReflectionAttrs(invocationModel, exportOp, wrapperOp); exportOp->removeAttr("iree.reflection"); - if (auto affinityAttr = exportOp->getAttr("stream.affinity")) { - wrapperOp->setAttr("stream.affinity", affinityAttr); + + // Fetch and normalize any explicitly assigned affinity. + auto defaultAffinityAttr = exportOp->getAttr("iree.abi.affinity"); + if (defaultAffinityAttr) { + exportOp->setAttr("stream.affinity", defaultAffinityAttr); } auto *entryBlock = wrapperOp.addEntryBlock(); @@ -572,12 +610,13 @@ if (llvm::isa<TensorType>(oldType)) { auto encodingAttr = exportOp.getArgAttrOfType<TypeAttr>(argIndex, "iree.abi.encoding"); + auto argName = inferArgumentName(entryBuilder.getContext(), argIndex, + exportOp.getArgAttrDict(argIndex)); auto tensorImportOp = entryBuilder.create<IREE::HAL::TensorImportOp>( arg.getLoc(), oldType, arg, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), waitFence, - inferArgumentName(entryBuilder.getContext(), argIndex, - exportOp.getArgAttrDict(argIndex)), - exportOp.getArgAttr(argIndex, "iree.abi.affinity")); + fallback(encodingAttr, TypeAttr::get(oldType)), waitFence, argName, + fallback(exportOp.getArgAttr(argIndex, "iree.abi.affinity"), + defaultAffinityAttr)); arguments.push_back(tensorImportOp.getTarget()); } else { arguments.push_back(arg); @@ -601,7 +640,8 @@ auto aliasOp = entryBuilder.create<IREE::HAL::TensorAliasOp>( exportOp.getLoc(), source.getType(), source, sourceDims, resultStorages[resultIndex], waitFence, - exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); + fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); asyncResults[resultIndex] = cast<OpResult>(aliasOp.getResult()); } @@ -622,7 +662,7 @@ signalFence); } else { auto barrierOp = entryBuilder.create<IREE::HAL::TensorBarrierOp>( - exportOp.getLoc(), asyncTensors, signalFence, /*affinity=*/nullptr); + exportOp.getLoc(), asyncTensors, signalFence); asyncResults = llvm::to_vector(barrierOp.getResults()); } } @@ -635,15 +675,17 @@ if (llvm::isa<TensorType>(oldType)) { auto encodingAttr = exportOp.getResultAttrOfType<TypeAttr>( resultIndex, "iree.abi.encoding"); + auto resultName = + inferResultName(entryBuilder.getContext(), resultIndex, + exportOp.getResultAttrDict(resultIndex)); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( result.getLoc(), result, entryBuilder); auto tensorExportOp = entryBuilder.create<IREE::HAL::TensorExportOp>( result.getLoc(), newType, result, - encodingAttr ? encodingAttr : TypeAttr::get(result.getType()), - dynamicDims, - inferResultName(entryBuilder.getContext(), resultIndex, - exportOp.getResultAttrDict(resultIndex)), - exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); + fallback(encodingAttr, TypeAttr::get(result.getType())), dynamicDims, + resultName, + fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); results.push_back(tensorExportOp); } else { results.push_back(result); @@ -731,12 +773,15 @@ exportOps.push_back(funcOp); } } + if (importOps.empty() && exportOps.empty()) { + return; // no-op + } SymbolTable symbolTable(moduleOp); // Create a wrapper function for each imported function. - // This will preserve the internal types (tensors/etc) but change the import - // to taking the ABI types and rewrite calls. + // This will preserve the internal types (tensors/etc) but change the + // import to taking the ABI types and rewrite calls. for (auto importOp : importOps) { if (failed(wrapImportFunc(getInvocationModel(importOp, invocationModel), moduleOp, importOp, symbolTable))) {
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir index 3780ee2..72a0441 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -181,6 +181,28 @@ // ----- +// Tests that explicit import affinity specification is carried through to +// the marshaling ops. + +// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view) -> !hal.buffer_view +util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_b>}) + +// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[ARG_TENSOR]] : tensor<2xi32> -> !hal.buffer_view +// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32> +// CHECK: util.return %[[RET_TENSOR]] +// CHECK: } + +// CHECK: util.func private @pinnedCaller(%arg0: tensor +util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + util.return %0 : tensor<2xi32> +} + +// ----- + // Tests that imports with encodings specified are propagated to the HAL ops. // CHECK-LABEL: util.func private @importEncodings(%arg0: !hal.buffer_view) -> !hal.buffer_view @@ -188,10 +210,10 @@ // CHECK: util.func private @_importEncodings(%[[ARG_TENSOR:.+]]: tensor<?x2xi32>) -> tensor<2x?xi32> { // CHECK: %[[ARG_DIM:.+]] = tensor.dim %[[ARG_TENSOR]], %c0 -// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view +// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor<?x2xf32> as tensor<?x2xi32>{%[[ARG_DIM]]} -> !hal.buffer_view // CHECK: %[[RET_VIEW:.+]] = util.call @importEncodings(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view // CHECK: %[[RET_DIM:.+]] = hal.buffer_view.dim<%[[RET_VIEW]] : !hal.buffer_view>[1] -// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xi32>{%[[RET_DIM]]} +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xf32> as tensor<2x?xi32>{%[[RET_DIM]]} // CHECK: util.return %[[RET_TENSOR]] // CHECK: }
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir index 4505a54..f9af505 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir
@@ -170,6 +170,40 @@ // ----- +// Tests that explicit import affinity specification is carried through to +// the marshaling ops. + +util.global private @dev_a : !hal.device +util.global private @dev_b : !hal.device +util.global private @dev_c : !hal.device + +// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view +util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_b>}) attributes { + iree.abi.affinity = #hal.device.affinity<@dev_c>, + iree.abi.model = "coarse-fences", + nosideeffects +} + +// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK-DAG: %[[DEVICE_C:.+]] = hal.device.resolve on(<@dev_c>) : !hal.device +// CHECK-DAG: %[[ARG_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence +// CHECK-DAG: %[[ARG_READY:.+]] = hal.tensor.barrier join(%[[ARG_TENSOR]] : tensor<2xi32>) => %[[ARG_FENCE]] : !hal.fence +// CHECK-DAG: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.affinity<@dev_a>) %[[ARG_READY]] : tensor<2xi32> -> !hal.buffer_view +// CHECK-DAG: %[[RESULT_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence +// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]], %[[ARG_FENCE]], %[[RESULT_FENCE]]) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.affinity<@dev_b>) wait(%[[RESULT_FENCE]]) => %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32> +// CHECK: util.return %[[RET_TENSOR]] +// CHECK: } + +// CHECK: util.func private @pinnedCaller(%arg0: tensor +util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + util.return %0 : tensor<2xi32> +} + +// ----- + // Tests a side-effect-free import that doesn't take/return reference types. // CHECK-LABEL: util.func private @importI32(%arg0: i32, %arg1: !hal.fence, %arg2: !hal.fence) -> i32
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel index 98da9d9..fb3b309 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -64,6 +64,7 @@ "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//runtime/src/iree/builtins/ukernel:exported_bits", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt index 7236121..0b79e7d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -86,6 +86,7 @@ iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Stream::Analysis PUBLIC )
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index c107bc7..4edaffa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp
@@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" @@ -547,6 +548,35 @@ return success(); } +// Returns the executable targets used within |funcOp|. +// +// TODO(multi-device): delete this pass and rely on tensor-based analysis to +// materialize encodings based on where tensors are used. This pass is not able +// to handle that. +static std::optional<SetVector<IREE::HAL::ExecutableTargetAttr>> +getFuncExecutableTargetAttrs(FunctionOpInterface funcOp, + IREE::Stream::AffinityAnalysis &affinityAnalysis, + IREE::HAL::DeviceAnalysis &deviceAnalysis) { + // Get a set of all unique affinities used by resources within the function. + SetVector<IREE::Stream::AffinityAttr> uniqueAffinityAttrs; + SmallVector<IREE::Stream::AffinityAttr> lookupAffinityAttrs; + funcOp.walk([&](Operation *op) { + if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) { + uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(), + lookupAffinityAttrs.end()); + } + lookupAffinityAttrs.clear(); + }); + + // Resolve affinities to executable targets. + SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; + for (auto affinityAttr : uniqueAffinityAttrs) { + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp, + executableTargetAttrs); + } + return executableTargetAttrs; +} + struct CPUMaterializeHostEncodingPass : public CPUMaterializeHostEncodingBase<CPUMaterializeHostEncodingPass> { CPUMaterializeHostEncodingPass() = default; @@ -560,23 +590,36 @@ auto moduleOp = getOperation(); // Run required analysis passes. - IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { return signalPassFailure(); + } + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) { // Gather the required executable targets for the function. Note that it's // possible there are more required for ops nested within the function but // this pass is a hack and can't handle that :shrug:. - SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + auto executableTargets = getFuncExecutableTargetAttrs( + funcOp, affinityAnalysis, deviceAnalysis); + if (!executableTargets) { + funcOp.emitOpError() + << "could not determine executable targets for the function"; + return signalPassFailure(); + } else if (executableTargets->empty()) { + // Probably no tensors. + continue; + } // HACK: this pass is run on the host _but shouldn't be_. Because it's // run on the host and IREE is a compiler capable of multi-targeting there // may be multiple executable targets at any point in the host program. // This pass can't handle that and assumes it's been checked earlier by // spooky action at a distance. This needs to be fixed. - if (executableTargets.size() != 1) { + if (executableTargets->size() != 1) { funcOp.emitOpError() << "has multiple executable targets and CPU data " "tiling isn't built to support that"; return signalPassFailure(); @@ -584,7 +627,7 @@ // Materialize encodings within the function. if (failed( - materializeFuncOpEncodings(funcOp, executableTargets.front()))) { + materializeFuncOpEncodings(funcOp, executableTargets->front()))) { return signalPassFailure(); } } @@ -636,22 +679,35 @@ auto moduleOp = getOperation(); // Run required analysis passes. - IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { return signalPassFailure(); + } + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) { // Gather the required executable targets for the function. Note that it's // possible there are more required for ops nested within the function but // this pass is a hack and can't handle that :shrug:. - SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + auto executableTargets = getFuncExecutableTargetAttrs( + funcOp, affinityAnalysis, deviceAnalysis); + if (!executableTargets) { + funcOp.emitOpError() + << "could not determine executable targets for the function"; + return signalPassFailure(); + } else if (executableTargets->empty()) { + // Probably no tensors. + continue; + } // Get patterns specialized for the executable targets used by the // function. RewritePatternSet patterns(&getContext()); MaterializeEncodingFn materializeEncodingFn = - getUpperBoundMaterializeEncodingFn(executableTargets.getArrayRef()); + getUpperBoundMaterializeEncodingFn(executableTargets->getArrayRef()); if (!materializeEncodingFn) return signalPassFailure(); populateMaterializeUpperBoundTileSizePatterns(patterns,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp index f762a1f..fc26729 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -211,6 +211,15 @@ } } +bool BindingLayoutAnalysis::hasDispatches() const { + for (auto &it : exportInfos) { + if (!it.second->dispatchOps.empty()) { + return true; // found at least one dispatch + } + } + return false; +} + ArrayRef<IREE::Stream::CmdDispatchOp> BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const { auto it = exportInfos.find(exportOp);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h index 050e18e..7d08959 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
@@ -59,6 +59,9 @@ public: explicit BindingLayoutAnalysis(Operation *rootOp, SymbolTable &symbolTable); + // Whether there are any dispatches in the program. + bool hasDispatches() const; + // Returns all of the dispatches to the given executable export. ArrayRef<IREE::Stream::CmdDispatchOp> getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp index a17a100..705292b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp
@@ -15,6 +15,92 @@ namespace { +struct ConvertDeviceResolveAnyOp + : public OpConversionPattern<IREE::HAL::DeviceResolveOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getAffinity()) { + return rewriter.notifyMatchFailure( + resolveOp, "only resolving unspecified affinities to any device"); + } + + auto deviceType = rewriter.getType<IREE::HAL::DeviceType>(); + Value device; + auto resolveDevice = [&]() { + if (!device) { + device = rewriter.create<IREE::HAL::DevicesGetOp>( + resolveOp.getLoc(), deviceType, + rewriter.create<arith::ConstantIndexOp>(resolveOp.getLoc(), 0)); + } + return device; + }; + + SmallVector<Value> results; + for (auto resultType : resolveOp.getResultTypes()) { + if (isa<IREE::HAL::DeviceType>(resultType)) { + results.push_back(resolveDevice()); + } else if (isa<IREE::HAL::AllocatorType>(resultType)) { + results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>( + resolveOp.getLoc(), resolveDevice())); + } else if (isa<IntegerType>(resultType)) { + results.push_back(rewriter.create<arith::ConstantIntOp>( + resolveOp.getLoc(), -1ll, 64)); + } + } + + rewriter.replaceOp(resolveOp, results); + return success(); + } +}; + +struct ConvertDeviceResolveAffinityOp + : public OpConversionPattern<IREE::HAL::DeviceResolveOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto affinityAttr = adaptor.getAffinityAttr(); + if (!affinityAttr) { + return rewriter.notifyMatchFailure( + resolveOp, "only resolving fully specified affinities"); + } + auto flatDeviceAttr = dyn_cast<FlatSymbolRefAttr>(affinityAttr.getDevice()); + if (!flatDeviceAttr) { + return rewriter.notifyMatchFailure( + resolveOp, "nested device references not yet supported"); + } + + auto deviceType = rewriter.getType<IREE::HAL::DeviceType>(); + Value device; + auto resolveDevice = [&]() { + if (!device) { + device = rewriter.create<IREE::Util::GlobalLoadOp>( + resolveOp.getLoc(), deviceType, flatDeviceAttr.getValue(), + /*is_immutable=*/true); + } + return device; + }; + + SmallVector<Value> results; + for (auto resultType : resolveOp.getResultTypes()) { + if (isa<IREE::HAL::DeviceType>(resultType)) { + results.push_back(resolveDevice()); + } else if (isa<IREE::HAL::AllocatorType>(resultType)) { + results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>( + resolveOp.getLoc(), resolveDevice())); + } else if (isa<IntegerType>(resultType)) { + results.push_back(rewriter.create<arith::ConstantIntOp>( + resolveOp.getLoc(), affinityAttr.getQueueMask(), 64)); + } + } + + rewriter.replaceOp(resolveOp, results); + return success(); + } +}; + struct ConvertExecutableCalculateWorkgroupsOp : public OpConversionPattern<IREE::HAL::ExecutableCalculateWorkgroupsOp> { using OpConversionPattern::OpConversionPattern; @@ -43,6 +129,10 @@ ConversionTarget &conversionTarget, TypeConverter &typeConverter, RewritePatternSet &patterns) { + conversionTarget.addIllegalOp<IREE::HAL::DeviceResolveOp>(); + patterns.insert<ConvertDeviceResolveAnyOp>(typeConverter, context); + patterns.insert<ConvertDeviceResolveAffinityOp>(typeConverter, context); + conversionTarget.addIllegalOp<IREE::HAL::ExecutableCalculateWorkgroupsOp>(); patterns.insert<ConvertExecutableCalculateWorkgroupsOp>(typeConverter, context);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel index 6056652..b38bbbe 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel
@@ -15,7 +15,10 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( - ["pseudo_ops.mlir"], + [ + "device_ops.mlir", + "pseudo_ops.mlir", + ], include = ["*.mlir"], ), cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt index 2a57a80..6757109 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt
@@ -14,6 +14,7 @@ NAME lit SRCS + "device_ops.mlir" "pseudo_ops.mlir" TOOLS FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir new file mode 100644 index 0000000..44fb128 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir
@@ -0,0 +1,75 @@ +// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s + +// CHECK-LABEL: @deviceResolveAnyDevice +util.func public @deviceResolveAnyDevice() -> !hal.device { + // CHECK-DAG: %[[ANY_ORDINAL:.+]] = arith.constant 0 + // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %[[ANY_ORDINAL]] : !hal.device + %device = hal.device.resolve : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDevice +util.func public @deviceResolveDevice() -> !hal.device { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + %device = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDeviceQueueAffinityAny +util.func public @deviceResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 + %device, %queue_affinity_any = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] + util.return %device, %queue_affinity_any : !hal.device, i64 +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDeviceQueueAffinity45 +util.func public @deviceResolveDeviceQueueAffinity45() -> (!hal.device, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %device, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 + // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] + util.return %device, %queue_affinity_45 : !hal.device, i64 +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveAllocator +util.func public @deviceResolveAllocator() -> !hal.allocator { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + %allocator = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.allocator + // CHECK: util.return %[[ALLOCATOR]] + util.return %allocator : !hal.allocator +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveAllocatorQueueAffinity45 +util.func public @deviceResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %allocator, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 + // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 +}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index f8efeba..a911de8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -69,32 +69,6 @@ return constantBuffer; } -IREE::VM::RodataOp -createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp, - OpBuilder &builder) { - auto executableOp = - binaryOp.getOperation()->getParentOfType<IREE::HAL::ExecutableOp>(); - auto insertPoint = builder.saveInsertionPoint(); - builder.setInsertionPoint(builder.getInsertionBlock()->getParentOp()); - - std::string rodataName = sanitizeSymbolName( - (executableOp.getName() + "_" + binaryOp.getName()).str()); - auto rodataOp = builder.create<IREE::VM::RodataOp>( - binaryOp.getLoc(), rodataName, binaryOp.getData()); - rodataOp.setPrivate(); - if (binaryOp.getMimeType().has_value()) { - rodataOp.setMimeTypeAttr(binaryOp.getMimeTypeAttr()); - } - - // TODO(benvanik): should these be page aligned? memcpy fastpath is fine for - // now. - rodataOp.setAlignmentAttr(builder.getI64IntegerAttr(16)); - - builder.restoreInsertionPoint(insertPoint); - - return rodataOp; -} - namespace { class RemoveExecutableOpConversion @@ -128,9 +102,15 @@ auto executableBinaryOp = SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>( createOp, createOp.getExecutableTarget()); - auto rodataOp = createExecutableBinaryRodata(executableBinaryOp, rewriter); - auto executableRodata = rewriter.createOrFold<IREE::VM::ConstRefRodataOp>( - createOp.getLoc(), rodataOp); + auto executableOp = executableBinaryOp.getOperation() + ->getParentOfType<IREE::HAL::ExecutableOp>(); + std::string rodataName = sanitizeSymbolName( + (executableOp.getName() + "_" + executableBinaryOp.getName()).str()); + auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>( + executableBinaryOp.getLoc(), + IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()), + rewriter.getStringAttr(rodataName), executableBinaryOp.getData(), + rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr()); // Get format string as a rodata blob. auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>( @@ -151,7 +131,7 @@ SmallVector<Value, 8> callOperands = { adaptor.getDevice(), executableFormatStr, - executableRodata, + rodataOp, constantBuffer, }; callOperands.append(adaptor.getLayouts().begin(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h index 071643a..0596524 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h
@@ -23,11 +23,6 @@ Value createPackedConstantBuffer(Location loc, ValueRange constantValues, OpBuilder &builder); -// Creates a vm.rodata containing the contents of a hal.executable.binary. -IREE::VM::RodataOp -createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp, - OpBuilder &builder); - } // namespace mlir::iree_compiler #endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOVM_PATTERNS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir index 9249b56..5dd5341 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
@@ -1,7 +1,5 @@ // RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s -// CHECK: vm.rodata private @exe_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> -// CHECK: vm.rodata private @exe_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8> hal.executable @exe { hal.executable.binary @binary1 attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, @@ -24,7 +22,7 @@ ) -> (!hal.executable, !hal.executable) { // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_ - // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe_binary1 : !vm.buffer + // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> // CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer // CHECK: %[[EXE1:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]], [%[[LAYOUT0]], %[[LAYOUT1]]] @@ -32,7 +30,7 @@ %0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) layouts([%layout0, %layout1]) : !hal.executable // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_ - // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe_binary2 : !vm.buffer + // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8> // CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer // CHECK: %[[EXE2:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]], [%[[LAYOUT1]], %[[LAYOUT0]]] @@ -45,14 +43,12 @@ // ----- -// CHECK: vm.rodata private @exe1_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> hal.executable @exe1 { hal.executable.binary @binary1 attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, format = "format" } } -// CHECK: vm.rodata private @exe2_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8> hal.executable @exe2 { hal.executable.binary @binary2 attributes { data = dense<[4, 5, 6, 7]> : vector<4xi8>, @@ -67,17 +63,16 @@ %layout1: !hal.pipeline_layout ) -> (!hal.executable, !hal.executable) { // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe1_binary1 : !vm.buffer + // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe1_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> %0 = hal.executable.create device(%device : !hal.device) target(@exe1::@binary1) layouts([%layout0, %layout1]) : !hal.executable // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe2_binary2 : !vm.buffer + // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe2_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8> %1 = hal.executable.create device(%device : !hal.device) target(@exe2::@binary2) layouts([%layout1, %layout0]) : !hal.executable util.return %0, %1 : !hal.executable, !hal.executable } // ----- -// CHECK: vm.rodata private @exe_binary {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> hal.executable @exe { hal.executable.binary @binary attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, @@ -95,7 +90,7 @@ %constant0: i32, %constant1: i32 ) -> !hal.executable { // CHECK-DAG: %[[FORMAT:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY:.+]] = vm.const.ref.rodata @exe_binary : !vm.buffer + // CHECK-DAG: %[[BINARY:.+]] = vm.rodata.inline "exe_binary" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> // CHECK: %[[CONSTANTS:.+]] = vm.buffer.alloc %c12, %c16 : !vm.buffer // CHECK-DAG: %[[INDEX0:.+]] = vm.const.i64 0
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index de9bbb4..a58f32a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -34,58 +34,32 @@ // Get the affinity from the op or an ancestor. Note that there may be no // affinity specified at all. - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(resolveOp); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(resolveOp); + + // If no affinity was specified then resolve as 'any'. + if (!affinityAttr) { + rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>( + resolveOp, resolveOp.getResultTypes(), + IREE::HAL::DeviceAffinityAttr{}); + return success(); + } // We currently only handle HAL device affinities. // We could make this an interface to select the device and allow users to // provide their own affinities to convert to HAL. In the future users may // also want to provide devices as function arguments post-initialization. // For now we just have one way to specify device globals. - auto deviceAffinityAttr = - dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr); - if (!deviceAffinityAttr) { - resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " - "affinities are supported"; - return rewriter.notifyMatchFailure( - resolveOp, "only HAL device affinities are supported"); + if (auto deviceAffinityAttr = + dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) { + rewriter.replaceOpWithNewOp<IREE::HAL::DeviceResolveOp>( + resolveOp, resolveOp.getResultTypes(), deviceAffinityAttr); + return success(); } - // Get the device handle and queue. - // - // TODO(multi-device): specialized types; may need analysis we don't have - // or at least a symbol lookup. An alternative would be an optional type - // on the affinity in cases where we've evaluated it early but for now - // we assume all device types are unspecialized. - auto deviceType = rewriter.getType<IREE::HAL::DeviceType>(); - Value device = rewriter.create<IREE::Util::GlobalLoadOp>( - resolveOp.getLoc(), deviceType, - deviceAffinityAttr.getDevice().getValue(), - /*is_immutable=*/true); - int64_t queueMask = deviceAffinityAttr.getQueueMask(); - - SmallVector<Value> results; - if (isa<IREE::HAL::DeviceType>(resultTypes[0])) { - results.push_back(device); - } else if (isa<IREE::HAL::AllocatorType>(resultTypes[0])) { - results.push_back(rewriter.create<IREE::HAL::DeviceAllocatorOp>( - resolveOp.getLoc(), device)); - } else { - return rewriter.notifyMatchFailure( - resolveOp, "unrecognized context resolve types for a HAL target"); - } - if (resultTypes.size() > 1) { - if (isa<IntegerType>(resultTypes[1])) { - results.push_back(rewriter.create<arith::ConstantIntOp>( - resolveOp.getLoc(), queueMask, 64)); - } else { - return rewriter.notifyMatchFailure( - resolveOp, - "unrecognized context resolve types for a HAL target (extended)"); - } - } - - rewriter.replaceOp(resolveOp, results); - return success(); + resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " + "affinities are supported"; + return rewriter.notifyMatchFailure( + resolveOp, "only HAL device affinities are supported"); } }; @@ -684,7 +658,7 @@ // make this difficult. For now we assume each stream region being lowered // has a singular affinity that may itself reference multiple devices in the // future but currently uniquely identifies a device. - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp); // Get the device handle we're executing against in this execution region. // Note that this is a dynamic value: we have to treat the device as unknown
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp index 8b628da..5c685c0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
@@ -22,7 +22,7 @@ namespace mlir::iree_compiler { Value lookupDeviceFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>( op->getLoc(), TypeRange{ @@ -34,7 +34,7 @@ std::tuple<Value, Value> lookupDeviceAndQueueAffinityFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>( op->getLoc(), TypeRange{ @@ -46,7 +46,7 @@ } Value lookupAllocatorFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>( op->getLoc(), TypeRange{ @@ -58,7 +58,7 @@ std::tuple<Value, Value> lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create<IREE::Stream::ContextResolveOp>( op->getLoc(), TypeRange{
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir index 20a9c59..c60d0da 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s +// NOTE: the hal.device.resolve lowering in HAL-to-HAL does most of the work. + util.global private @device : !hal.device // CHECK-LABEL: @contextResolveDefaultDevice @@ -16,63 +18,12 @@ util.global private @device : !hal.device -// CHECK-LABEL: @contextResolveDevice -util.func public @contextResolveDevice() -> !hal.device { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - %device = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device - // CHECK: util.return %[[DEVICE]] - util.return %device : !hal.device -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny -util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 - %device, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 - // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] - util.return %device, %queue_affinity_any : !hal.device, i64 -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveDeviceQueueAffinity45 -util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %device, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 - // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] - util.return %device, %queue_affinity_45 : !hal.device, i64 -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveAllocator -util.func public @contextResolveAllocator() -> !hal.allocator { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator - %allocator = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.allocator - // CHECK: util.return %[[ALLOCATOR]] - util.return %allocator : !hal.allocator -} - -// ----- - -util.global private @device : !hal.device - // CHECK-LABEL: @contextResolveAllocatorQueueAffinity45 -util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { +util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.device, !hal.allocator, i64) { // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 - // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] - util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 + %device, %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, !hal.allocator, i64 + // CHECK: util.return %[[DEVICE]], %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %device, %allocator, %queue_affinity_45 : !hal.device, !hal.allocator, i64 }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 6d3c4a2..fe32e8b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -448,28 +448,34 @@ // `[targets, ...]` (optional) do { IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) + if (failed(p.parseAttribute(executableTargetAttr))) { return {}; + } executableTargetAttrs.push_back(executableTargetAttr); } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) + if (failed(p.parseRSquare())) { return {}; + } } else { // `{config dict}` (optional) - if (failed(p.parseAttribute(configAttr))) + if (failed(p.parseAttribute(configAttr))) { return {}; + } // `, [targets, ...]` (optional) if (succeeded(p.parseOptionalComma())) { - if (failed(p.parseLSquare())) + if (failed(p.parseLSquare())) { return {}; + } do { IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) + if (failed(p.parseAttribute(executableTargetAttr))) { return {}; + } executableTargetAttrs.push_back(executableTargetAttr); } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) + if (failed(p.parseRSquare())) { return {}; + } } } } @@ -502,7 +508,14 @@ } std::string DeviceTargetAttr::getSymbolNameFragment() { - return sanitizeSymbolName(getDeviceID().getValue().lower()); + std::string name = getDeviceID().getValue().lower(); + if (auto ordinalAttr = + dyn_cast_if_present<IntegerAttr>(getConfigurationAttr("ordinal"))) { + name += "_"; + name += std::to_string(ordinalAttr.getInt()); + name += "_"; // can't have trailing numbers + } + return sanitizeSymbolName(name); } bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { @@ -510,6 +523,13 @@ return configAttr && configAttr.get(name); } +Attribute DeviceTargetAttr::getConfigurationAttr(StringRef name) { + if (auto configAttr = getConfiguration()) { + return configAttr.get(name); + } + return {}; +} + void DeviceTargetAttr::getExecutableTargets( SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) { for (auto attr : getExecutableTargets()) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 2d10dc3..9511cf8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -754,6 +754,8 @@ // Returns true if there's an attribute with the given name in the // configuration dictionary. bool hasConfigurationAttr(StringRef name); + // Returns the configuration attribute with the given name if found. + Attribute getConfigurationAttr(StringRef name); // Returns zero or more executable targets that this device supports. void getExecutableTargets( @@ -916,7 +918,7 @@ }]; let parameters = (ins - AttrParameter<"FlatSymbolRefAttr", "">:$device, + AttrParameter<"SymbolRefAttr", "">:$device, AttrParameter<"int64_t", "">:$queue_mask );
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 23dab24..bf596ce 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -90,14 +90,16 @@ for (auto source : op.getSources()) { auto it = uniqueSources.insert(std::make_pair(source, orderedSources.size())); - if (it.second) + if (it.second) { orderedSources.push_back(source); + } resultMapping.push_back(it.first->second); } - if (orderedSources.size() == op.getSources().size()) + if (orderedSources.size() == op.getSources().size()) { return failure(); - auto newOp = rewriter.create<TensorBarrierOp>( - op.getLoc(), orderedSources, op.getSignalFence(), op.getAffinityAttr()); + } + auto newOp = rewriter.create<TensorBarrierOp>(op.getLoc(), orderedSources, + op.getSignalFence()); SmallVector<Value> newResults; newResults.reserve(newOp.getNumResults()); for (unsigned newIndex : resultMapping) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index e28025e..538b32d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -458,19 +458,6 @@ waitFence, name, affinity); } -Value TensorImportOp::getTiedResult(unsigned resultIndex) { - return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); -} - -::std::optional<unsigned> -TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) { - return {0}; // source -} - -SmallVector<int64_t> TensorImportOp::getTiedResultOperandIndices() { - return {0}; // source -} - static LogicalResult verifyTypeStorageCompatibility(Operation *op, Type encodingType, Type storageType) { @@ -539,19 +526,6 @@ affinity); } -Value TensorExportOp::getTiedResult(unsigned resultIndex) { - return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); -} - -::std::optional<unsigned> -TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) { - return {0}; // source -} - -SmallVector<int64_t> TensorExportOp::getTiedResultOperandIndices() { - return {0}; // source -} - LogicalResult TensorExportOp::verify() { TensorExportOp op = *this; auto sourceType = llvm::cast<TensorType>(op.getSource().getType()); @@ -595,11 +569,10 @@ //===----------------------------------------------------------------------===// void TensorBarrierOp::build(OpBuilder &builder, OperationState &result, - ValueRange sources, Value signalFence, - Attribute affinity) { + ValueRange sources, Value signalFence) { auto resultTypes = llvm::map_to_vector( sources, [](Value source) { return source.getType(); }); - build(builder, result, resultTypes, sources, signalFence, affinity); + build(builder, result, resultTypes, sources, signalFence); } Value TensorBarrierOp::getTiedResult(unsigned resultIndex) { @@ -1064,6 +1037,23 @@ } //===----------------------------------------------------------------------===// +// hal.device.resolve +//===----------------------------------------------------------------------===// + +void DeviceResolveOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + for (auto result : getResults()) { + if (isa<IREE::HAL::DeviceType>(result.getType())) { + setNameFn(result, "device"); + } else if (isa<IREE::HAL::AllocatorType>(result.getType())) { + setNameFn(result, "allocator"); + } else if (isa<IntegerType>(result.getType())) { + setNameFn(result, "queue_affinity"); + } + } +} + +//===----------------------------------------------------------------------===// // hal.device.allocator //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index d35a0cb..dff5438 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -98,11 +98,6 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ AttrSizedOperandSegments, - DeclareOpInterfaceMethods<Util_TiedOpInterface, [ - "getTiedResult", - "getTiedResultOperandIndex", - "getTiedResultOperandIndices", - ]>, Util_ShapeAwareOp, ]> { let summary = [{imports a tensor from a HAL buffer view}]; @@ -171,11 +166,6 @@ } def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ - DeclareOpInterfaceMethods<Util_TiedOpInterface, [ - "getTiedResult", - "getTiedResultOperandIndex", - "getTiedResultOperandIndices", - ]>, Util_ShapeAwareOp, ]> { let summary = [{exports a tensor to a HAL buffer view}]; @@ -320,15 +310,13 @@ let arguments = (ins Variadic<AnyTensor>:$sources, - HAL_Fence:$signal_fence, - OptionalAttr<AnyAttr>:$affinity + HAL_Fence:$signal_fence ); let results = (outs Variadic<AnyTensor>:$results ); let assemblyFormat = [{ - (`on` `(` $affinity^ `)`)? `join` `` `(` $sources `:` type($sources) `)` `=` `` `>` $signal_fence `:` type($signal_fence) @@ -338,8 +326,7 @@ let builders = [ OpBuilder<(ins "ValueRange":$sources, - "Value":$signalFence, - "Attribute":$affinity + "Value":$signalFence )>, ]; @@ -1616,6 +1603,41 @@ let opDocGroup = OpGroupDeviceOps in { +def HAL_DeviceResolveOp : HAL_PureOp<"device.resolve", [ + DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, +]> { + let summary = [{resolves device handles based on affinity}]; + let description = [{ + Examples: + ``` + // Returns a HAL device. + = hal.device.resolve on(#something) : !hal.device + // Returns a HAL device, allocator, and (optional) queue affinity. + = hal.device.resolve on(#something) : !hal.device, !hal.allocator, i64 + // Returns a HAL allocator and (optional) queue affinity. + = hal.device.resolve on(#something) : !hal.allocator, i64 + // Returns "any" device. Should only be used as a fallback. + = hal.device.resolve : !hal.device + ``` + }]; + + let arguments = (ins + OptionalAttr<HAL_DeviceAffinityAttr>:$affinity + ); + let results = (outs + Variadic<AnyTypeOf<[ + HAL_Device, + HAL_Allocator, + I64, + ]>>:$results + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + attr-dict `:` type($results) + }]; +} + def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, ]> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp index 06d2366..eedb427 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp
@@ -49,6 +49,7 @@ : public IREE::HAL::impl::ConvertToHALPassBase<ConvertToHALPass> { void runOnOperation() override { auto *context = &getContext(); + auto moduleOp = getOperation(); // Gather all interfaces from registered dialects. // These will perform the tensor->buffer mapping for their ops. @@ -64,8 +65,7 @@ HALTypeConverter typeConverter(conversionInterfaces); HALConversionTarget conversionTarget(context, typeConverter); - RewritePatternSet patterns(&getContext()); - + RewritePatternSet patterns(context); populateHALToHALPatterns(context, conversionTarget, typeConverter, patterns); populateUtilToHALPatterns(context, conversionTarget, typeConverter, @@ -84,13 +84,14 @@ // NOTE: we allow ops that we don't know about to allow custom dialects // that don't need anything HAL-specific to pass through. - if (failed(applyPartialConversion(getOperation(), conversionTarget, + if (failed(applyPartialConversion(moduleOp, conversionTarget, std::move(patterns)))) { return signalPassFailure(); } // Cleanup conversion attributes used for spooky action at a distance. - for (auto executableOp : getOperation().getOps<IREE::HAL::ExecutableOp>()) { + moduleOp->removeAttr("stream.affinity.default"); + for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) { for (auto variantOp : executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { for (auto exportOp :
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 6e5f110..7ad1f4a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -165,12 +165,30 @@ return map; } +static std::pair<Value, Value> +getDeviceAndQueueAffinity(Location loc, IREE::Stream::AffinityAttr affinityAttr, + OpBuilder &builder) { + if (auto deviceAffinityAttr = + dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr)) { + auto resolveOp = builder.create<IREE::HAL::DeviceResolveOp>( + loc, + TypeRange{ + builder.getType<IREE::HAL::DeviceType>(), + builder.getI64Type(), + }, + deviceAffinityAttr); + return std::make_pair(resolveOp.getResult(0), resolveOp.getResult(1)); + } + auto device = IREE::HAL::DeviceType::resolveAny(loc, builder); + auto queueAffinity = builder.create<arith::ConstantIntOp>(loc, -1, 64); + return std::make_pair(device, queueAffinity); +} + // Appends a global hal.buffer initialized to the size required for all // of the bindings in |dispatchParams| (plus alignment). -static IREE::Util::GlobalOp -appendGlobalBuffer(Location loc, StringRef baseName, - const DispatchParams &dispatchParams, - OpBuilder &moduleBuilder) { +static IREE::Util::GlobalOp appendGlobalBuffer( + Location loc, StringRef baseName, const DispatchParams &dispatchParams, + IREE::Stream::AffinityAttr affinityAttr, OpBuilder &moduleBuilder) { // Create a global to hold the HAL buffer. auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( loc, (baseName + "_buffer").str(), @@ -191,12 +209,12 @@ auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock()); IndexSet indexSet(loc, initBuilder); - // TODO(multi-device): support multiple devices in benchmark generation. - Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder); + // Resolve allocator for the benchmark device. + auto [device, queueAffinity] = + getDeviceAndQueueAffinity(loc, affinityAttr, initBuilder); auto allocator = initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).getResult(); - auto queueAffinity = initBuilder.create<arith::ConstantIntOp>(loc, -1, 64); auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal; auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer | IREE::HAL::BufferUsageBitfield::DispatchStorage; @@ -234,8 +252,8 @@ } // Add a global variable holding an initialized buffer for the dispatch IO. - auto bufferGlobalOp = - appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder); + auto bufferGlobalOp = appendGlobalBuffer(loc, baseName, dispatchParams, + affinityAttr, moduleBuilder); // Create an exported benchmark function that runs the dispatches. auto funcType = @@ -261,10 +279,9 @@ auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>( loc, funcBuilder.getIndexType(), entryBlock->getArgument(0)); - // TODO(multi-device): support multiple devices in benchmark generation. - // For now we should just use the affinityAttr to resolve the device. - Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder); - Value queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64); + // Resolve device for this particular benchmark. + auto [device, queueAffinity] = + getDeviceAndQueueAffinity(loc, affinityAttr, funcBuilder); // Create and begin command buffer. // TODO(benvanik): reuse the command buffer (initialize once and store). @@ -423,8 +440,9 @@ // would be to generate one module per device dispatches are made on such // that users can isolate to individual devices. For now we just deal with // it. - for (auto globalOp : deviceAnalysis.getDeviceGlobals()) + for (auto globalOp : deviceAnalysis.getDeviceGlobals()) { moduleBuilder.clone(*globalOp.getOperation()); + } // Clone the executable variant into the new module. auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index e1437b2..221f754 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -600,6 +600,18 @@ return signalPassFailure(); } + // If no devices were defined and there are dispatches in the program then + // error out. This provides a better error message than if we were to allow + // this pass to no-op and then fail during conversion later on. + if (layoutAnalysis.hasDispatches() && + deviceAnalysis.getDeviceGlobals().empty()) { + mlir::emitError(moduleOp.getLoc()) + << "no HAL devices defined in the module; use the module-level " + "hal.device.targets attribute, the --iree-hal-target-device= " + "flag, or provide inputs with global !hal.devices defined"; + return signalPassFailure(); + } + // Gather the required executable targets per executable and dispatch site. auto requiredExecutableTargets = buildRequiredExecutableTargetsMap( moduleOp, deviceAnalysis, layoutAnalysis);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp index 67bf383..5d70f1b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp
@@ -142,36 +142,32 @@ // have one set. static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp, FlatSymbolRefAttr defaultDeviceRef) { - Builder builder(moduleOp); - auto affinityName = builder.getStringAttr("stream.affinity"); - auto affinityAttr = builder.getAttr<IREE::HAL::DeviceAffinityAttr>( - defaultDeviceRef, /*queue_mask=*/-1ll); + auto affinityAttr = IREE::HAL::DeviceAffinityAttr::get( + moduleOp.getContext(), defaultDeviceRef, /*queue_mask=*/-1ll); - // TODO(benvanik): make this an interface that can be registered on types. - auto isAnnotatableType = [](Type type) { - return isa<TensorType>(type) || isa<IREE::Stream::ResourceType>(type); - }; - for (auto &op : moduleOp.getOps()) { - bool shouldAnnotate = true; - if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) { - if (!isAnnotatableType(globalOp.getGlobalType())) { - shouldAnnotate = false; - } - } else if (op.hasTrait<OpTrait::SymbolTable>()) { - // Symbol table ops can't reference parent symbols properly. - shouldAnnotate = false; - } - if (!shouldAnnotate) { - continue; // skip op - } + // Default on the module that applies to any ops that don't otherwise have a + // placement. Ideally we never need this but some programs may take/return no + // tensors or have tensors come from unattributed containers (lists/dicts). + moduleOp->setAttr("stream.affinity.default", affinityAttr); - if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) { - if (!affinityOp.getAffinityAttr()) { - affinityOp.setAffinityAttr(affinityAttr); + // Set all arg/results to route through the default device unless they've + // already been assigned. + auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity"); + for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) { + if (funcOp.isPublic()) { + for (auto arg : funcOp.getArguments()) { + if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) { + if (!funcOp.getArgAttr(arg.getArgNumber(), affinityName)) { + funcOp.setArgAttr(arg.getArgNumber(), affinityName, affinityAttr); + } + } } - } else { - if (!op.hasAttr(affinityName)) { - op.setAttr(affinityName, affinityAttr); + for (auto result : llvm::enumerate(funcOp.getResultTypes())) { + if (isa<IREE::Stream::AffinityTypeInterface>(result.value())) { + if (!funcOp.getResultAttr(result.index(), affinityName)) { + funcOp.setResultAttr(result.index(), affinityName, affinityAttr); + } + } } } }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 773ed92..54a87b0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -198,11 +198,17 @@ passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( {assignmentOptions.targetDevices})); } + + // Create globals for each device (if needed). passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass( {assignmentOptions.defaultDevice})); + + // Resolve #hal.device.promise and #hal.device.alias attributes. passManager.addPass(IREE::HAL::createResolveDevicePromisesPass()); passManager.addPass( IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry})); + + // Verify devices are valid. passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); } @@ -222,6 +228,9 @@ // and initial interface analysis (we rely on CSE and such having been run). addCleanupPatterns(passManager); + // Verify devices are valid. + passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); + //---------------------------------------------------------------------------- // Device-specific interface materialization //----------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp index e1ca624..6ac9cc8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp
@@ -127,13 +127,38 @@ return signalPassFailure(); } - // Must have at least one device specified. - if (deviceAnalysis.getDeviceGlobals().empty()) { + // Devices are only required if we have dialects we may lower into device + // code. For now checking for tensor types is probably sufficient though we + // may want a pluggable way to decide this (e.g. dialect/type/op + // interfaces). + auto isTensor = [](Type type) { return isa<TensorType>(type); }; + bool anyTensors = false; + for (auto &op : moduleOp.getOps()) { + if (op.hasTrait<OpTrait::IREE::Util::ObjectLike>()) { + continue; // ignore executables + } + op.walk([&](Operation *childOp) { + if (llvm::any_of(childOp->getOperandTypes(), isTensor) || + llvm::any_of(childOp->getResultTypes(), isTensor)) { + anyTensors = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + // TODO(multi-device): the logic above is insufficient; we only need devices + // if the program will end up requiring them but we don't know that here. + // We have to wait until we've lowered to the point where we do require a + // device _and_ we actually want one (aren't compiling a non-HAL program). + // We could probably have an op interface, better output from the pass that + // requires the devices, etc. For now we error out in HAL conversion when we + // try to resolve devices. + if (false && anyTensors && deviceAnalysis.getDeviceGlobals().empty()) { auto diagnostic = moduleOp.emitError(); diagnostic << "no HAL devices defined in the module; use the module-level " "hal.device.targets attribute, the --iree-hal-target-device= " - "flags, or provide inputs with global !hal.devices defined; "; + "flag, or provide inputs with global !hal.devices defined; "; printAvailable(diagnostic, *targetRegistry.value); return signalPassFailure(); }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir index 61926d3..adce11f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
@@ -4,7 +4,7 @@ // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a[0],device_a[1]})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ORDINALS // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.target<"local">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ATTR // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.alias<"device_a">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ALIAS -// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices="device_a,#hal.device.alias<"device_b">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices={"device_a,#hal.device.alias<"device_b">"}})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a=#hal.device.alias<"device_a">,"device_bc=device_b,#hal.device.alias<"device_c">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT-MULTI // CHECK: module
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir index 11b3918..6360abe 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir
@@ -33,6 +33,7 @@ // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_0> module @module attributes { hal.device.targets = [ #hal.device.select<[#device_a, #device_b]> : !hal.device @@ -42,18 +43,6 @@ // CHECK-SAME: #[[DEVICE_A]], // CHECK-SAME: #[[DEVICE_B]] // CHECK-SAME: ]> : !hal.device - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> - util.global private @tensor_global : tensor<4xf32> - - // CHECK: util.global private @primitive_global - // CHECK-NOT: stream.affinity - util.global private @primitive_global : i32 - - // CHECK: util.func private @func - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> - util.func private @func() -> () } // ----- @@ -69,6 +58,7 @@ // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_a> module @module attributes { hal.device.targets = { device_a = #device_a, @@ -77,10 +67,6 @@ } { // CHECK: util.global private @device_a = #[[DEVICE_A]] // CHECK: util.global private @device_bc = #hal.device.select<[#[[DEVICE_B]], #[[DEVICE_C]]]> - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a> - util.global private @tensor_global : tensor<4xf32> } // ----- @@ -94,6 +80,7 @@ // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_b> module @module attributes { hal.device.targets = { device_a = #device_a, @@ -103,10 +90,6 @@ } { // CHECK: util.global private @device_a // CHECK: util.global private @device_b - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b> - util.global private @tensor_global : tensor<4xf32> } // ----- @@ -120,6 +103,7 @@ // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_1> module @module attributes { hal.device.targets = [ #device_a, @@ -129,9 +113,4 @@ } { // CHECK: util.global private @__device_0 // CHECK: util.global private @__device_1 - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_1> - util.global private @tensor_global : tensor<4xf32> } -
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir index 4511a0b..b4e2264 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir
@@ -1,12 +1,27 @@ // RUN: iree-opt --split-input-file --iree-hal-verify-devices %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s -// expected-error@+1 {{no HAL devices defined in the module}} +// Tests that modules without tensors don't need devices. + module @module { + // CHECK: util.func private @func util.func private @func() -> () } // ----- +// TODO(multi-device): find a way to verify that devices exist if they need to. +// Currently the check is disabled as it's difficult to tell if a device will be +// needed by the time we get to the HAL layer: plugins may absorb things, etc. +// NO-expected-errorx@+1 {{no HAL devices defined in the module}} +module @module { + util.func private @func() -> () { + arith.constant dense<1.0> : tensor<4xf32> + util.return + } +} + +// ----- + module @module { // expected-error@+1 {{unregistered target device "__unregistered__"}} util.global private @device = #hal.device.target<"__unregistered__"> : !hal.device
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel index fbc0e51..10d1456 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@ ], deps = [ "//compiler/src/iree/compiler/Dialect/Flow/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//compiler/src/iree/compiler/Dialect/Stream/IR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt index 05bbb79..cc472aa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ MLIRTransformUtils MLIRTransforms iree::compiler::Dialect::Flow::IR + iree::compiler::Dialect::Stream::Analysis iree::compiler::Dialect::Stream::IR PUBLIC )
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index bdc5aaf..31d6151 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -35,41 +35,42 @@ } struct ConvertTensorConstantOp - : public OpConversionPattern<IREE::Flow::TensorConstantOp> { + : public AffinityOpConversionPattern<IREE::Flow::TensorConstantOp> { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Capture the tensor constant strongly typed with constant lifetime. - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); + auto constantType = rewriter.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::Constant); auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>( constantOp.getLoc(), constantType, convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr); + TypeAttr::get(constantOp.getType()), ValueRange{}, + executionAffinityAttr); // Transfer to unknown lifetime. - Type unknownType = IREE::Stream::ResourceType::get(getContext()); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>( constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); return success(); } }; struct ConvertTensorDynamicConstantOp - : public OpConversionPattern<IREE::Flow::TensorDynamicConstantOp> { + : public AffinityOpConversionPattern<IREE::Flow::TensorDynamicConstantOp> { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorDynamicConstantOp constantOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorDynamicConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto attrType = dyn_cast<RankedTensorType>(constantOp.getValue().getType()); if (!attrType) return failure(); @@ -91,22 +92,21 @@ } // Capture the tensor constant strongly typed with constant lifetime. - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); + auto constantType = rewriter.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::Constant); auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>( constantOp.getLoc(), constantType, convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(resultType), dynamicDims, affinityAttr); + TypeAttr::get(resultType), dynamicDims, executionAffinityAttr); // Transfer to unknown lifetime. - Type unknownType = IREE::Stream::ResourceType::get(getContext()); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>( constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); return success(); } }; @@ -114,157 +114,169 @@ // Reshapes and bitcasts become clones here to preserve shape/element type // information (which may become actual transfers depending on source/target // shape) - they'll be elided if not needed. +// +// NOTE: we transfer to the target before cloning. This may not be optimal +// as the clone may otherwise have been able to be elided on the producer +// side but we leave that for future copy elision to determine. template <typename CastOpTy> -struct ConvertTensorCastLikeOp : public OpConversionPattern<CastOpTy> { - using OpConversionPattern<CastOpTy>::OpConversionPattern; +struct ConvertTensorCastLikeOp + : public AffinityAwareConversionPattern<CastOpTy> { + using AffinityAwareConversionPattern< + CastOpTy>::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(CastOpTy op, typename CastOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); - auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultAffinityAttr = this->lookupResultAffinity(op.getResult()); + auto source = this->transferTensorOperand(op.getLoc(), op.getSource(), + adaptor.getSource(), + resultAffinityAttr, rewriter); auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + resultAffinityAttr, rewriter); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, op.getResult().getType(), - adaptor.getResultDims(), resultSize, affinityAttr); + adaptor.getResultDims(), resultSize, resultAffinityAttr); return success(); } }; struct ConvertTensorAllocaOp - : public OpConversionPattern<IREE::Flow::TensorAllocaOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern<IREE::Flow::TensorAllocaOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncAllocaOp>( - op, unknownType, resultSize, affinityAttr); + op, unknownType, resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorEmptyOp - : public OpConversionPattern<IREE::Flow::TensorEmptyOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern<IREE::Flow::TensorEmptyOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::TensorEmptyOp>( op, unknownType, op.getResult().getType(), adaptor.getResultDims(), - resultSize, affinityAttr); + resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorSplatOp - : public OpConversionPattern<IREE::Flow::TensorSplatOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern<IREE::Flow::TensorSplatOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorSplatOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>( op, unknownType, adaptor.getValue(), op.getResult().getType(), - adaptor.getResultDims(), resultSize, affinityAttr); + adaptor.getResultDims(), resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorCloneOp - : public OpConversionPattern<IREE::Flow::TensorCloneOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorCloneOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern<IREE::Flow::TensorCloneOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorCloneOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto operand = transferTensorOperand(op.getLoc(), op.getOperand(), + adaptor.getOperand(), + executionAffinityAttr, rewriter); auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); - auto operand = - consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>( op, unknownType, operand.resource, op.getOperand().getType(), op.getArgumentDims(), operand.resourceSize, op.getResult().getType(), - adaptor.getArgumentDims(), operand.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); + adaptor.getArgumentDims(), operand.resourceSize, executionAffinityAttr); return success(); } }; struct ConvertTensorTransferOp - : public OpConversionPattern<IREE::Flow::TensorTransferOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorTransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto targetAffinityAttr = - dyn_cast<IREE::Stream::AffinityAttr>(adaptor.getTarget()); - if (!targetAffinityAttr) + : public AffinityOpConversionPattern<IREE::Flow::TensorTransferOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorTransferOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + if (!executionAffinityAttr) { return rewriter.notifyMatchFailure(op, "invalid stream affinity attr"); + } + auto operand = resolveTensorOperand(op.getLoc(), op.getOperand(), + adaptor.getOperand(), rewriter); auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); - auto operand = - consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( op, unknownType, operand.resource, operand.resourceSize, operand.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr{}, targetAffinityAttr); + /*source_affinity=*/operand.affinity, + /*result_affinity=*/executionAffinityAttr); return success(); } }; struct ConvertTensorSliceOp - : public OpConversionPattern<IREE::Flow::TensorSliceOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); + : public AffinityOpConversionPattern<IREE::Flow::TensorSliceOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorSliceOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::TensorSliceOp>( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(), adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(), - resultSize, affinityAttr); + resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorUpdateOp - : public OpConversionPattern<IREE::Flow::TensorUpdateOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto update = - consumeTensorOperand(op.getLoc(), adaptor.getUpdate(), rewriter); + : public AffinityOpConversionPattern<IREE::Flow::TensorUpdateOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto target = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); + auto update = + transferTensorOperand(op.getLoc(), op.getUpdate(), adaptor.getUpdate(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::TensorUpdateOp>( op, target.resource.getType(), target.resource, op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize, adaptor.getStartIndices(), update.resource, op.getUpdate().getType(), - op.getUpdateDims(), update.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); + op.getUpdateDims(), update.resourceSize, executionAffinityAttr); return success(); } }; @@ -281,14 +293,13 @@ } struct ConvertTensorLoadOp - : public OpConversionPattern<IREE::Flow::TensorLoadOp> { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern<IREE::Flow::TensorLoadOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resultType = getTypeConverter()->convertType(op.getResult().getType()); - auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + auto source = resolveTensorOperand(op.getLoc(), op.getSource(), + adaptor.getSource(), rewriter); // If the source is not a staging resource then we need to transfer it to // a staging resource. We slice out just what is being loaded so that we @@ -299,6 +310,7 @@ // If already a staging resource then we can fast-path load the value. auto stagingType = rewriter.getType<IREE::Stream::ResourceType>( IREE::Stream::Lifetime::Staging); + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); if (source.resource.getType() == stagingType) { rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>( op, resultType, source.resource, op.getSource().getType(), @@ -306,16 +318,14 @@ return success(); } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - // Scalar tensors get transferred without slicing. auto sourceEncoding = op.getSource().getType(); if (isScalarTensor(sourceEncoding)) { auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>( op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>( op, resultType, transferOp.getResult(), sourceEncoding, adaptor.getSourceDims(), transferOp.getResultSize(), @@ -341,16 +351,17 @@ RankedTensorType::get(resultDims, sourceEncoding.getElementType(), sourceEncoding.getEncoding()); Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>( - op.getLoc(), resultEncoding, ValueRange{}, affinityAttr); + op.getLoc(), resultEncoding, ValueRange{}, source.affinity); auto sliceOp = rewriter.create<IREE::Stream::TensorSliceOp>( op.getLoc(), source.resource.getType(), source.resource, sourceEncoding, adaptor.getSourceDims(), source.resourceSize, sliceIndices, - sliceLengths, resultEncoding, ValueRange{}, resultSize, affinityAttr); + sliceLengths, resultEncoding, ValueRange{}, resultSize, + source.affinity); auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>( op.getLoc(), stagingType, sliceOp.getResult(), sliceOp.getResultSize(), sliceOp.getResultSize(), - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>( op, resultType, transferOp.getResult(), sliceOp.getResultEncoding(), sliceOp.getResultEncodingDims(), transferOp.getResultSize(), @@ -360,13 +371,13 @@ }; struct ConvertTensorStoreOp - : public OpConversionPattern<IREE::Flow::TensorStoreOp> { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern<IREE::Flow::TensorStoreOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto target = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + auto target = resolveTensorOperand(op.getLoc(), op.getTarget(), + adaptor.getTarget(), rewriter); // If the target is a staging resource then we can directly store into it // with a fast-path. Otherwise we need to stage an upload. @@ -380,34 +391,23 @@ return success(); } - // Scalar tensors disconnect from the original target. - auto targetEncoding = op.getTarget().getType(); - if (isScalarTensor(targetEncoding)) { - rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>( - op, target.resource.getType(), adaptor.getValue(), targetEncoding, - adaptor.getTargetDims(), target.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); - return success(); - } - // Use fill to store the value. // TODO(benvanik): support larger buffer slices (stage + update). IndexSet indexSet(op.getLoc(), rewriter); indexSet.populate(adaptor.getIndices()); - SmallVector<Value> lengths; - for (auto index : adaptor.getIndices()) - lengths.push_back(indexSet.get(1)); + SmallVector<Value> lengths(adaptor.getIndices().size(), indexSet.get(1)); + auto targetEncoding = op.getTarget().getType(); rewriter.replaceOpWithNewOp<IREE::Stream::TensorFillOp>( op, target.resource, targetEncoding, adaptor.getTargetDims(), target.resourceSize, adaptor.getIndices(), lengths, adaptor.getValue(), - IREE::Stream::AffinityAttr::lookup(op)); + target.affinity); return success(); } }; struct ConvertTensorTraceOp - : public OpConversionPattern<IREE::Flow::TensorTraceOp> { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern<IREE::Flow::TensorTraceOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorTraceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -416,8 +416,8 @@ SmallVector<Attribute> resourceEncodings; for (auto [tensorOperand, resourceOperand] : llvm::zip_equal(op.getValues(), adaptor.getValues())) { - auto source = - consumeTensorOperand(op.getLoc(), resourceOperand, rewriter); + auto source = resolveTensorOperand(op.getLoc(), tensorOperand, + resourceOperand, rewriter); auto stagingType = rewriter.getType<IREE::Stream::ResourceType>( IREE::Stream::Lifetime::Staging); auto traceSource = source.resource; @@ -425,13 +425,14 @@ traceSource = rewriter.create<IREE::Stream::AsyncTransferOp>( op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/nullptr); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); } resources.push_back(traceSource); resourceSizes.push_back(source.resourceSize); resourceEncodings.push_back(TypeAttr::get(tensorOperand.getType())); } + rewriter.replaceOpWithNewOp<IREE::Stream::TensorTraceOp>( op, adaptor.getKey(), resources, resourceSizes, rewriter.getArrayAttr(resourceEncodings), adaptor.getValueDims()); @@ -440,16 +441,18 @@ }; struct ConvertChannelDefaultOp - : public OpConversionPattern<IREE::Flow::ChannelDefaultOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern<IREE::Flow::ChannelDefaultOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<IREE::Stream::ChannelCreateOp>( - op, /*id=*/Value{}, + op, + /*id=*/Value{}, /*group=*/adaptor.getGroupAttr(), /*rank=*/Value{}, - /*count=*/Value{}, IREE::Stream::AffinityAttr::lookup(op)); + /*count=*/Value{}, executionAffinityAttr); return success(); } }; @@ -491,164 +494,190 @@ }; struct ConvertAllGatherOp - : public OpConversionPattern<IREE::Flow::CollectiveAllGatherOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast<ShapedType>(op.getSource().getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllGather, + : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllGatherOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>( + IREE::Stream::CollectiveKind::AllGather, /*reduction=*/std::nullopt, static_cast<IREE::Stream::CollectiveElementType>(op.getElementType())); auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); auto elementCount = rewriter.create<arith::ConstantIndexOp>( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertAllReduceOp - : public OpConversionPattern<IREE::Flow::CollectiveAllReduceOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast<ShapedType>(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllReduce, + : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllReduceOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>( + IREE::Stream::CollectiveKind::AllReduce, static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()), static_cast<IREE::Stream::CollectiveElementType>(op.getElementType())); auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); auto elementCount = rewriter.create<arith::ConstantIndexOp>( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertAllToAllOp - : public OpConversionPattern<IREE::Flow::CollectiveAllToAllOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast<ShapedType>(op.getSource().getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllToAll, + : public AffinityOpConversionPattern<IREE::Flow::CollectiveAllToAllOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>( + IREE::Stream::CollectiveKind::AllToAll, /*reduction=*/std::nullopt, static_cast<IREE::Stream::CollectiveElementType>(op.getElementType())); auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); auto elementCount = rewriter.create<arith::ConstantIndexOp>( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; -struct ConvertReduceScatterOp - : public OpConversionPattern<IREE::Flow::CollectiveReduceScatterOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast<ShapedType>(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::ReduceScatter, +struct ConvertReduceScatterOp : public AffinityOpConversionPattern< + IREE::Flow::CollectiveReduceScatterOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>( + IREE::Stream::CollectiveKind::ReduceScatter, static_cast<IREE::Stream::CollectiveReductionOp>(op.getReductionOp()), static_cast<IREE::Stream::CollectiveElementType>(op.getElementType())); auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); auto elementCount = rewriter.create<arith::ConstantIndexOp>( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertCollectiveSendRecvOp - : public OpConversionPattern<IREE::Flow::CollectiveSendRecvOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast<ShapedType>(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::SendRecv, + : public AffinityOpConversionPattern<IREE::Flow::CollectiveSendRecvOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr<IREE::Stream::CollectiveAttr>( + IREE::Stream::CollectiveKind::SendRecv, /*reduction=*/std::nullopt, static_cast<IREE::Stream::CollectiveElementType>(op.getElementType())); auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); auto elementCount = rewriter.create<arith::ConstantIndexOp>( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Pack send, recv into param. The values are checked to be within the // 16-bit range during lowering to Flow dialect. @@ -665,27 +694,31 @@ auto param = rewriter.create<arith::OrIOp>(op.getLoc(), hi, lo); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/param, IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/param, executionAffinityAttr); return success(); } }; -struct ConvertDispatchOp : public OpConversionPattern<IREE::Flow::DispatchOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - +struct ConvertDispatchOp + : public AffinityOpConversionPattern<IREE::Flow::DispatchOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::DispatchOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); @@ -700,7 +733,8 @@ llvm::zip_equal(op.getArguments(), adaptor.getArguments())) { if (llvm::isa<ShapedType>(oldOperand.getType())) { auto newOperandCast = - consumeTensorOperand(op.getLoc(), newOperand, rewriter); + transferTensorOperand(op.getLoc(), oldOperand, newOperand, + executionAffinityAttr, rewriter); newOperand = newOperandCast.resource; dispatchOperandSizes.push_back(newOperandCast.resourceSize); operandSizes.push_back(newOperandCast.resourceSize); @@ -732,9 +766,9 @@ } else { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); - resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, affinityAttr, - rewriter)); + resultSizes.push_back( + buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims, + executionAffinityAttr, rewriter)); resultTypes.push_back(unknownType); } } @@ -743,7 +777,7 @@ op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(), dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths, resultSizes, - adaptor.getTiedOperandsAttr(), affinityAttr); + adaptor.getTiedOperandsAttr(), executionAffinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } @@ -759,8 +793,8 @@ // Tensors become resources without sizes. The default type converter // adds the size so we bypass that here. We may want to allow the user // to override the lifetime with attributes, too. - return IREE::Stream::ResourceType::get(type.getContext(), - IREE::Stream::Lifetime::Unknown); + return rewriter.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::Unknown); } return getTypeConverter()->convertType(type); }; @@ -784,13 +818,12 @@ } }; -struct ConvertCallOp : public OpConversionPattern<IREE::Flow::CallOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - +struct ConvertCallOp : public AffinityOpConversionPattern<IREE::Flow::CallOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CallOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); @@ -805,7 +838,8 @@ llvm::zip_equal(op.getArguments(), adaptor.getArguments())) { if (llvm::isa<ShapedType>(oldOperand.getType())) { auto newOperandCast = - consumeTensorOperand(op.getLoc(), newOperand, rewriter); + transferTensorOperand(op.getLoc(), oldOperand, newOperand, + executionAffinityAttr, rewriter); newOperand = newOperandCast.resource; callOperandSizes.push_back(newOperandCast.resourceSize); operandSizes.push_back(newOperandCast.resourceSize); @@ -837,9 +871,9 @@ } else { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); - resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, affinityAttr, - rewriter)); + resultSizes.push_back( + buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims, + executionAffinityAttr, rewriter)); resultTypes.push_back(unknownType); } } @@ -848,7 +882,7 @@ op, resultTypes, adaptor.getCalleeAttr(), callOperands, callOperandSizes, callOperandOffsets, callOperandEnds, callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(), - affinityAttr); + executionAffinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } @@ -1065,9 +1099,10 @@ } // namespace -void populateFlowToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateFlowToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { patterns .insert<ConvertTensorConstantOp, ConvertTensorDynamicConstantOp, ConvertTensorCastLikeOp<IREE::Flow::TensorReshapeOp>, @@ -1075,17 +1110,19 @@ ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp, ConvertTensorCloneOp, ConvertTensorTransferOp, ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp, - ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter, - context); - patterns.insert<ConvertChannelDefaultOp, ConvertChannelSplitOp, - ConvertChannelRankOp, ConvertChannelCountOp>(typeConverter, - context); + ConvertTensorStoreOp, ConvertTensorTraceOp>( + typeConverter, context, affinityAnalysis); + patterns.insert<ConvertChannelDefaultOp>(typeConverter, context, + affinityAnalysis); + patterns.insert<ConvertChannelSplitOp, ConvertChannelRankOp, + ConvertChannelCountOp>(typeConverter, context); patterns .insert<ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp, - ConvertAllToAllOp, ConvertCollectiveSendRecvOp>(typeConverter, - context); - patterns.insert<ConvertDispatchOp>(typeConverter, context); - patterns.insert<ConvertFuncOp, ConvertCallOp>(typeConverter, context); + ConvertAllToAllOp, ConvertCollectiveSendRecvOp>( + typeConverter, context, affinityAnalysis); + patterns.insert<ConvertDispatchOp>(typeConverter, context, affinityAnalysis); + patterns.insert<ConvertFuncOp>(typeConverter, context); + patterns.insert<ConvertCallOp>(typeConverter, context, affinityAnalysis); patterns.insert<ConvertExecutableOp>(typeConverter, context); patterns.insert< ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp, @@ -1098,10 +1135,11 @@ patterns.insert<ConvertReturnOp>(typeConverter, context); } -void populateFlowToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateFlowToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { // Disallow all flow ops besides the ones we pass through (today). // We don't have a stream-equivalent of several of the dispatch-level flow // ops as the codegen backends directly touch them and so long as we have both @@ -1111,7 +1149,8 @@ conversionTarget.addLegalOp<IREE::Stream::ExecutableOp>(); conversionTarget.markOpRecursivelyLegal<IREE::Stream::ExecutableOp>(); - populateFlowToStreamConversionPatterns(context, typeConverter, patterns); + populateFlowToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h index ad2c95a..0379be2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h
@@ -11,18 +11,24 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform flow->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateFlowToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateFlowToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateFlowToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateFlowToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir index 4106307..063389f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
@@ -52,13 +52,15 @@ // CHECK-LABEL: @dispatchAffinity // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index) util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) { + // CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]} // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} - // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %0 = flow.dispatch @ex::@entry0(%input) { stream.affinity = #hal.device.affinity<@device_a> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3} + // CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]} // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]} - // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %1 = flow.dispatch @ex::@entry1(%input) { stream.affinity = #hal.device.affinity<@device_b> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index a755d44..ee68211 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -231,9 +231,9 @@ util.func public @tensorStoreScalar(%target : tensor<i32>) -> tensor<i32> { // CHECK: %[[VALUE:.+]] = arith.constant 9 %value = arith.constant 9 : i32 - // CHECK: %[[SPLAT:.+]] = stream.tensor.splat %[[VALUE]] : i32 -> tensor<i32> in !stream.resource<*>{%[[TARGET_SIZE]]} + // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]] : i32 -> tensor<i32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]} %0 = flow.tensor.store %value, %target : tensor<i32> - // CHECK: util.return %[[SPLAT]] + // CHECK: util.return %[[FILL]] util.return %0 : tensor<i32> }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 4323473..76eef8b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp
@@ -21,11 +21,12 @@ // %1 = stream.tensor.import %0 : !hal.buffer_view -> // tensor<4xf32> in !stream.resource<*> struct ConvertTensorImportOp - : public OpConversionPattern<IREE::HAL::TensorImportOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorImportOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern<IREE::HAL::TensorImportOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorImportOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSource().getType(); auto targetType = op.getTargetEncoding(); if (!llvm::isa<IREE::HAL::BufferType>(sourceType) && @@ -49,25 +50,23 @@ } } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - // Import (buffer view to stream resource). auto resultType = rewriter.getType<IREE::Stream::ResourceType>( IREE::Stream::Lifetime::External); Value resultSize = rewriter.create<IREE::Stream::TensorSizeOfOp>( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(), - affinityAttr); + executionAffinityAttr); Value resource = rewriter.create<IREE::Stream::TensorImportOp>( op.getLoc(), resultType, adaptor.getSource(), TypeAttr::get(targetType), - adaptor.getTargetDims(), resultSize, affinityAttr); + adaptor.getTargetDims(), resultSize, executionAffinityAttr); // Await the fence, if needed. When not specified the resource is assumed to // be immediately available. if (auto waitFence = op.getWaitFence()) { Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>( op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(), - ValueRange{waitFence}, affinityAttr); + ValueRange{waitFence}, executionAffinityAttr); resource = rewriter .create<IREE::Stream::TimepointAwaitOp>( op.getLoc(), ValueRange{resource}, @@ -77,8 +76,9 @@ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( - op, unknownType, resource, resultSize, resultSize, affinityAttr, - /*target_affinity=*/IREE::Stream::AffinityAttr{}); + op, unknownType, resource, resultSize, resultSize, + /*source_affinity=*/executionAffinityAttr, + /*target_affinity=*/executionAffinityAttr); return success(); } @@ -122,11 +122,12 @@ // %1 = stream.tensor.export %0 : tensor<4xf32> in !stream.resource<*> -> // !hal.buffer_view struct ConvertTensorExportOp - : public OpConversionPattern<IREE::HAL::TensorExportOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorExportOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern<IREE::HAL::TensorExportOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorExportOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSourceEncoding(); auto targetType = op.getTarget().getType(); if (!llvm::isa<IREE::HAL::BufferType>(targetType) && @@ -134,9 +135,9 @@ return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion"); } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Exporting a produced value - transfer our source value to an externally // usable resource and directly export it. This will cause an allocation. @@ -146,14 +147,14 @@ if (source.resource.getType() != externalType) { exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>( op.getLoc(), externalType, source.resource, source.resourceSize, - source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr{}, - affinityAttr); + source.resourceSize, /*source_affinity=*/source.affinity, + /*target_affinity=*/executionAffinityAttr); } // Export (stream resource to buffer view). rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>( op, targetType, exportSource, TypeAttr::get(sourceType), - adaptor.getSourceDims(), source.resourceSize, affinityAttr); + adaptor.getSourceDims(), source.resourceSize, executionAffinityAttr); return success(); } }; @@ -170,29 +171,23 @@ // %update = stream.async.update %0, %storage[...] // %2 = stream.async.slice %update[...] struct ConvertTensorAliasOp - : public OpConversionPattern<IREE::HAL::TensorAliasOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern<IREE::HAL::TensorAliasOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSource().getType(); auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - - // All operations (if any) will happen on the device specified by the alias - // as that indicates the affinity of the storage. - auto affinityAttr = - dyn_cast_if_present<IREE::Stream::AffinityAttr>(op.getAffinityAttr()); - if (!affinityAttr) { - affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - } + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Query the target storage buffer length; we will only populate up to // what is required for the output. Value storageSize = rewriter.create<IREE::Stream::TensorSizeOfOp>( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), - affinityAttr); + executionAffinityAttr); // Import the target storage as a resource that we can use as an update // target. We overwrite the contents and just cast the storage to the @@ -202,7 +197,7 @@ auto importOp = rewriter.create<IREE::Stream::TensorImportOp>( op.getLoc(), externalType, adaptor.getStorage(), TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, - affinityAttr); + executionAffinityAttr); // Await the fence, if needed. When not specified the storage is assumed to // be immediately available. @@ -210,7 +205,7 @@ if (auto waitFence = op.getWaitFence()) { Value waitTimepoint = rewriter.create<IREE::Stream::TimepointImportOp>( op.getLoc(), rewriter.getType<IREE::Stream::TimepointType>(), - ValueRange{waitFence}, affinityAttr); + ValueRange{waitFence}, executionAffinityAttr); storage = rewriter .create<IREE::Stream::TimepointAwaitOp>( op.getLoc(), ValueRange{storage}, @@ -223,7 +218,7 @@ auto updateOp = rewriter.create<IREE::Stream::AsyncUpdateOp>( op.getLoc(), externalType, storage, storageSize, zeroOffset, source.resourceSize, source.resource, source.resourceSize, - affinityAttr); + executionAffinityAttr); // Slice out the value from the updated tensor. // This preserves the use-def chain but is almost always elided by aliasing @@ -231,14 +226,14 @@ auto sliceOp = rewriter.create<IREE::Stream::AsyncSliceOp>( op.getLoc(), externalType, updateOp.getResult(), updateOp.getTargetSize(), zeroOffset, source.resourceSize, - source.resourceSize, affinityAttr); + source.resourceSize, executionAffinityAttr); // Transfer to match original lifetime (if needed). Value result = sliceOp.getResult(); if (source.resource.getType() != result.getType()) { result = rewriter.create<IREE::Stream::AsyncTransferOp>( op.getLoc(), source.resource.getType(), result, source.resourceSize, - source.resourceSize, affinityAttr, affinityAttr); + source.resourceSize, executionAffinityAttr, executionAffinityAttr); } rewriter.replaceOp(op, result); @@ -256,28 +251,38 @@ // %t01 = stream.timepoint.join max(%t0, %t1) // stream.timepoint.export %t01 => %fence struct ConvertTensorBarrierOp - : public OpConversionPattern<IREE::HAL::TensorBarrierOp> { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern<IREE::HAL::TensorBarrierOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::HAL::TensorBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto timepointType = rewriter.getType<IREE::Stream::TimepointType>(); + IREE::Stream::AffinityAttr anyAffinityAttr; SmallVector<Value> signaledResources; SmallVector<Value> signaledTimepoints; - for (auto sourceResource : adaptor.getSources()) { - auto source = consumeTensorOperand(op.getLoc(), sourceResource, rewriter); + for (auto [sourceTensor, sourceResource] : + llvm::zip_equal(op.getSources(), adaptor.getSources())) { + auto source = resolveTensorOperand(op.getLoc(), sourceTensor, + sourceResource, rewriter); auto barrierOp = rewriter.create<IREE::Stream::TimepointBarrierOp>( sourceResource.getLoc(), source.resource.getType(), timepointType, - source.resource, source.resourceSize, affinityAttr); + source.resource, source.resourceSize, source.affinity); signaledResources.push_back(barrierOp.getResult()); signaledTimepoints.push_back(barrierOp.getResultTimepoint()); + + // When joining from multiple affinities we need to pick one to perform + // the chain. For now we do the affinity of the last tensor with the hope + // that we can perform the final signal on the affinity that is running. + // We should instead probably change this to be set after timepoint + // propagation such that we ensure it happens on the final signal when not + // acting as a join. + anyAffinityAttr = source.affinity; } Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join( op.getLoc(), signaledTimepoints, rewriter); rewriter.create<IREE::Stream::TimepointChainExternalOp>( op.getLoc(), joinedTimepoint, ValueRange{adaptor.getSignalFence()}, - affinityAttr); + anyAffinityAttr); rewriter.replaceOp(op, signaledResources); return success(); } @@ -285,21 +290,27 @@ } // namespace -void populateHALToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateHALToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion( [](IREE::HAL::BufferViewType type) { return type; }); - patterns.insert<ConvertTensorImportOp>(typeConverter, context); - patterns.insert<ConvertTensorExportOp>(typeConverter, context); - patterns.insert<ConvertTensorAliasOp>(typeConverter, context); - patterns.insert<ConvertTensorBarrierOp>(typeConverter, context); + patterns.insert<ConvertTensorImportOp>(typeConverter, context, + affinityAnalysis); + patterns.insert<ConvertTensorExportOp>(typeConverter, context, + affinityAnalysis); + patterns.insert<ConvertTensorAliasOp>(typeConverter, context, + affinityAnalysis); + patterns.insert<ConvertTensorBarrierOp>(typeConverter, context, + affinityAnalysis); } -void populateHALToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateHALToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { // Allow executables through without modification. conversionTarget.addLegalOp<IREE::HAL::ExecutableOp>(); conversionTarget.markOpRecursivelyLegal<IREE::HAL::ExecutableOp>(); @@ -315,7 +326,8 @@ typeConverter.isLegal(op.getTarget().getType()); }); - populateHALToStreamConversionPatterns(context, typeConverter, patterns); + populateHALToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h index ed2a3c0..f3e955d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h
@@ -11,18 +11,24 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform hal->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateHALToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateHALToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateHALToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateHALToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 6bb26f1..fee06f2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" @@ -23,13 +24,59 @@ return attr; } +IREE::Stream::AffinityAttr +tryLookupGlobalAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + return affinityAnalysis->lookupGlobalAffinity(op); +} + +IREE::Stream::AffinityAttr +tryLookupExecutionAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + assert(llvm::isa<IREE::Stream::AffinityOpInterface>(op) && + "must be an affinity op"); + return affinityAnalysis->lookupExecutionAffinity(op); +} + +IREE::Stream::AffinityAttr +tryLookupResultAffinity(Value value, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + return affinityAnalysis->lookupResourceAffinity(value); +} + +static std::pair<Value, Value> +resolveTensorOperand(Location loc, Value convertedOperand, OpBuilder &builder) { + auto operandType = convertedOperand.getType(); + if (llvm::isa<IREE::Stream::ResourceType>(operandType)) { + // Prior to https://reviews.llvm.org/D111620 this is the path we'd take; + // the tensor operands would be remapped into their new resource types. + // This is still possible during rewriting if we ourselves produce a new + // resource type, but the automatic materialization will go down the + // unrealized_conversion_cast path below. + return std::make_pair(convertedOperand, + builder.createOrFold<IREE::Stream::ResourceSizeOp>( + loc, builder.getIndexType(), convertedOperand)); + } else if (auto castOp = + convertedOperand + .getDefiningOp<mlir::UnrealizedConversionCastOp>()) { + // We only have a single tensor type conversion and it expands to (resource, + // size) so that's all we look for here. + assert(castOp.getNumOperands() == 2 && "expected (resource, size)"); + return std::make_pair(castOp.getOperand(0), castOp.getOperand(1)); + } + assert(false && + "unexpected operand; expected either a IREE::Stream::ResourceType or " + "the result of a mlir::UnrealizedConversionCastOp"); + return std::make_pair(Value{}, Value{}); +} + void expandResourceOperand(Location loc, Value operand, SmallVectorImpl<Value> &newOperands, OpBuilder &builder) { if (llvm::isa<TensorType>(operand.getType())) { - auto value = consumeTensorOperand(loc, operand, builder); - newOperands.push_back(value.resource); - newOperands.push_back(value.resourceSize); + auto [resource, resourceSize] = resolveTensorOperand(loc, operand, builder); + newOperands.push_back(resource); + newOperands.push_back(resourceSize); } else if (llvm::isa<IREE::Stream::ResourceType>(operand.getType())) { newOperands.push_back(operand); newOperands.push_back( @@ -49,34 +96,28 @@ return expandedOperands; } -ConvertedTensor consumeTensorOperand(Location loc, Value operand, - OpBuilder &builder) { - auto operandType = operand.getType(); - if (llvm::isa<IREE::Stream::ResourceType>(operandType)) { - // Prior to https://reviews.llvm.org/D111620 this is the path we'd take; - // the tensor operands would be remapped into their new resource types. - // This is still possible during rewriting if we ourselves produce a new - // resource type, but the automatic materialization will go down the - // unrealized_conversion_cast path below. - return { - operand, - builder.createOrFold<IREE::Stream::ResourceSizeOp>( - loc, builder.getIndexType(), operand), - }; - } else if (auto castOp = - operand.getDefiningOp<mlir::UnrealizedConversionCastOp>()) { - // We only have a single tensor type conversion and it expands to (resource, - // size) so that's all we look for here. - assert(castOp.getNumOperands() == 2 && "expected (resource, size)"); - return { - castOp.getOperand(0), - castOp.getOperand(1), - }; +ConvertedTensor resolveTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) { + auto [resource, resourceSize] = + resolveTensorOperand(loc, convertedOperand, builder); + auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand); + return {affinityAttr, resource, resourceSize}; +} + +ConvertedTensor transferTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) { + auto [resource, resourceSize] = + resolveTensorOperand(loc, convertedOperand, builder); + auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand); + if (affinityAttr != requiredAffinityAttr) { + resource = builder.create<IREE::Stream::AsyncTransferOp>( + loc, resource.getType(), resource, resourceSize, resourceSize, + affinityAttr, requiredAffinityAttr); } - assert(false && - "unexpected operand; expected either a IREE::Stream::ResourceType or " - "the result of a mlir::UnrealizedConversionCastOp"); - return ConvertedTensor(); + return {requiredAffinityAttr, resource, resourceSize}; } } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h index fd9249e..43cfbb0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
@@ -7,37 +7,123 @@ #ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_ #define IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_ +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Converts a supported attribute type to the corresponding stream dialect // value. Returns the provided value if it is natively supported. TypedAttr convertAttributeToStream(TypedAttr attr); -void expandResourceOperand(Location loc, Value operand, - SmallVectorImpl<Value> &newOperands, - OpBuilder &builder); +IREE::Stream::AffinityAttr +tryLookupGlobalAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis); +IREE::Stream::AffinityAttr +tryLookupExecutionAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis); +IREE::Stream::AffinityAttr +tryLookupResultAffinity(Value value, + IREE::Stream::AffinityAnalysis *affinityAnalysis); -SmallVector<Value> expandResourceOperands(Location loc, ValueRange operands, - ConversionPatternRewriter &rewriter); - -// https://reviews.llvm.org/D111620 broke 1->N type expansion during dialect -// conversion. It inserts unrealized_conversion_casts but then passes the -// illegal source dialect types for pattern operands, meaning that even though -// we say tensors are illegal the patterns get the new remapped values as -// tensors. This, naturally, breaks everything. To work around this we have this -// helper that tries to peek through the unrealized_conversion_casts and get out -// the actual values we expected to see from the conversion (and did before that -// change). struct ConvertedTensor { + // Optional affinity of the resource at the time it is consumed. + // May be nullptr if the affinity could not be determined. + IREE::Stream::AffinityAttr affinity; + // Resource storing the tensor. Value resource; + // Size of the resource in bytes. Value resourceSize; }; -ConvertedTensor consumeTensorOperand(Location loc, Value operand, - OpBuilder &builder); + +void expandResourceOperand(Location loc, Value convertedOperand, + SmallVectorImpl<Value> &newOperands, + OpBuilder &builder); +SmallVector<Value> expandResourceOperands(Location loc, + ValueRange convertedOperands, + ConversionPatternRewriter &rewriter); + +ConvertedTensor resolveTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder); +ConvertedTensor transferTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder); + +template <typename OpT> +struct AffinityAwareConversionPattern : public OpConversionPattern<OpT> { +public: + AffinityAwareConversionPattern( + const TypeConverter &typeConverter, MLIRContext *context, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + PatternBenefit benefit = 1) + : OpConversionPattern<OpT>(typeConverter, context, benefit), + affinityAnalysis(affinityAnalysis) {} + + IREE::Stream::AffinityAnalysis *getAffinityAnalysis() const { + return affinityAnalysis; + } + +protected: + ConvertedTensor resolveTensorOperand(Location loc, Value originalOperand, + Value convertedOperand, + OpBuilder &builder) const { + return mlir::iree_compiler::resolveTensorOperand( + loc, originalOperand, convertedOperand, affinityAnalysis, builder); + } + + ConvertedTensor + transferTensorOperand(Location loc, Value originalOperand, + Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + OpBuilder &builder) const { + return mlir::iree_compiler::transferTensorOperand( + loc, originalOperand, convertedOperand, requiredAffinityAttr, + affinityAnalysis, builder); + } + + IREE::Stream::AffinityAttr lookupResultAffinity(Value originalResult) const { + return mlir::iree_compiler::tryLookupResultAffinity(originalResult, + affinityAnalysis); + } + + IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr; +}; + +template <typename OpT> +struct AffinityOpConversionPattern + : public AffinityAwareConversionPattern<OpT> { +public: + AffinityOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + PatternBenefit benefit = 1) + : AffinityAwareConversionPattern<OpT>(typeConverter, context, + affinityAnalysis, benefit) {} + +protected: + virtual LogicalResult matchAndRewriteOnAffinity( + OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const = 0; + +private: + LogicalResult + matchAndRewrite(OpT op, typename OpConversionPattern<OpT>::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto executionAffinityAttr = + tryLookupExecutionAffinity(op, this->getAffinityAnalysis()); + return matchAndRewriteOnAffinity(op, adaptor, executionAffinityAttr, + rewriter); + } +}; } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel index 646d55e..38ad9ee 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel
@@ -15,8 +15,6 @@ iree_compiler_cc_library( name = "StandardToStream", srcs = [ - "ConvertConstantOps.cpp", - "ConvertStructuralOps.cpp", "Patterns.cpp", ], hdrs = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt index b910c60..3def716 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
@@ -16,8 +16,6 @@ HDRS "Patterns.h" SRCS - "ConvertConstantOps.cpp" - "ConvertStructuralOps.cpp" "Patterns.cpp" DEPS LLVMSupport
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp deleted file mode 100644 index 5ff99f7..0000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp +++ /dev/null
@@ -1,66 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" -#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" -#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler { - -namespace { - -struct ConvertTensorConstantOp : public OpConversionPattern<arith::ConstantOp> { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only handle tensor types - other arith.constant types (like i32) are - // ignored. - if (!llvm::isa<TensorType>(constantOp.getType())) - return failure(); - - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); - auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>( - constantOp.getLoc(), constantType, - convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(constantOp.getType()), - /*result_encoding_dims=*/ValueRange{}, affinityAttr); - - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>( - constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); - rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( - constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); - return success(); - } -}; - -} // namespace - -void populateStandardConstantToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { - conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>( - [](arith::ConstantOp op) { - return !llvm::isa<TensorType>(op.getType()); - }); - - patterns.insert<ConvertTensorConstantOp>(typeConverter, context); -} - -} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp deleted file mode 100644 index 5b29504..0000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp +++ /dev/null
@@ -1,406 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" -#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" -#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler { - -namespace { - -struct BranchOpConversion : public OpConversionPattern<mlir::cf::BranchOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = expandResourceOperands( - op.getLoc(), adaptor.getDestOperands(), rewriter); - rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(), - expandedOperands); - return success(); - } -}; - -struct CondBranchOpConversion - : public OpConversionPattern<mlir::cf::CondBranchOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto trueDestOperands = expandResourceOperands( - op.getLoc(), adaptor.getTrueDestOperands(), rewriter); - auto falseDestOperands = expandResourceOperands( - op.getLoc(), adaptor.getFalseDestOperands(), rewriter); - rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>( - op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands, - op.getFalseDest(), falseDestOperands); - return success(); - } -}; - -static ValueRange asValueRange(ArrayRef<Value> values) { return values; } - -struct SwitchOpConversion : public OpConversionPattern<mlir::cf::SwitchOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto defaultOperands = expandResourceOperands( - op.getLoc(), adaptor.getDefaultOperands(), rewriter); - auto caseOperands = llvm::to_vector( - llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) { - return expandResourceOperands(op.getLoc(), operands, rewriter); - })); - rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>( - op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands, - op.getCaseValuesAttr(), op.getCaseDestinations(), - llvm::to_vector(llvm::map_range(caseOperands, asValueRange))); - return success(); - } -}; - -struct SelectOpConversion : public OpConversionPattern<mlir::arith::SelectOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only handle selects where the operands are tensors (resources). - if (!llvm::isa<TensorType>(op.getTrueValue().getType())) - return failure(); - auto trueOperand = - consumeTensorOperand(op.getLoc(), adaptor.getTrueValue(), rewriter); - auto falseOperand = - consumeTensorOperand(op.getLoc(), adaptor.getFalseValue(), rewriter); - auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>( - op.getLoc(), adaptor.getCondition(), trueOperand.resource, - falseOperand.resource); - auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>( - op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize, - falseOperand.resourceSize); - rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>( - op, adaptor.getTrueValue().getType(), - ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()}); - return success(); - } -}; - -struct ScfIfOpConversion : public OpConversionPattern<mlir::scf::IfOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - - // Expand any resource results to resource + size. - SmallVector<Type> expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector<Result> resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector<Type> newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - // Create a new call that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original call as - // the result counts differ. - auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes, - op.getCondition()); - - ifOp.getThenRegion().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(), - ifOp.getThenRegion().end()); - - ifOp.getElseRegion().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(), - ifOp.getElseRegion().end()); - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector<Value> results; - for (auto result : resultMap) { - if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = ifOp.getResult(result.newIndex + 0); - auto resourceSize = ifOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create<mlir::UnrealizedConversionCastOp>( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(ifOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfForOpConversion : public OpConversionPattern<mlir::scf::ForOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto &typeConverter = *getTypeConverter(); - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter); - - // Expand any resource results to resource + size. - SmallVector<Type> expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector<Result> resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector<Type> newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - auto &block = op.getRegion().front(); - TypeConverter::SignatureConversion newSignature(block.getNumArguments()); - for (auto arg : llvm::enumerate(block.getArgumentTypes())) { - if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(), - newSignature))) { - return failure(); - } - } - - // Create a new loop that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original loop as - // the result counts differ. - auto forOp = rewriter.create<mlir::scf::ForOp>( - op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), expandedOperands); - - // Inline the block and update the block arguments. - rewriter.eraseBlock(forOp.getBody()); - rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(), - forOp.getRegion().end()); - if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter, - &newSignature))) { - return failure(); - } - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector<Value> results; - for (auto result : resultMap) { - if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = forOp.getResult(result.newIndex + 0); - auto resourceSize = forOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create<mlir::UnrealizedConversionCastOp>( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(forOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfWhileOpConversion : public OpConversionPattern<mlir::scf::WhileOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto &typeConverter = *getTypeConverter(); - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - - // Expand any resource results to resource + size. - SmallVector<Type> expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector<Result> resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector<Type> newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - TypeConverter::SignatureConversion newSignature(op.getNumOperands()); - for (auto argType : llvm::enumerate(op.getOperandTypes())) { - if (failed(typeConverter.convertSignatureArg( - argType.index(), argType.value(), newSignature))) { - return failure(); - } - } - - // Create a new call that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original call as - // the result counts differ. - auto whileOp = rewriter.create<mlir::scf::WhileOp>( - op.getLoc(), expandedTypes, expandedOperands); - - // Inline the `before` block and update the block arguments. - whileOp.getBefore().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(), - whileOp.getBefore().end()); - if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter, - &newSignature))) { - return failure(); - } - - // Inline the `after` block and update the block arguments. - whileOp.getAfter().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(), - whileOp.getAfter().end()); - if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter, - &newSignature))) { - return failure(); - } - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector<Value> results; - for (auto result : resultMap) { - if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = whileOp.getResult(result.newIndex + 0); - auto resourceSize = whileOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create<mlir::UnrealizedConversionCastOp>( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(whileOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfConditionOpConversion - : public OpConversionPattern<mlir::scf::ConditionOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter); - rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>( - op, adaptor.getCondition(), expandedOperands); - return success(); - } -}; - -struct ScfYieldOpConversion : public OpConversionPattern<mlir::scf::YieldOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands); - return success(); - } -}; - -} // namespace - -template <typename OpT> -static inline void addGenericLegalOp(ConversionTarget &conversionTarget, - TypeConverter &typeConverter) { - conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) { - return llvm::all_of( - op->getOperandTypes(), - [&typeConverter](Type t) { return typeConverter.isLegal(t); }) && - llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) { - return typeConverter.isLegal(t); - }); - }); -} - -void populateStandardStructuralToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { - conversionTarget.addLegalOp<mlir::ModuleOp>(); - - // We need to rewrite certain types on operands/results so use the default - // dynamic legality checker to force any ops using such types to run through - // our patterns. - - addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter); - patterns - .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>( - typeConverter, context); - - addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter); - patterns.insert<SelectOpConversion>(typeConverter, context); - - addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter); - addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter); - patterns - .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion, - ScfWhileOpConversion, ScfYieldOpConversion>(typeConverter, - context); -} - -} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp index 1725fb0..9924fd2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp
@@ -6,26 +6,419 @@ #include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" +#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir::iree_compiler { -void populateStandardConstantToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); +namespace { -void populateStandardStructuralToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); +struct ConvertTensorConstantOp + : public AffinityOpConversionPattern<arith::ConstantOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + arith::ConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + // Only handle tensor types - other arith.constant types (like i32) are + // ignored. + if (!llvm::isa<TensorType>(constantOp.getType())) { + return failure(); + } + + auto constantType = rewriter.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::Constant); + auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>( + constantOp.getLoc(), constantType, + convertAttributeToStream(constantOp.getValue()), + TypeAttr::get(constantOp.getType()), + /*result_encoding_dims=*/ValueRange{}, executionAffinityAttr); + + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); + auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>( + constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); + rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( + constantOp, unknownType, newOp.getResult(), constantSize, constantSize, + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); + return success(); + } +}; + +struct BranchOpConversion + : public AffinityAwareConversionPattern<mlir::cf::BranchOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = expandResourceOperands( + op.getLoc(), adaptor.getDestOperands(), rewriter); + rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(), + expandedOperands); + return success(); + } +}; + +struct CondBranchOpConversion + : public AffinityAwareConversionPattern<mlir::cf::CondBranchOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto trueDestOperands = expandResourceOperands( + op.getLoc(), adaptor.getTrueDestOperands(), rewriter); + auto falseDestOperands = expandResourceOperands( + op.getLoc(), adaptor.getFalseDestOperands(), rewriter); + rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>( + op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands, + op.getFalseDest(), falseDestOperands); + return success(); + } +}; + +static ValueRange asValueRange(ArrayRef<Value> values) { return values; } + +struct SwitchOpConversion + : public AffinityAwareConversionPattern<mlir::cf::SwitchOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto defaultOperands = expandResourceOperands( + op.getLoc(), adaptor.getDefaultOperands(), rewriter); + auto caseOperands = llvm::to_vector( + llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) { + return expandResourceOperands(op.getLoc(), operands, rewriter); + })); + rewriter.replaceOpWithNewOp<mlir::cf::SwitchOp>( + op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands, + op.getCaseValuesAttr(), op.getCaseDestinations(), + llvm::to_vector(llvm::map_range(caseOperands, asValueRange))); + return success(); + } +}; + +struct SelectOpConversion + : public AffinityAwareConversionPattern<mlir::arith::SelectOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle selects where the operands are tensors (resources). + if (!llvm::isa<TensorType>(op.getTrueValue().getType())) + return failure(); + auto trueOperand = resolveTensorOperand(op.getLoc(), op.getTrueValue(), + adaptor.getTrueValue(), rewriter); + auto falseOperand = resolveTensorOperand(op.getLoc(), op.getFalseValue(), + adaptor.getFalseValue(), rewriter); + auto resourceSelectOp = rewriter.create<mlir::arith::SelectOp>( + op.getLoc(), adaptor.getCondition(), trueOperand.resource, + falseOperand.resource); + auto sizeSelectOp = rewriter.create<mlir::arith::SelectOp>( + op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize, + falseOperand.resourceSize); + rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>( + op, adaptor.getTrueValue().getType(), + ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()}); + return success(); + } +}; + +struct ScfIfOpConversion + : public AffinityAwareConversionPattern<mlir::scf::IfOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource results to resource + size. + SmallVector<Type> expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector<Result> resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector<Type> newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + // Create a new call that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original call as + // the result counts differ. + auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), expandedTypes, + op.getCondition()); + + ifOp.getThenRegion().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(), + ifOp.getThenRegion().end()); + + ifOp.getElseRegion().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(), + ifOp.getElseRegion().end()); + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector<Value> results; + for (auto result : resultMap) { + if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = ifOp.getResult(result.newIndex + 0); + auto resourceSize = ifOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create<mlir::UnrealizedConversionCastOp>( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(ifOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfForOpConversion + : public AffinityAwareConversionPattern<mlir::scf::ForOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto &typeConverter = *getTypeConverter(); + + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter); + + // Expand any resource results to resource + size. + SmallVector<Type> expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector<Result> resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector<Type> newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + auto &block = op.getRegion().front(); + TypeConverter::SignatureConversion newSignature(block.getNumArguments()); + for (auto arg : llvm::enumerate(block.getArgumentTypes())) { + if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(), + newSignature))) { + return failure(); + } + } + + // Create a new loop that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original loop as + // the result counts differ. + auto forOp = rewriter.create<mlir::scf::ForOp>( + op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), expandedOperands); + + // Inline the block and update the block arguments. + rewriter.eraseBlock(forOp.getBody()); + rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(), + forOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter, + &newSignature))) { + return failure(); + } + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector<Value> results; + for (auto result : resultMap) { + if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = forOp.getResult(result.newIndex + 0); + auto resourceSize = forOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create<mlir::UnrealizedConversionCastOp>( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(forOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfWhileOpConversion + : public AffinityAwareConversionPattern<mlir::scf::WhileOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto &typeConverter = *getTypeConverter(); + + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); + + // Expand any resource results to resource + size. + SmallVector<Type> expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector<Result> resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector<Type> newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + TypeConverter::SignatureConversion newSignature(op.getNumOperands()); + for (auto argType : llvm::enumerate(op.getOperandTypes())) { + if (failed(typeConverter.convertSignatureArg( + argType.index(), argType.value(), newSignature))) { + return failure(); + } + } + + // Create a new call that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original call as + // the result counts differ. + auto whileOp = rewriter.create<mlir::scf::WhileOp>( + op.getLoc(), expandedTypes, expandedOperands); + + // Inline the `before` block and update the block arguments. + whileOp.getBefore().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(), + whileOp.getBefore().end()); + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter, + &newSignature))) { + return failure(); + } + + // Inline the `after` block and update the block arguments. + whileOp.getAfter().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(), + whileOp.getAfter().end()); + if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter, + &newSignature))) { + return failure(); + } + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector<Value> results; + for (auto result : resultMap) { + if (llvm::isa<IREE::Stream::ResourceType>(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = whileOp.getResult(result.newIndex + 0); + auto resourceSize = whileOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create<mlir::UnrealizedConversionCastOp>( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(whileOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfConditionOpConversion + : public AffinityAwareConversionPattern<mlir::scf::ConditionOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter); + rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>( + op, adaptor.getCondition(), expandedOperands); + return success(); + } +}; + +struct ScfYieldOpConversion + : public AffinityAwareConversionPattern<mlir::scf::YieldOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); + rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, expandedOperands); + return success(); + } +}; + +template <typename OpT> +static inline void addGenericLegalOp(ConversionTarget &conversionTarget, + TypeConverter &typeConverter) { + conversionTarget.addDynamicallyLegalOp<OpT>([&](OpT op) { + return llvm::all_of( + op->getOperandTypes(), + [&typeConverter](Type t) { return typeConverter.isLegal(t); }) && + llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) { + return typeConverter.isLegal(t); + }); + }); +} + +} // namespace void populateStandardToStreamConversionPatterns( MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion([](IndexType type) { return type; }); typeConverter.addConversion([](IntegerType type) { return type; }); typeConverter.addConversion([](FloatType type) { return type; }); @@ -35,10 +428,38 @@ conversionTarget.addIllegalOp<memref::DimOp, memref::RankOp, tensor::DimOp, tensor::RankOp>(); - populateStandardConstantToStreamPatterns(context, conversionTarget, - typeConverter, patterns); - populateStandardStructuralToStreamPatterns(context, conversionTarget, - typeConverter, patterns); + conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>( + [](arith::ConstantOp op) { + return !llvm::isa<TensorType>(op.getType()); + }); + patterns.insert<ConvertTensorConstantOp>(typeConverter, context, + affinityAnalysis); + + conversionTarget.addLegalOp<mlir::ModuleOp>(); + + // We need to rewrite certain types on operands/results so use the default + // dynamic legality checker to force any ops using such types to run through + // our patterns. + + addGenericLegalOp<mlir::cf::BranchOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::cf::CondBranchOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::cf::SwitchOp>(conversionTarget, typeConverter); + patterns + .insert<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>( + typeConverter, context, affinityAnalysis); + + addGenericLegalOp<mlir::arith::SelectOp>(conversionTarget, typeConverter); + patterns.insert<SelectOpConversion>(typeConverter, context, affinityAnalysis); + + addGenericLegalOp<mlir::scf::IfOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::scf::ForOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::scf::WhileOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::scf::ConditionOp>(conversionTarget, typeConverter); + addGenericLegalOp<mlir::scf::YieldOp>(conversionTarget, typeConverter); + patterns + .insert<ScfConditionOpConversion, ScfIfOpConversion, ScfForOpConversion, + ScfWhileOpConversion, ScfYieldOpConversion>( + typeConverter, context, affinityAnalysis); } } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h index 3314dcf..112e602 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h
@@ -10,6 +10,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform standard/builtin->stream @@ -17,7 +21,9 @@ // provided |typeConverter|. void populateStandardToStreamConversionPatterns( MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp index 4d1fa5f..35e1ca8 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -67,8 +67,9 @@ } }; -struct CallOpConversion : public OpConversionPattern<IREE::Util::CallOp> { - using OpConversionPattern::OpConversionPattern; +struct CallOpConversion + : public AffinityAwareConversionPattern<IREE::Util::CallOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Util::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -122,8 +123,9 @@ } }; -struct ReturnOpConversion : public OpConversionPattern<IREE::Util::ReturnOp> { - using OpConversionPattern::OpConversionPattern; +struct ReturnOpConversion + : public AffinityAwareConversionPattern<IREE::Util::ReturnOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Util::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -142,6 +144,7 @@ struct ExpandedGlobalResource { IREE::Util::GlobalOp resourceOp; IREE::Util::GlobalOp resourceSizeOp; + IREE::Stream::AffinityAttr affinityAttr; }; struct GlobalExpansionState { @@ -163,13 +166,16 @@ public: BaseGlobalConversionPattern( std::shared_ptr<GlobalExpansionState> expansionState, - TypeConverter &typeConverter, MLIRContext *context, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern<T>(typeConverter, context, benefit), - expansionState(std::move(expansionState)) {} + expansionState(std::move(expansionState)), + affinityAnalysis(affinityAnalysis) {} protected: mutable std::shared_ptr<GlobalExpansionState> expansionState; + IREE::Stream::AffinityAnalysis *affinityAnalysis; }; struct GlobalOpExpansion @@ -230,9 +236,13 @@ globalOp.getIsMutable(), indexType, std::optional<TypedAttr>{}); resourceSizeOp.setVisibility(globalOp.getVisibility()); + // Resolve the affinity of the global. + // We require this to be a single value today that is usually chosen from + // consumers (we take the hit on transfer from producers if needed). + auto affinityAttr = tryLookupGlobalAffinity(globalOp, affinityAnalysis); + // Materialize the initializer if we need to setup a tensor-like constant. if (tensorInitializerRequired) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp); auto initializerOp = rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc()); auto *entryBlock = rewriter.createBlock(&initializerOp.getBody()); @@ -265,6 +275,7 @@ expansionState->globalMap[globalOp.getSymName()] = ExpandedGlobalResource{ resourceOp, resourceSizeOp, + affinityAttr, }; return success(); @@ -289,7 +300,7 @@ auto &expandedGlobal = expandedGlobalIt->getSecond(); // Insert a load/transfer to the unknown resource lifetime. - auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext()); + auto unknownType = rewriter.getType<IREE::Stream::ResourceType>(); auto resource = rewriter .create<IREE::Util::GlobalLoadOp>( @@ -303,8 +314,8 @@ .getResult(); rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>( loadOp, unknownType, resource, resourceSize, resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr); + /*source_affinity=*/expandedGlobal.affinityAttr, + /*result_affinity=*/expandedGlobal.affinityAttr); return success(); } @@ -330,12 +341,14 @@ // Insert a transfer/store to the global with unknown lifetime. Lifetime // refinement will make this go away if possible. auto value = - consumeTensorOperand(storeOp.getLoc(), adaptor.getValue(), rewriter); + resolveTensorOperand(storeOp.getLoc(), storeOp.getValue(), + adaptor.getValue(), affinityAnalysis, rewriter); assert(expandedGlobal.resourceOp && "Missing resource op"); auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>( storeOp.getLoc(), expandedGlobal.resourceOp.getType(), value.resource, - value.resourceSize, value.resourceSize, /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr); + value.resourceSize, value.resourceSize, + /*source_affinity=*/value.affinity, + /*result_affinity=*/expandedGlobal.affinityAttr); rewriter.replaceOpWithNewOp<IREE::Util::GlobalStoreOp>( storeOp, transferOp.getResult(), expandedGlobal.resourceOp.getSymName()); @@ -347,30 +360,59 @@ } }; +struct OptimizationBarrierOpConversion + : public AffinityAwareConversionPattern<IREE::Util::OptimizationBarrierOp> { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> newOperands; + for (auto [originalOperand, convertedOperand] : + llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { + if (isa<TensorType>(convertedOperand.getType())) { + newOperands.push_back(resolveTensorOperand(op.getLoc(), originalOperand, + convertedOperand, rewriter) + .resource); + } else { + newOperands.push_back(convertedOperand); + } + } + rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op, + newOperands); + return success(); + } +}; + } // namespace -void populateUtilToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns - .insert<FuncOpSignatureConversion, CallOpConversion, ReturnOpConversion>( - typeConverter, context); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { + patterns.insert<FuncOpSignatureConversion>(typeConverter, context); + patterns.insert<CallOpConversion, ReturnOpConversion>(typeConverter, context, + affinityAnalysis); auto expansionState = std::make_shared<GlobalExpansionState>(); // TODO(#7432): add indirect global expansion support to streams. patterns .insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>( - expansionState, typeConverter, context); + expansionState, typeConverter, affinityAnalysis, context); patterns.add<GenericConvertTypesPattern<IREE::Util::GlobalOp>, GenericConvertTypesPattern<IREE::Util::GlobalLoadOp>, GenericConvertTypesPattern<IREE::Util::GlobalStoreOp>>( typeConverter, context); + + patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context, + affinityAnalysis, + /*benefit=*/2); } -void populateUtilToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateUtilToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion([=](IREE::Util::PtrType type, SmallVectorImpl<Type> &resultTypes) { // Expand pointers to tensors to [resource, sizeof resource] pointers. @@ -432,7 +474,8 @@ return typeConverter.isLegal(op.getResultTypes()); }); - populateUtilToStreamConversionPatterns(context, typeConverter, patterns); + populateUtilToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h index 5673c74..56fcca2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h
@@ -11,18 +11,24 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform util->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateUtilToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateUtilToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index b0b66ac..91f5c0f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h" #include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h" #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" @@ -39,79 +40,6 @@ namespace { -// Builds a stream.tensor.import op that imports an external tensor value into -// a stream resource. |consumingOps| will be populated with all ops that consume -// the original |sourceTensor| and that should not be replaced with the returned -// value. -static Value buildTensorImportOp(Location loc, Value sourceTensor, - Type targetType, - SmallPtrSetImpl<Operation *> &consumingOps, - IREE::Stream::AffinityAttr affinityAttr, - OpBuilder &builder) { - // Gather dynamic dimensions from the input value. - auto dynamicDims = - IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder); - - // Compute the size of the tensor once in the stream resource. - // This may differ from the external encoding of the tensor as imports are - // a transfer operation that may need to reformat the tensor. - auto encodingAttr = TypeAttr::get(sourceTensor.getType()); - Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>( - loc, builder.getIndexType(), encodingAttr, dynamicDims, affinityAttr); - - // Associate the external SSA value, encoding, and shape information with the - // stream resource. When lowering we'll then have all the metadata required - // even after erasing it all on the resource. - auto externalType = builder.getType<IREE::Stream::ResourceType>( - IREE::Stream::Lifetime::External); - auto importOp = builder.create<IREE::Stream::TensorImportOp>( - loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize, - affinityAttr); - consumingOps.insert(importOp); - - // If needed insert a transfer to the target lifetime. - Value result = importOp.getResult(); - if (targetType != externalType) { - result = builder - .create<IREE::Stream::AsyncTransferOp>( - loc, targetType, result, resultSize, resultSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr) - .getResult(); - } - - auto castOp = builder.create<mlir::UnrealizedConversionCastOp>( - loc, sourceTensor.getType(), ValueRange{result, resultSize}); - consumingOps.insert(castOp); - return castOp.getResult(0); -} - -// Builds a stream.tensor.export op that exports a stream resource into an -// external tensor value. -static Value buildTensorExportOp(Location loc, Value sourceValue, - TensorType targetType, ValueRange dynamicDims, - IREE::Stream::AffinityAttr affinityAttr, - OpBuilder &builder) { - auto source = consumeTensorOperand(loc, sourceValue, builder); - - // If needed insert a transfer to external resource lifetime. - auto externalType = builder.getType<IREE::Stream::ResourceType>( - IREE::Stream::Lifetime::External); - if (source.resource.getType() != externalType) { - source.resource = builder.create<IREE::Stream::AsyncTransferOp>( - loc, externalType, source.resource, source.resourceSize, - source.resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/affinityAttr); - } - - // Associate the stream resource and external encoding and shape information. - auto newOp = builder.create<IREE::Stream::TensorExportOp>( - loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims, - source.resourceSize, affinityAttr); - return newOp.getResult(); -} - // Returns true if |op| has tensor I/O that is not yet imported/exported using // the stream ops that capture encodings and shapes. static bool doesOperationNeedWrapping(Operation *op) { @@ -123,8 +51,9 @@ operand.getDefiningOp()); }) || llvm::any_of(op->getResults(), [](Value result) { - if (!isa<TensorType>(result.getType())) + if (!isa<TensorType>(result.getType())) { return false; + } return !llvm::all_of(result.getUsers(), llvm::IsaPred<TensorImportOp>); }); @@ -133,15 +62,19 @@ // Fallback handler for unknown ops taking/returning tensors that need to be // marshaled into/outof stream resource types. struct GenericResourcePattern : public ConversionPattern { - GenericResourcePattern(MLIRContext *context, TypeConverter &converter) - : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {} + GenericResourcePattern(MLIRContext *context, TypeConverter &converter, + IREE::Stream::AffinityAnalysis *affinityAnalysis) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context), + affinityAnalysis(affinityAnalysis) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - if (!doesOperationNeedWrapping(op)) + if (!doesOperationNeedWrapping(op)) { return failure(); + } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto executionAffinityAttr = affinityAnalysis->inferExecutionAffinity(op); // Export resources into tensor operands for the op to consume. SmallVector<Value> newOperands; @@ -156,11 +89,14 @@ auto tensorType = dyn_cast<TensorType>(oldOperand.getType()); assert(tensorType && "must have a tensor type to map to a resource"); + auto exportAffinityAttr = + affinityAnalysis->lookupResourceAffinity(oldOperand); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( op->getLoc(), oldOperand, rewriter); - newOperands.push_back(buildTensorExportOp(op->getLoc(), newOperand, - tensorType, dynamicDims, - affinityAttr, rewriter)); + newOperands.push_back(buildTensorExportOp( + op->getLoc(), oldOperand, newOperand, tensorType, dynamicDims, + exportAffinityAttr ? exportAffinityAttr : executionAffinityAttr, + rewriter)); } rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); }); @@ -168,46 +104,107 @@ rewriter.setInsertionPointAfter(op); for (auto result : op->getResults()) { auto tensorType = dyn_cast<TensorType>(result.getType()); - if (!tensorType) + if (!tensorType) { continue; + } + auto importAffinityAttr = + affinityAnalysis->lookupResourceAffinity(result); auto dynamicDims = IREE::Util::buildDynamicDimsForValue(op->getLoc(), result, rewriter); SmallPtrSet<Operation *, 4> consumingOps; auto importedValue = buildTensorImportOp( op->getLoc(), result, rewriter.getType<IREE::Stream::ResourceType>(), - consumingOps, affinityAttr, rewriter); + consumingOps, + importAffinityAttr ? importAffinityAttr : executionAffinityAttr, + rewriter); result.replaceAllUsesExcept(importedValue, consumingOps); } return success(); } -}; -struct OptimizationBarrierOpConversion - : public OpConversionPattern<IREE::Util::OptimizationBarrierOp> { - using OpConversionPattern< - IREE::Util::OptimizationBarrierOp>::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> newOperands; - for (Value v : adaptor.getOperands()) { - if (isa<TensorType>(v.getType())) { - newOperands.push_back( - consumeTensorOperand(op.getLoc(), v, rewriter).resource); - } else { - newOperands.push_back(v); - } + // Builds a stream.tensor.export op that exports a stream resource into an + // external tensor value. + Value buildTensorExportOp(Location loc, Value originalValue, + Value convertedValue, TensorType targetType, + ValueRange dynamicDims, + IREE::Stream::AffinityAttr executionAffinityAttr, + OpBuilder &builder) const { + auto source = + transferTensorOperand(loc, originalValue, convertedValue, + executionAffinityAttr, affinityAnalysis, builder); + + // If needed insert a transfer to external resource lifetime. + auto externalType = builder.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::External); + if (source.resource.getType() != externalType) { + source.resource = builder.create<IREE::Stream::AsyncTransferOp>( + loc, externalType, source.resource, source.resourceSize, + source.resourceSize, + /*source_affinity=*/source.affinity, + /*result_affinity=*/executionAffinityAttr); } - rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op, - newOperands); - return success(); + + // Associate the stream resource and external encoding and shape + // information. + auto newOp = builder.create<IREE::Stream::TensorExportOp>( + loc, targetType, source.resource, TypeAttr::get(targetType), + dynamicDims, source.resourceSize, executionAffinityAttr); + return newOp.getResult(); } + + // Builds a stream.tensor.import op that imports an external tensor value into + // a stream resource. |consumingOps| will be populated with all ops that + // consume the original |sourceTensor| and that should not be replaced with + // the returned value. + Value buildTensorImportOp(Location loc, Value sourceTensor, Type targetType, + SmallPtrSetImpl<Operation *> &consumingOps, + IREE::Stream::AffinityAttr executionAffinityAttr, + OpBuilder &builder) const { + // Gather dynamic dimensions from the input value. + auto dynamicDims = + IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder); + + // Compute the size of the tensor once in the stream resource. + // This may differ from the external encoding of the tensor as imports are + // a transfer operation that may need to reformat the tensor. + auto encodingAttr = TypeAttr::get(sourceTensor.getType()); + Value resultSize = builder.create<IREE::Stream::TensorSizeOfOp>( + loc, builder.getIndexType(), encodingAttr, dynamicDims, + executionAffinityAttr); + + // Associate the external SSA value, encoding, and shape information with + // the stream resource. When lowering we'll then have all the metadata + // required even after erasing it all on the resource. + auto externalType = builder.getType<IREE::Stream::ResourceType>( + IREE::Stream::Lifetime::External); + auto importOp = builder.create<IREE::Stream::TensorImportOp>( + loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize, + executionAffinityAttr); + consumingOps.insert(importOp); + + // If needed insert a transfer to the target lifetime. + Value result = importOp.getResult(); + if (targetType != externalType) { + result = builder + .create<IREE::Stream::AsyncTransferOp>( + loc, targetType, result, resultSize, resultSize, + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr) + .getResult(); + } + + auto castOp = builder.create<mlir::UnrealizedConversionCastOp>( + loc, sourceTensor.getType(), ValueRange{result, resultSize}); + consumingOps.insert(castOp); + return castOp.getResult(0); + } + + IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr; }; static void stripAffinityAttrs(ModuleOp moduleOp) { - moduleOp->removeAttr("stream.affinity.default"); auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity"); for (auto &op : moduleOp.getOps()) { op.removeDiscardableAttr(affinityName); @@ -223,6 +220,13 @@ void runOnOperation() override { auto *context = &getContext(); + // Run affinity analysis so that the required producer/consumer affinities + // for all SSA values we'll use during conversion are available. + AffinityAnalysis affinityAnalysis(getOperation()); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } + TypeConverter typeConverter; ConversionTarget conversionTarget(getContext()); RewritePatternSet patterns(&getContext()); @@ -235,10 +239,9 @@ // Allow unknown types to pass through; these come from custom dialects that // may be mixed into the IR we are converting. typeConverter.addConversion([=](Type type) -> Type { - // convert flow.channel into stream.channel - if (llvm::isa<IREE::Flow::ChannelType>(type)) + if (llvm::isa<IREE::Flow::ChannelType>(type)) { return IREE::Stream::ChannelType::get(context); - + } return !llvm::isa<TensorType>(type) ? type : Type{}; }); @@ -275,21 +278,20 @@ populateUtilConversionPatterns(context, conversionTarget, typeConverter, patterns); - populateUtilToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); + populateUtilToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); - populateStandardToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); - populateFlowToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); - populateHALToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); + populateStandardToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); + populateFlowToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); + populateHALToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); conversionTarget.markUnknownOpDynamicallyLegal( [&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); }); - patterns.insert<GenericResourcePattern>(context, typeConverter); - patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context, - /*benefit=*/2); + patterns.insert<GenericResourcePattern>(context, typeConverter, + &affinityAnalysis); // NOTE: we allow ops that we don't know about to allow custom dialects // that don't need anything Stream-specific to pass through.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp index 7e9c231..6732b97 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
@@ -399,7 +399,8 @@ if (sourceType != targetType && sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { LLVM_DEBUG(llvm::dbgs() - << " - clone source is a constant; cannot elide\n"); + << " - clone is a resource lifetime cast (" << sourceType + << " to " << targetType << "); cannot elide\n"); return false; }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp index f3cd827..2db2cd6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp
@@ -33,7 +33,9 @@ // Emplacement //===----------------------------------------------------------------------===// -static void replaceUsesAndTransfer(Value oldValue, Value newValue) { +static void +replaceUsesAndTransfer(Value oldValue, Value newValue, + IREE::Stream::AffinityAttr usageAffinityAttr) { assert(isa<IREE::Stream::ResourceType>(oldValue.getType())); assert(isa<IREE::Stream::ResourceType>(newValue.getType())); if (oldValue.getType() == newValue.getType()) { @@ -44,8 +46,8 @@ builder.setInsertionPointAfterValue(newValue); Value newValueSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( newValue.getLoc(), newValue, builder); - IREE::Stream::AffinityAttr sourceAffinity; - IREE::Stream::AffinityAttr resultAffinity; + IREE::Stream::AffinityAttr sourceAffinity = usageAffinityAttr; + IREE::Stream::AffinityAttr resultAffinity = usageAffinityAttr; Value transferValue = builder.create<IREE::Stream::AsyncTransferOp>( newValue.getLoc(), oldValue.getType(), newValue, newValueSize, newValueSize, sourceAffinity, resultAffinity); @@ -74,7 +76,7 @@ break; } - // Find potential. + // Find potential update to place the dispatch result into. Value targetResource; Value targetResourceSize; Value targetOffset; @@ -82,12 +84,22 @@ Value targetLength; Value targetResult; Value targetResultSize; + Attribute targetAffinityAttr; Operation *userOp = *result.user_begin(); if (auto updateOp = dyn_cast<IREE::Stream::AsyncUpdateOp>(userOp)) { if (updateOp.getUpdate() != result) { // TODO(#14566): continue if sparse emplacement on multiple results. break; } + + // Currently only allow exactly matching affinities. + // TODO(multi-device): memory compatibility - if compatible then allow. + if (updateOp.getAffinityAttr() != dispatchOp.getAffinityAttr()) { + continue; + } + + // Try to move all SSA values required into the appropriate place. + // TODO(benvanik): undo this if there's a failure (or record/roll-back). if (!IREE::Util::tryMoveProducerBefore(updateOp.getUpdateSize(), dispatchOp) || !IREE::Util::tryMoveProducerBefore(updateOp.getTargetSize(), @@ -102,6 +114,7 @@ // TODO(#14566): continue if sparse emplacement on multiple results. break; } + targetResource = updateOp.getTarget(); if (targetResource.getDefiningOp() == dispatchOp) { // NOTE: we may have already replaced the update target with one of our @@ -115,6 +128,7 @@ targetLength = updateOp.getUpdateSize(); targetResult = updateOp.getResult(); targetResultSize = updateOp.getTargetSize(); + targetAffinityAttr = updateOp.getAffinityAttr(); } if (!targetResource) { // TODO(#14566): continue if sparse emplacement on multiple results. @@ -136,7 +150,7 @@ dispatchOp.getResultSizesMutable().assign(resultSizes); // Replace users with the result of the dispatch op. - replaceUsesAndTransfer(targetResult, result); + replaceUsesAndTransfer(targetResult, result, dispatchOp.getAffinityAttr()); userOp->erase(); didChange = true;
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir index a3eae92..b11839b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}> #map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> @@ -57,7 +57,7 @@ } } -// vulkan uses default materialization patterns which unsets the encodings. +// Vulkan uses default materialization patterns which unsets the encodings. // CHECK-LABEL: util.func public @lhs_encoding // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: util.return %[[ARG0]]
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir index 4842e51..5354b82 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir
@@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @parameterLoad // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.buffer, !hal.fence) util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) { @@ -7,7 +9,7 @@ %c51_i64 = arith.constant 51 : i64 %c100 = arith.constant 100 : index %c101 = arith.constant 101 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: %[[BUFFERS:.+]]:2 = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) @@ -15,7 +17,7 @@ // CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") // CHECK-NEXT: "scope"::"key0"[%c50_i64] : !hal.buffer{%c100} // CHECK-NEXT: "scope"::"key1"[%c51_i64] : !hal.buffer{%c101} - %results:2, %result_timepoint = stream.parameter.load await(%wait) => { + %results:2, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => { "scope"::"key0"[%c50_i64] : !stream.resource<constant>{%c100}, "scope"::"key1"[%c51_i64] : !stream.resource<constant>{%c101} } => !stream.timepoint @@ -25,19 +27,21 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterLoadNoScope // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.fence) util.func public @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource<constant>, !stream.timepoint) { %c50_i64 = arith.constant 50 : i64 %c100 = arith.constant 100 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") // CHECK-NEXT: "key"[%c50_i64] : !hal.buffer{%c100} - %result, %result_timepoint = stream.parameter.load await(%wait) => { + %result, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => { "key"[%c50_i64] : !stream.resource<constant>{%c100} } => !stream.timepoint // CHECK: return %[[BUFFER]], %[[SIGNAL]] @@ -46,6 +50,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterRead // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterRead(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint { @@ -53,19 +59,21 @@ %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: "scope"::"key"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer - %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint + %timepoint = stream.parameter.read on(#hal.device.affinity<@device>) await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300} => !stream.timepoint // CHECK: return %[[SIGNAL]] util.return %timepoint : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterWrite // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint { @@ -73,19 +81,21 @@ %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !hal.buffer -> "scope"::"key"[%c50_i64] - %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint + %timepoint = stream.parameter.write on(#hal.device.affinity<@device>) await(%wait) => %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint // CHECK: return %[[SIGNAL]] util.return %timepoint : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterGather // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint { @@ -99,7 +109,7 @@ %c201 = arith.constant 201 : index %c202 = arith.constant 202 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) @@ -107,7 +117,7 @@ // CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer, // CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer, // CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !hal.buffer - %timepoint = stream.parameter.gather await(%wait) => { + %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => { "scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300}, "scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300}, "scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource<transient>{%c300} @@ -118,6 +128,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterGatherNoScope // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource<transient>) -> !stream.timepoint { @@ -128,14 +140,14 @@ %c200 = arith.constant 200 : index %c201 = arith.constant 201 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer, // CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer - %timepoint = stream.parameter.gather await(%wait) => { + %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => { "key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource<transient>{%c300}, "key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource<transient>{%c300} } => !stream.timepoint @@ -145,6 +157,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterScatter // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource<transient>) -> !stream.timepoint { @@ -158,7 +172,7 @@ %c201 = arith.constant 201 : index %c202 = arith.constant 202 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) @@ -167,7 +181,7 @@ // CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64], // CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64] // CHECK-NEXT: } - %timepoint = stream.parameter.scatter await(%wait) => { + %timepoint = stream.parameter.scatter on(#hal.device.affinity<@device>) await(%wait) => { %source[%c100 for %c200] : !stream.resource<transient>{%c300} -> "scope"::"key0"[%c50_i64], %source[%c101 for %c201] : !stream.resource<transient>{%c300} -> "scope"::"key1"[%c51_i64], %source[%c102 for %c202] : !stream.resource<transient>{%c300} -> "scope"::"key2"[%c52_i64]
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 0d730cd..bd56dc8 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -136,18 +136,6 @@ if (compileTo == IREEVMPipelinePhase::Input) return; // early-exit - // If the user specified a set of target devices we attach them to the module - // IR so that they are available for all passes that may want to use this - // information. If trying to compile in a generic mode the user should omit - // specifying targets. - IREE::HAL::AssignmentOptions halAssignmentOptions; - halAssignmentOptions.legacyTargetBackends = - halTargetOptions.legacyTargetBackends; - halAssignmentOptions.targetDevices = halTargetOptions.targetDevices; - halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice; - IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - halAssignmentOptions); - // Now that inputs are legalized, generate wrapper for entry functions. if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI"); @@ -172,6 +160,18 @@ if (compileTo == IREEVMPipelinePhase::ABI) return; // early-exit + // If the user specified a set of target devices we attach them to the module + // IR so that they are available for all passes that may want to use this + // information. If trying to compile in a generic mode the user should omit + // specifying targets. + IREE::HAL::AssignmentOptions halAssignmentOptions; + halAssignmentOptions.legacyTargetBackends = + halTargetOptions.legacyTargetBackends; + halAssignmentOptions.targetDevices = halTargetOptions.targetDevices; + halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice; + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + halAssignmentOptions); + GlobalOptimization::TransformOptions globalTransformOptions; globalTransformOptions.options = globalOptimizationOptions;
diff --git a/tools/test/compile_pipelines.mlir b/tools/test/compile_pipelines.mlir index 2fd4a6c..fb6dbbe 100644 --- a/tools/test/compile_pipelines.mlir +++ b/tools/test/compile_pipelines.mlir
@@ -1,10 +1,10 @@ // RUN: iree-opt --iree-common-input-transformation-pipeline %s | \ // RUN: iree-opt --iree-abi-transformation-pipeline - | \ -// RUN: iree-opt --iree-common-input-transformation-pipeline - | \ +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-device-assignment-pipeline{target-devices=local})' --iree-hal-local-target-device-backends=vmvx - | \ // RUN: iree-opt --iree-global-optimization-transformation-pipeline - | \ // RUN: iree-opt --iree-flow-transformation-pipeline - | \ // RUN: iree-opt --iree-stream-transformation-pipeline - | \ -// RUN: iree-opt --iree-hal-transformation-pipeline --iree-hal-target-backends=vmvx - | \ +// RUN: iree-opt --iree-hal-transformation-pipeline - | \ // RUN: iree-opt --iree-vm-transformation-pipeline - | \ // RUN: FileCheck %s
diff --git a/tools/test/compile_to_continuation.mlir b/tools/test/compile_to_continuation.mlir index 9c78153..5476462 100644 --- a/tools/test/compile_to_continuation.mlir +++ b/tools/test/compile_to_continuation.mlir
@@ -1,79 +1,89 @@ // RUN: iree-compile --compile-to=input %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=INPUT-PHASE // INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> // RUN: iree-compile --compile-to=abi %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=ABI-PHASE // ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=flow %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FLOW-PHASE // FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=stream %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow %s | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ +// RUN: FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE +// FLOW-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> + +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=STREAM-PHASE // STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=stream %s | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ +// RUN: FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE +// STREAM-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> + +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE // EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE // EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=HAL-PHASE // HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=VM-PHASE // VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> // RUN: iree-compile --compile-to=input %s | \ -// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE // FROM-INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> // RUN: iree-compile --compile-to=abi %s | \ -// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE // FROM-ABI-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=flow %s | \ -// RUN: iree-compile --compile-from=flow --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --compile-from=flow --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-FLOW-PHASE // FROM-FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=stream %s | \ -// RUN: iree-compile --compile-from=stream --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --compile-from=stream --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-STREAM-PHASE // FROM-STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=executable-sources --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-SOURCES-PHASE // FROM-EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=executable-targets --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-TARGETS-PHASE // FROM-EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=hal --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-HAL-PHASE // FROM-HAL-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view> -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=vm --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-VM-PHASE // FROM-VM-PHASE: vm.func private @abs(%arg0: !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer_view>
diff --git a/tools/test/compile_to_phase.mlir b/tools/test/compile_to_phase.mlir index 0390564..f1861a0 100644 --- a/tools/test/compile_to_phase.mlir +++ b/tools/test/compile_to_phase.mlir
@@ -7,36 +7,44 @@ // ABI-PHASE: %[[INPUT:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor<f32> // ABI-PHASE: math.absf %[[INPUT]] : tensor<f32> -// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE +// RUN: iree-compile --compile-to=flow %s --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx | FileCheck %s --check-prefix=FLOW-PHASE // FLOW-PHASE: flow.executable.export public @abs_dispatch_0 // FLOW-PHASE: flow.dispatch @abs_dispatch_0 -// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE +// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE +// FLOW-PHASE-NO-DEVICE: flow.executable.export public @abs_dispatch_0 +// FLOW-PHASE-NO-DEVICE: flow.dispatch @abs_dispatch_0 + +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=STREAM-PHASE // STREAM-PHASE: stream.executable.export public @abs_dispatch_0 // STREAM-PHASE: stream.cmd.dispatch @abs_dispatch_0 -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE +// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE +// STREAM-PHASE-NO-DEVICE: stream.executable.export public @abs_dispatch_0 +// STREAM-PHASE-NO-DEVICE: stream.cmd.dispatch @abs_dispatch_0 + +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE // EXECUTABLE-SOURCES-PHASE: hal.executable private @abs_dispatch_0 // EXECUTABLE-SOURCES-PHASE: hal.executable.variant // EXECUTABLE-SOURCES-PHASE: linalg.generic // EXECUTABLE-SOURCES-PHASE: math.absf -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE // EXECUTABLE-TARGETS-PHASE: hal.executable private @abs_dispatch_0 // EXECUTABLE-TARGETS-PHASE: hal.executable.variant // EXECUTABLE-TARGETS-PHASE: vm.abs.f32 -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE // HAL-PHASE: hal.executable private @abs_dispatch_0 // HAL-PHASE: hal.executable.binary // HAL-PHASE: hal.command_buffer.dispatch -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE // VM-PHASE: vm.rodata private @abs_dispatch_0 // VM-PHASE: vm.call @hal.command_buffer.dispatch -// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE +// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE // END-PHASE: vm.rodata private @abs_dispatch_0 // END-PHASE: vm.call @hal.command_buffer.dispatch