Adding the #hal.device.affinity attr replacing #hal.affinity.queue. The queue affinity attr was added as a placeholder to test things but was never used/useful.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir index 3caa5b0..92a0208 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir
@@ -1,5 +1,7 @@ // RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" -split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: util.func public @foo( // CHECK: %[[argA:.*]]: tensor<?x?xf32>, %[[argB:.*]]: tensor<5x10xf32>, %[[argC:.*]]: tensor<10x11xf32> util.func public @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: tensor<10x11xf32>) -> (tensor<?x?xf32>, tensor<5x11xf32>) { @@ -21,7 +23,7 @@ flow.return %argA : tensor<?x?xf32> } // CHECK: %[[r1:.*]] = flow.dispatch.workgroups(%[[argB]], %[[argC]]) : (tensor<5x10xf32>, tensor<10x11xf32>) -> tensor<5x11xf32> - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> // CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor<readonly:tensor<5x10xf32>>, %[[arg4:.*]]: !flow.dispatch.tensor<readonly:tensor<10x11xf32>>, %[[arg5:.*]]: !flow.dispatch.tensor<writeonly:tensor<5x11xf32>>) // CHECK-DAG: %[[loadB:.*]] = flow.dispatch.tensor.load %[[arg3]], offsets = [0, 0], sizes = [5, 10], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<5x10xf32>> -> tensor<5x10xf32> // CHECK-DAG: %[[loadC:.*]] = flow.dispatch.tensor.load %[[arg4]], offsets = [0, 0], sizes = [10, 11], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10x11xf32>> -> tensor<10x11xf32> @@ -31,7 +33,9 @@ // CHECK: flow.dispatch.tensor.store %[[matmul]], %[[arg5]], offsets = [0, 0], sizes = [5, 11], strides = [1, 1] : tensor<5x11xf32> -> !flow.dispatch.tensor<writeonly:tensor<5x11xf32>> // CHECK: flow.return // CHECK: } - %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes {stream.affinity = #hal.affinity.queue<[0]>} { + %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes { + stream.affinity = #hal.device.affinity<@device> + } { %zero = arith.constant 0.0 : f32 %0 = tensor.empty() : tensor<5x11xf32> %1 = linalg.fill ins(%zero : f32) outs(%0 : tensor<5x11xf32>) -> tensor<5x11xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir index e3db1b6..57304a1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
@@ -67,11 +67,13 @@ // Tests that any hoistable attrs are propagated to the outlined globals. +util.global private @device : !hal.device + // CHECK: util.global private @__constant_tensor_2xi32 -// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> +// CHECK-SAME: stream.affinity = #hal.device.affinity<@device, [0]> // CHECK-NEXT: util.func private @set_affinity util.func private @set_affinity() attributes { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device, [0]> } { // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32 %cst = arith.constant dense<[0, 1]> : tensor<2xi32>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir index 0a0f9e5..dd9d651 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -78,6 +78,9 @@ // ----- +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK: flow.executable private @dispatchFn1_dispatch_0 // CHECK-LABEL: util.func public @dispatchFn1 @@ -85,9 +88,9 @@ %x = arith.constant 100 : index %y = arith.constant 50 : index // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0 - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device_a> } = ( %arg: !flow.dispatch.tensor<readonly:tensor<8x4xf32>>, %ret: !flow.dispatch.tensor<writeonly:tensor<4x8xf32>> ) { @@ -103,9 +106,9 @@ %x = arith.constant 100 : index %y = arith.constant 50 : index // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0 - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes { - stream.affinity = #hal.affinity.queue<[1]> + stream.affinity = #hal.device.affinity<@device_b> } = ( %arg: !flow.dispatch.tensor<readonly:tensor<8x4xf32>>, %ret: !flow.dispatch.tensor<writeonly:tensor<4x8xf32>> ) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 74c0fc9..de9bbb4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -23,24 +23,6 @@ namespace { -// Returns the device queue affinity mask indicating which device queues the -// operations are allowed to execute on. -static Value buildQueueAffinityMask(Location loc, - IREE::Stream::AffinityAttr affinityAttr, - Value device, OpBuilder &builder) { - // Try to find a specified affinity. This may be on the op provided or one of - // its parent regions. - if (auto queueAffinityAttr = - llvm::dyn_cast_if_present<IREE::HAL::AffinityQueueAttr>( - affinityAttr)) { - return builder.create<arith::ConstantIntOp>( - loc, queueAffinityAttr.getMask(), 64); - } - - // No affinity specified; use default (any) affinity. - return builder.create<arith::ConstantIntOp>(loc, -1, 64); -} - struct ContextResolveOpPattern : public StreamConversionPattern<IREE::Stream::ContextResolveOp> { using StreamConversionPattern::StreamConversionPattern; @@ -50,9 +32,36 @@ auto resultTypes = llvm::to_vector(resolveOp.getResultTypes()); assert(!resultTypes.empty() && "must have at least one result"); - // TODO(multi-device): emit get with derived ordinal or lookup with attr. - Value device = - IREE::HAL::DeviceType::resolveAny(resolveOp.getLoc(), rewriter); + // Get the affinity from the op or an ancestor. Note that there may be no + // affinity specified at all. + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(resolveOp); + + // We currently only handle HAL device affinities. + // We could make this an interface to select the device and allow users to + // provide their own affinities to convert to HAL. In the future users may + // also want to provide devices as function arguments post-initialization. + // For now we just have one way to specify device globals. + auto deviceAffinityAttr = + dyn_cast_if_present<IREE::HAL::DeviceAffinityAttr>(affinityAttr); + if (!deviceAffinityAttr) { + resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " + "affinities are supported"; + return rewriter.notifyMatchFailure( + resolveOp, "only HAL device affinities are supported"); + } + + // Get the device handle and queue. + // + // TODO(multi-device): specialized types; may need analysis we don't have + // or at least a symbol lookup. An alternative would be an optional type + // on the affinity in cases where we've evaluated it early but for now + // we assume all device types are unspecialized. + auto deviceType = rewriter.getType<IREE::HAL::DeviceType>(); + Value device = rewriter.create<IREE::Util::GlobalLoadOp>( + resolveOp.getLoc(), deviceType, + deviceAffinityAttr.getDevice().getValue(), + /*is_immutable=*/true); + int64_t queueMask = deviceAffinityAttr.getQueueMask(); SmallVector<Value> results; if (isa<IREE::HAL::DeviceType>(resultTypes[0])) { @@ -66,8 +75,8 @@ } if (resultTypes.size() > 1) { if (isa<IntegerType>(resultTypes[1])) { - results.push_back(buildQueueAffinityMask( - resolveOp.getLoc(), resolveOp.getAffinityAttr(), device, rewriter)); + results.push_back(rewriter.create<arith::ConstantIntOp>( + resolveOp.getLoc(), queueMask, 64)); } else { return rewriter.notifyMatchFailure( resolveOp, @@ -698,54 +707,67 @@ caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp)); }); - // Select the variant index. - Value selectedIndex = buildIfElseTree( - loc, caseExportOps.size(), - [&](Location loc, size_t i, OpBuilder &builder) { - auto exportOp = caseExportOps[i].second; - auto variantOp = - exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>(); - return variantOp.buildCondition(device, rewriter); - }, - rewriter); - - // Allow each variant to define how it is dispatched. - auto switchOp = rewriter.replaceOpWithNewOp<scf::IndexSwitchOp>( - dispatchOp, TypeRange{}, selectedIndex, caseIndices, - caseIndices.size()); - for (size_t i = 0; i < caseExportOps.size(); ++i) { - auto entryPointAttr = caseExportOps[i].first; - auto exportOp = caseExportOps[i].second; - auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); - auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); - + auto recordDispatch = [&](SymbolRefAttr entryPointAttr, + IREE::HAL::ExecutableExportOp exportOp, + OpBuilder &builder) { // Record push constants and buffer bindings. recordParameters(loc, affinityAttr, device, commandBuffer, exportOp, - dispatchOp, adaptor, caseBuilder); + dispatchOp, adaptor, builder); // Dispatch with a target-specific workgroup count. - auto caseWorkgroupCount = exportOp.calculateWorkgroupCount( - loc, device, adaptor.getWorkload(), caseBuilder); - Value executable = caseBuilder.create<IREE::HAL::ExecutableLookupOp>( - loc, caseBuilder.getType<IREE::HAL::ExecutableType>(), device, + auto workgroupCount = exportOp.calculateWorkgroupCount( + loc, device, adaptor.getWorkload(), builder); + Value executable = builder.create<IREE::HAL::ExecutableLookupOp>( + loc, builder.getType<IREE::HAL::ExecutableType>(), device, entryPointAttr.getRootReference().getValue()); - Value ordinal = caseBuilder.create<IREE::HAL::ExecutableExportOrdinalOp>( - loc, caseBuilder.getIndexType(), entryPointAttr); - auto flags = caseBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>( + Value ordinal = builder.create<IREE::HAL::ExecutableExportOrdinalOp>( + loc, builder.getIndexType(), entryPointAttr); + auto flags = builder.getAttr<IREE::HAL::DispatchFlagsAttr>( IREE::HAL::DispatchFlags::None); - caseBuilder.create<IREE::HAL::CommandBufferDispatchOp>( - loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0], - caseWorkgroupCount[1], caseWorkgroupCount[2], flags); + return builder.create<IREE::HAL::CommandBufferDispatchOp>( + loc, commandBuffer, executable, ordinal, workgroupCount[0], + workgroupCount[1], workgroupCount[2], flags); + }; - caseBuilder.create<scf::YieldOp>(loc); + // If there is only one variant we can emit that directly without a + // conditional check. The same result should occur later on but it saves + // a lot of IR during generation if we know we can avoid it. + if (caseExportOps.size() == 1) { + auto [entryPointAttr, exportOp] = caseExportOps.front(); + rewriter.replaceOp(dispatchOp, + recordDispatch(entryPointAttr, exportOp, rewriter)); + } else { + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseExportOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + auto exportOp = caseExportOps[i].second; + auto variantOp = + exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>(); + return variantOp.buildCondition(device, rewriter); + }, + rewriter); + + // Allow each variant to define how it is dispatched. + auto switchOp = rewriter.create<scf::IndexSwitchOp>( + loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size()); + for (size_t i = 0; i < caseExportOps.size(); ++i) { + auto [entryPointAttr, exportOp] = caseExportOps[i]; + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + recordDispatch(entryPointAttr, exportOp, caseBuilder); + caseBuilder.create<scf::YieldOp>(loc); + } + + // Fallback for no available variant. Today we just no-op as executable + // loading should have already failed. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + defaultBuilder.create<scf::YieldOp>(loc); + + rewriter.replaceOp(dispatchOp, switchOp); } - // Fallback for no available variant. Today we just no-op as executable - // loading should have already failed. - auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); - auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); - defaultBuilder.create<scf::YieldOp>(loc); - return success(); }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir index 3f88bd1..bb2108f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir
@@ -1,15 +1,17 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @channel_create // CHECK-SAME: () -> !hal.channel util.func public @channel_create() -> !stream.channel { - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} : !hal.device + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant 3 // CHECK-DAG: %[[ID:.+]] = util.null : !util.buffer // CHECK-DAG: %[[GROUP:.+]] = util.buffer.constant : !util.buffer = "group" // CHECK-DAG: %[[DEFAULT:.+]] = arith.constant -1 // CHECK: %[[CHANNEL:.+]] = hal.channel.create device(%[[DEVICE]] : !hal.device) affinity(%[[AFFINITY]]) flags(0) id(%[[ID]]) group(%[[GROUP]]) rank(%[[DEFAULT]]) count(%[[DEFAULT]]) : !hal.channel - %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) group("group") : !stream.channel + %channel = stream.channel.create on(#hal.device.affinity<@device, [0, 1]>) group("group") : !stream.channel // CHECK: util.return %[[CHANNEL]] util.return %channel : !stream.channel }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 7cdd991..941c15b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
@@ -3,12 +3,14 @@ // Today all memory control operations are ignored and we're just left with // the normal sequential execution barriers. +util.global private @device : !hal.device + // CHECK-LABEL: @cmdMemoryControl util.func public @cmdMemoryControl(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) { // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] @@ -22,13 +24,15 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdFill util.func public @cmdFill(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c255_i32 = arith.constant 255 : i32 // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) { // CHECK-NEXT: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%arg0 : !hal.buffer)[%c0, %c128] // CHECK-SAME: pattern(%c255_i32 : i32) @@ -41,12 +45,14 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdCopy util.func public @cmdCopy(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<staging>{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<staging>{%arg3}) { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: source(%arg0 : !hal.buffer)[%c0] // CHECK-SAME: target(%arg2 : !hal.buffer)[%c0] @@ -60,12 +66,14 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdCollective util.func public @cmdCollective(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<transient>, %arg3: index, %arg4: !stream.channel) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<transient>{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<transient>{%arg3}) { // Out-of-place all-reduce: // CHECK-NEXT: hal.command_buffer.collective @@ -127,12 +135,14 @@ // than we actually need and guard a lot more work than we otherwise would need // to. +util.global private @device : !hal.device + // CHECK-LABEL: @cmdExecute util.func public @cmdExecute(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) { stream.cmd.concurrent { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} @@ -166,10 +176,6 @@ #executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> #executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> -#device_target_cpu = #hal.device.target<"llvm-cpu", [ - #executable_target_aarch64, - #executable_target_x86_64 -]> #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [ #hal.descriptor_set.layout<0, bindings = [ #hal.descriptor_set.binding<4, storage_buffer> @@ -219,6 +225,8 @@ } } +util.global private @device : !hal.device + // CHECK-LABEL: @cmdDispatch util.func public @cmdDispatch(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<external>, %arg3: index) -> !stream.timepoint { %c0 = arith.constant 0 : index @@ -229,7 +237,7 @@ %c5_i32 = arith.constant 5 : i32 %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) { // Switch for each executable variant by checking conditions and ranking: // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer> // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") @@ -297,6 +305,8 @@ // Tests conversion of streamable calls and function declarations. // Expect a command buffer and a buffer + offset + length for each resource. +util.global private @device : !hal.device + // CHECK: util.func private @cmdFunc(%arg0: !hal.command_buffer, %arg1: !hal.buffer, %arg2: index, %arg3: index, %arg4: i32, %arg5: !hal.buffer, %arg6: index, %arg7: index, %arg8: !custom.type, %arg9: !hal.buffer, %arg10: index, %arg11: index) stream.cmd.func private @cmdFunc(%arg0[%arg1 for %arg2]: !stream.resource<*>, %arg3: i32, %arg4[%arg5 for %arg6]: !stream.resource<*>, %arg7: !custom.type, %arg8[%arg9 for %arg10]: !stream.resource<*>) @@ -310,7 +320,7 @@ // CHECK-DAG: %[[SIZE2:.+]] = arith.constant 102 %size2 = arith.constant 102 : index // CHECK: %[[COMMAND_BUFFER:.+]] = hal.command_buffer.create - %timepoint = stream.cmd.execute with(%arg0 as %stream0: !stream.resource<external>{%size0}, %arg2 as %stream1: !stream.resource<external>{%size1}, %arg4 as %stream2: !stream.resource<external>{%size2}) { + %timepoint = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %stream0: !stream.resource<external>{%size0}, %arg2 as %stream1: !stream.resource<external>{%size1}, %arg4 as %stream2: !stream.resource<external>{%size2}) { // CHECK: util.call @cmdFunc(%[[COMMAND_BUFFER]], %arg0, %c0, %[[SIZE0]], %arg1, %arg2, %c0, %[[SIZE1]], %arg3, %arg4, %c0, %[[SIZE2]]) : // CHECK-SAME: (!hal.command_buffer, !hal.buffer, index, index, i32, !hal.buffer, index, index, !custom.type, !hal.buffer, index, index) -> () stream.cmd.call @cmdFunc(ro %stream0[%c0 for %size0], %arg1, rw %stream1[%c0 for %size1], %arg3, wo %stream2[%c0 for %size2]) : (!stream.resource<external>{%size0}, i32, !stream.resource<external>{%size1}, !custom.type, !stream.resource<external>{%size2}) -> () @@ -324,12 +334,14 @@ // appropriate queue affinity mask. The final affinity is the result of ORing // the target affinities (0b01 | 0b10 = 0b11 = 3). +util.global private @device : !hal.device + // CHECK-LABEL: @cmdExecuteAffinities util.func public @cmdExecuteAffinities(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute on(#hal.affinity.queue<[0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device, [0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) { stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3} } => !stream.timepoint // CHECK: hal.device.queue.execute
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir index 5d73951..20a9c59 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir
@@ -1,19 +1,12 @@ // RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s -// CHECK-LABEL: @contextResolveAllocator -util.func public @contextResolveAllocator() -> !hal.allocator { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} - // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator - %allocator = stream.context.resolve : !hal.allocator - // CHECK: util.return %[[ALLOCATOR]] - util.return %allocator : !hal.allocator -} +util.global private @device : !hal.device -// ----- - -// CHECK-LABEL: @contextResolveDevice -util.func public @contextResolveDevice() -> !hal.device { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} +// CHECK-LABEL: @contextResolveDefaultDevice +util.func public @contextResolveDefaultDevice() -> !hal.device attributes { + stream.affinity = #hal.device.affinity<@device> +} { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device %device = stream.context.resolve : !hal.device // CHECK: util.return %[[DEVICE]] util.return %device : !hal.device @@ -21,22 +14,65 @@ // ----- +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveDevice +util.func public @contextResolveDevice() -> !hal.device { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + %device = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device +} + +// ----- + +util.global private @device : !hal.device + // CHECK-LABEL: @contextResolveDeviceQueueAffinityAny util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 - %device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 + %device, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] util.return %device, %queue_affinity_any : !hal.device, i64 } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @contextResolveDeviceQueueAffinity45 util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 + %device, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] util.return %device, %queue_affinity_45 : !hal.device, i64 } + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveAllocator +util.func public @contextResolveAllocator() -> !hal.allocator { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + %allocator = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.allocator + // CHECK: util.return %[[ALLOCATOR]] + util.return %allocator : !hal.allocator +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveAllocatorQueueAffinity45 +util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 + // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 +}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir index 1182ee4..efa925a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir
@@ -1,44 +1,50 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @file_constant // CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer) util.func public @file_constant(%buffer: !util.buffer) { %c0 = arith.constant 0 : index %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: = hal.ex.file.from_memory device(%[[DEVICE]] : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file - %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + %file = stream.file.constant on(#hal.device.affinity<@device>) %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file util.return } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @file_read // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) util.func public @file_read(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource<variable>) -> !stream.timepoint { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: %[[SIGNAL:.+]] = hal.fence.create // CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0) - %signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint + %signal = stream.file.read on(#hal.device.affinity<@device>) await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource<variable>{%c1088} => !stream.timepoint // CHECK: util.return %[[SIGNAL]] util.return %signal : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @file_write // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) util.func public @file_write(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource<variable>) -> !stream.timepoint { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: %[[SIGNAL:.+]] = hal.fence.create // CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0) - %signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint + %signal = stream.file.write on(#hal.device.affinity<@device>) await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource<variable>{%c1088} -> !stream.file => !stream.timepoint // CHECK: util.return %[[SIGNAL]] util.return %signal : !stream.timepoint }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir index 6af93ee..09f7046 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
@@ -1,18 +1,22 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAlloc util.func public @resourceAlloc(%arg0: index) -> !stream.resource<transient> { // CHECK: %[[RET0:.+]] = hal.allocator.allocate // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%arg0} - %0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%arg0} + %0 = stream.resource.alloc uninitialized on(#hal.device.affinity<@device>) : !stream.resource<transient>{%arg0} // CHECK: util.return %[[RET0]] util.return %0 : !stream.resource<transient> } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAlloca // CHECK-SAME: (%[[SIZE:.+]]: index) util.func public @resourceAlloca(%size: index) -> (!stream.resource<transient>, !stream.timepoint) { @@ -26,13 +30,15 @@ // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized : !stream.resource<transient>{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) : !stream.resource<transient>{%size} => !stream.timepoint // CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]] util.return %0#0, %0#1 : !stream.resource<transient>, !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAllocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence) util.func public @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource<transient>, !stream.timepoint) { @@ -45,13 +51,15 @@ // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint // CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]] util.return %0#0, %0#1 : !stream.resource<transient>, !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceDealloca // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer) util.func public @resourceDealloca(%size: index, %resource: !stream.resource<transient>) -> !stream.timepoint { @@ -62,14 +70,14 @@ // CHECK-SAME: wait(%[[WAIT_FENCE]]) // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) - %0 = stream.resource.dealloca %resource : !stream.resource<transient>{%size} => !stream.timepoint + %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) %resource : !stream.resource<transient>{%size} => !stream.timepoint // CHECK: util.return %[[SIGNAL_FENCE]] util.return %0 : !stream.timepoint } // ----- -// TODO(#9572): implement stream ordered allocations. +util.global private @device : !hal.device // CHECK-LABEL: @resourceDeallocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence) @@ -80,7 +88,7 @@ // CHECK-SAME: wait(%[[WAIT_FENCE]]) // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) - %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource<transient>{%size} => !stream.timepoint + %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%await_timepoint) => %resource : !stream.resource<transient>{%size} => !stream.timepoint // CHECK: util.return %[[SIGNAL_FENCE]] util.return %0 : !stream.timepoint } @@ -97,6 +105,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceTryMap util.func public @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource<constant>) { %c0 = arith.constant 0 : index @@ -105,7 +115,7 @@ // CHECK-SAME: source(%arg0 : !util.buffer)[%c0, %c128] // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}SharingImmutable") : i1, !hal. - %did_map, %mapping = stream.resource.try_map %arg0[%c0] : !util.buffer -> i1, !stream.resource<constant>{%c128} + %did_map, %mapping = stream.resource.try_map on(#hal.device.affinity<@device>) %arg0[%c0] : !util.buffer -> i1, !stream.resource<constant>{%c128} // CHECK: util.return %[[DID_IMPORT]], %[[IMPORTED]] util.return %did_map, %mapping : i1, !stream.resource<constant> }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir index 8a7b691..007f457 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
@@ -42,12 +42,14 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @timepointChainExternal // CHECK-SAME: (%[[TIMEPOINT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence) util.func public @timepointChainExternal(%timepoint: !stream.timepoint, %signal: !hal.fence) { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[TIMEPOINT]]) signal(%[[SIGNAL]]) - stream.timepoint.chain_external %timepoint => (%signal : !hal.fence) + stream.timepoint.chain_external on(#hal.device.affinity<@device>) %timepoint => (%signal : !hal.fence) util.return }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir index 1dbcc24..5805f71 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir
@@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @tensorImportBuffer util.func public @tensorImportBuffer(%arg0: !hal.buffer, %arg1: index) -> !stream.resource<external> { %c20 = arith.constant 20 : index @@ -10,7 +12,7 @@ // CHECK-SAME: minimum_length(%c20) // CHECK-SAME: type(DeviceVisible) // CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}") - %0 = stream.tensor.import %arg0 : !hal.buffer -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20} + %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20} // CHECK: util.return %arg0 util.return %0 : !stream.resource<external> } @@ -21,6 +23,8 @@ // when lowering into the stream dialect; here we only care about the storage // buffer itself. +util.global private @device : !hal.device + // CHECK-LABEL: @tensorImportBufferView util.func public @tensorImportBufferView(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource<external> { %c20 = arith.constant 20 : index @@ -32,23 +36,27 @@ // CHECK-SAME: minimum_length(%c20) // CHECK-SAME: type(DeviceVisible) // CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}") - %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20} + %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20} // CHECK: util.return %[[BUFFER]] util.return %0 : !stream.resource<external> } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorExportBuffer util.func public @tensorExportBuffer(%arg0: !stream.resource<external>, %arg1: index) -> !hal.buffer { %c200 = arith.constant 200 : index - %0 = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer + %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer // CHECK: util.return %arg0 : !hal.buffer util.return %0 : !hal.buffer } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorExportBufferView util.func public @tensorExportBufferView(%arg0: !stream.resource<external>, %arg1: index) -> !hal.buffer_view { %c200 = arith.constant 200 : index @@ -60,7 +68,7 @@ // CHECK-SAME: type(%[[ELEMENT_TYPE]]) // CHECK-SAME: encoding(%[[ENCODING_TYPE]]) // CHECK-SAME: : !hal.buffer_view - %0 = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view + %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view // CHECK: util.return %[[VIEW]] util.return %0 : !hal.buffer_view }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index f5ac50d..c1d8c22 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -1122,25 +1122,24 @@ } //===----------------------------------------------------------------------===// -// #hal.affinity.queue<*> +// #hal.device.affinity<*> //===----------------------------------------------------------------------===// // static -Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) { - int64_t mask = 0; - // `<` - if (failed(p.parseLess())) +Attribute DeviceAffinityAttr::parse(AsmParser &p, Type type) { + // `<@device` + StringAttr deviceName; + int64_t queueMask = -1; + if (failed(p.parseLess()) || failed(p.parseSymbolName(deviceName))) return {}; - // `*` (any) - if (succeeded(p.parseOptionalStar())) { - mask = -1; - } else { + if (succeeded(p.parseOptionalComma())) { // `[`queue_bit[, ...] `]` + queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; if (failed(p.parseInteger(i))) return failure(); - mask |= 1ll << i; + queueMask |= 1ll << i; return success(); }))) { return {}; @@ -1149,19 +1148,18 @@ // `>` if (failed(p.parseGreater())) return {}; - return get(p.getContext(), mask); + return get(p.getContext(), FlatSymbolRefAttr::get(deviceName), queueMask); } -void AffinityQueueAttr::print(AsmPrinter &p) const { +void DeviceAffinityAttr::print(AsmPrinter &p) const { auto &os = p.getStream(); os << "<"; - int64_t mask = getMask(); - if (mask == -1) { - os << "*"; - } else { - os << "["; - for (int i = 0, j = 0; i < sizeof(mask) * 8; ++i) { - if (mask & (1ll << i)) { + os << getDevice(); + int64_t queueMask = getQueueMask(); + if (queueMask != -1) { + os << ", ["; + for (int i = 0, j = 0; i < sizeof(queueMask) * 8; ++i) { + if (queueMask & (1ll << i)) { if (j++ > 0) os << ", "; os << i; @@ -1172,45 +1170,62 @@ os << ">"; } -bool AffinityQueueAttr::isExecutableWith( +bool DeviceAffinityAttr::isExecutableWith( IREE::Stream::AffinityAttr other) const { if (!other) return true; - // Only compatible with other queue affinities today. When we extend the - // attributes to specify device targets we'd want to check here. - auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other); - if (!otherQueueAttr) + // Only compatible with the same exact devices today. We could support a + // peering model to allow operations to move across devices in a peered set + // but that may be best done at higher levels and avoided once we get to the + // "are these the same device" stage. + auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other); + if (!otherAffinityAttr || getDevice() != otherAffinityAttr.getDevice()) return false; // If this affinity is a subset of the target affinity then it can execute // with it. - if ((getMask() & otherQueueAttr.getMask()) == getMask()) + if ((getQueueMask() & otherAffinityAttr.getQueueMask()) == getQueueMask()) return true; // Otherwise not compatible. return false; } IREE::Stream::AffinityAttr -AffinityQueueAttr::joinOR(IREE::Stream::AffinityAttr other) const { +DeviceAffinityAttr::joinOR(IREE::Stream::AffinityAttr other) const { if (!other) return *this; if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { return nullptr; } - auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other); - return AffinityQueueAttr::get(getContext(), - getMask() | otherQueueAttr.getMask()); + auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other); + return DeviceAffinityAttr::get(getContext(), getDevice(), + getQueueMask() | + otherAffinityAttr.getQueueMask()); } IREE::Stream::AffinityAttr -AffinityQueueAttr::joinAND(IREE::Stream::AffinityAttr other) const { +DeviceAffinityAttr::joinAND(IREE::Stream::AffinityAttr other) const { if (!other) return *this; if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { return nullptr; } - auto otherQueueAttr = llvm::dyn_cast_if_present<AffinityQueueAttr>(other); - return AffinityQueueAttr::get(getContext(), - getMask() & otherQueueAttr.getMask()); + auto otherAffinityAttr = llvm::dyn_cast_if_present<DeviceAffinityAttr>(other); + return DeviceAffinityAttr::get(getContext(), getDevice(), + getQueueMask() & + otherAffinityAttr.getQueueMask()); +} + +bool DeviceAffinityAttr::isLegalToInline(Operation *inlineSite, + Operation *inlinable) const { + // Look up the affinity of the inlining target site and only allow inlining if + // it matches exactly. We could make a decision as to whether we allow + // inlining when queues are subsets (so if the target site allows any queue + // and the inlinable allows queue 2 then allow, etc). In the future we may + // want to allow util.scope restrictions within the inline target to keep + // queue specification tighter but today most queue masks are wildcarded + // anyway. + auto targetAffinityAttr = IREE::Stream::AffinityAttr::lookup(inlineSite); + return *this == targetAffinityAttr; } //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 69a1f15..b77c9a5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -983,41 +983,40 @@ } //===----------------------------------------------------------------------===// -// #hal.affinity.queue<*> +// #hal.device.affinity<*> //===----------------------------------------------------------------------===// -def HAL_AffinityQueueAttr : AttrDef<HAL_Dialect, "AffinityQueue", [ +def HAL_DeviceAffinityAttr : AttrDef<HAL_Dialect, "DeviceAffinity", [ DeclareAttrInterfaceMethods<Stream_AffinityAttr, [ "isExecutableWith", "joinOR", "joinAND", ]>, Util_HoistableAttrInterface, + DeclareAttrInterfaceMethods<Util_InliningPolicyAttrInterface, [ + "isLegalToInline", + ]>, ]> { - let mnemonic = "affinity.queue"; - let summary = [{specifies a set of allowed queues for an operation}]; + let mnemonic = "device.affinity"; + let summary = [{specifies a named device and optional queue affinity}]; let description = [{ - WIP; see [#10765](https://github.com/iree-org/iree/issues/10765). - This may change in the future to either be a nested attribute on a larger - affinity struct or be defined by an implementation of the affinity attr - interface. For now this allows higher levels of the stack to specify - queues such that the stream dialect can understand them and they can be - lowered into the HAL dialect. - Specifies that an annotated operation or scope is only allowed to execute on - the set of queues (0-64) provided. Operations will not run on other queues. + a specific device and optionally a set of queues (0-64) provided. + Operations will not run on other queues. If the queue mask is omitted then + any queue on the device is allowed to execute the specified operations. Example: ```mlir - // any queue - #hal.affinity.queue<*> - // queues 4 and 5 - #hal.affinity.queue<[4, 5]> + // Any queue on @device_a. + #hal.device.affinity<@device_a> + // Queues 4 and 5 on @device_b. + #hal.device.affinity<@device_b, [4, 5]> ``` }]; let parameters = (ins - AttrParameter<"int64_t", "">:$mask + AttrParameter<"FlatSymbolRefAttr", "">:$device, + AttrParameter<"int64_t", "">:$queue_mask ); let hasCustomAssemblyFormat = 1;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 47f8468..00f39f5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
@@ -1,4 +1,5 @@ // RUN: iree-opt --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s +// RUN: iree-opt --inline --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s --check-prefix=CHECK-INLINE // CHECK-LABEL: descriptor_set_layout_binding.basic "descriptor_set_layout_binding.basic"() { @@ -100,11 +101,52 @@ // ----- -"affinity.queue"() { - // CHECK: any = #hal.affinity.queue<*> - any = #hal.affinity.queue<*>, - // CHECK: q0 = #hal.affinity.queue<[0]> - q0 = #hal.affinity.queue<[0]>, - // CHECK: q123 = #hal.affinity.queue<[1, 2, 3]> - q123 = #hal.affinity.queue<[1, 2, 3]> +util.global private @device : !hal.device +"device.affinity"() { + // CHECK: device_any = #hal.device.affinity<@device> + device_any = #hal.device.affinity<@device>, + // CHECK: device_queue_0 = #hal.device.affinity<@device, [0]> + device_queue_0 = #hal.device.affinity<@device, [0]>, + // CHECK: device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]> + device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]> } : () -> () + +// ----- + +// Tests that differing device affinities blocks inlining. +// Here the @inline_target is using the default affinity specified on the +// module and only functions also using the default affinity or a matching +// specified affinity will be inlined. The #hal.device.affinity controls this +// behavior and in the future we could allow inlining of compatible devices, +// the same device on differing queues, etc. + +builtin.module attributes { + stream.affinity = #hal.device.affinity<@device_a> +} { + util.global private @device_a : !hal.device + util.global private @device_b : !hal.device + // CHECK-INLINE: util.func public @inline_target + util.func public @inline_target() -> (i32, i32) { + // CHECK-INLINE-NOT: util.call @compat_inlinable + // CHECK-INLINE: %[[A:.+]] = arith.constant 0 + %a = util.call @compat_inlinable() : () -> i32 + // CHECK-INLINE: %[[B:.+]] = util.call @noncompat_inlinable + %b = util.call @noncompat_inlinable() : () -> i32 + // CHECK-INLINE: util.return %[[A]], %[[B]] + util.return %a, %b : i32, i32 + } + // CHECK-INLINE-NOT: util.func private @compat_inlinable + util.func private @compat_inlinable() -> i32 attributes { + stream.affinity = #hal.device.affinity<@device_a> + } { + %c0 = arith.constant 0 : i32 + util.return %c0 : i32 + } + // CHECK-INLINE: util.func private @noncompat_inlinable + util.func private @noncompat_inlinable() -> i32 attributes { + stream.affinity = #hal.device.affinity<@device_b> + } { + %c1 = arith.constant 1 : i32 + util.return %c1 : i32 + } +}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir index 6da4866..7408ad9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -144,9 +144,10 @@ // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.pipeline_layout, // CHECK-SAME: %[[LAYOUT1:.+]]: !hal.pipeline_layout -util.func public @executable_create(%device: !hal.device, - %layout0: !hal.pipeline_layout, - %layout1: !hal.pipeline_layout) { +util.func public @executable_create( + %device: !hal.device, + %layout0: !hal.pipeline_layout, + %layout1: !hal.pipeline_layout) { // CHECK: = hal.executable.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: target(@exe::@binary1) @@ -163,16 +164,17 @@ // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout, // CHECK-SAME: %[[LAYOUT1:.+]]: !hal.descriptor_set_layout -util.func public @pipeline_layout_create(%device: !hal.device, - %layout0: !hal.descriptor_set_layout, - %layout1: !hal.descriptor_set_layout) { +util.func public @pipeline_layout_create( + %device: !hal.device, + %layout0: !hal.descriptor_set_layout, + %layout1: !hal.descriptor_set_layout) { // CHECK: hal.pipeline_layout.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: push_constants(1) // CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT1]]]) : !hal.pipeline_layout %0 = hal.pipeline_layout.create device(%device : !hal.device) - push_constants(1) - layouts([%layout0, %layout1]) : !hal.pipeline_layout + push_constants(1) + layouts([%layout0, %layout1]) : !hal.pipeline_layout util.return } @@ -197,8 +199,9 @@ // CHECK-LABEL: @unresolved_workload // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[WORKLOAD_0:.+]]: index, %[[WORKLOAD_1:.+]]: index) -util.func public @unresolved_workload(%device: !hal.device, - %workload_0: index, %workload_1: index) -> (index, index, index) { +util.func public @unresolved_workload( + %device: !hal.device, + %workload_0: index, %workload_1: index) -> (index, index, index) { // CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] = // CHECK-SAME: hal.executable.calculate_workgroups // CHECK-SAME: device(%[[DEVICE]] : !hal.device)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 82b6310..1de9b90 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -2,6 +2,8 @@ // Tests an end-to-end simple single-dispatch `dispatch(arg0, arg1) -> result`. +util.global private @device : !hal.device + #executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> @@ -66,7 +68,9 @@ // CHECK: util.func public @simpleDispatch // CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view -util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { +util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes { + stream.affinity = #hal.device.affinity<@device> +} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c16 = arith.constant 16 : index @@ -76,8 +80,7 @@ // CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer - // (annoyingly out of order) - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device : !hal.device // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator // CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index 287822a..3cb5f76 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -15,8 +15,7 @@ // CHECK-LABEL: @exeLayoutLookup util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout { // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout - %0 = hal.pipeline_layout.lookup device(%device : !hal.device) - layout(#hal.pipeline.layout<push_constants = 1, sets = [ + %0 = hal.pipeline_layout.lookup device(%device : !hal.device) layout(#hal.pipeline.layout<push_constants = 1, sets = [ #hal.descriptor_set.layout<0, bindings = [ #hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir index da75704..4106307 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
@@ -46,18 +46,21 @@ // ----- +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK-LABEL: @dispatchAffinity // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index) util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) { - // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[0]>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} - // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %0 = flow.dispatch @ex::@entry0(%input) { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device_a> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3} - // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[1]>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]} - // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]} + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %1 = flow.dispatch @ex::@entry1(%input) { - stream.affinity = #hal.affinity.queue<[1]> + stream.affinity = #hal.device.affinity<@device_b> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1} // return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]] util.return %0, %1 : tensor<?x?x1024xf32>, tensor<?x?x1024xf32>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index 9a1272f..7633d8c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -166,6 +166,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorLoad // CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index) util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 { @@ -173,10 +175,10 @@ %c1 = arith.constant 1 : index // CHECK: %[[T0:.+]] = stream.async.transfer // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]} - // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource<staging>{%[[SOURCE_SIZE]]} + // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource<staging>{%[[SOURCE_SIZE]]} // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c1] : tensor<2x3xi32> in !stream.resource<staging>{%[[SOURCE_SIZE]]} -> i32 %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } // CHECK: util.return %[[T1]] util.return %0 : i32 @@ -184,6 +186,8 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorStore // CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index) util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> { @@ -191,13 +195,13 @@ %c1 = arith.constant 1 : index %c9 = arith.constant 9 : i32 // CHECK: %[[T0:.+]] = stream.async.transfer %[[TARGET]] : !stream.resource<*>{%[[TARGET_SIZE]]} - // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource<staging>{%[[TARGET_SIZE]]} + // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource<staging>{%[[TARGET_SIZE]]} // CHECK: %[[T1:.+]] = stream.tensor.store %c9_i32, %[[T0]][%c0, %c1] : // CHECK-SAME: i32 -> tensor<2x3xi32> in %[[T0]] as !stream.resource<staging>{%[[TARGET_SIZE]]} // CHECK: %[[T2:.+]] = stream.async.transfer %[[T1]] : !stream.resource<staging>{%[[TARGET_SIZE]]} -> - // CHECK-SAME: to(#hal.affinity.queue<[0, 1]>) !stream.resource<*>{%[[TARGET_SIZE]]} + // CHECK-SAME: to(#hal.device.affinity<@device>) !stream.resource<*>{%[[TARGET_SIZE]]} %0 = flow.tensor.store %c9, %target[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } // CHECK: util.return %[[T2]] util.return %0 : tensor<2x3xi32>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index c994e65..cb36271 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -123,8 +123,8 @@ ); let assemblyFormat = [{ - (`on` `(` $affinity^ `)`)? (`uninitialized` $uninitialized^)? + (`on` `(` $affinity^ `)`)? attr-dict `:` type($result) `{` $storage_size `}` }]; @@ -1772,15 +1772,15 @@ } // OpGroupTensorOps //===----------------------------------------------------------------------===// -// Resource transfer ops +// Async (stream.async*) ops //===----------------------------------------------------------------------===// -def OpGroupResourceTransferOps : OpDocGroup { - let summary = "Resource transfer ops"; +def OpGroupAsyncOps : OpDocGroup { + let summary = "Async ops"; let description = ""; } -let opDocGroup = OpGroupResourceTransferOps in { +let opDocGroup = OpGroupAsyncOps in { def Stream_AsyncAllocaOp : Stream_Op<"async.alloca", [ DeclareOpInterfaceMethods<Stream_AffinityOp, [ @@ -2460,7 +2460,7 @@ let hasCanonicalizer = 1; } -} // OpGroupResourceTransferOps +} // OpGroupAsyncOps //===----------------------------------------------------------------------===// // Async control flow ops @@ -2855,15 +2855,15 @@ } // OpGroupAsyncControlFlowOps //===----------------------------------------------------------------------===// -// Explicit command ops +// Explicit command (stream.cmd.*) ops //===----------------------------------------------------------------------===// -def OpGroupExplicitCommandOps : OpDocGroup { +def OpGroupCmdOps : OpDocGroup { let summary = "Explicit command ops"; let description = ""; } -let opDocGroup = OpGroupExplicitCommandOps in { +let opDocGroup = OpGroupCmdOps in { def Stream_CmdFlushOp : Stream_Op<"cmd.flush", [ Stream_CmdPhaseOp, @@ -3531,7 +3531,7 @@ let hasCanonicalizer = 1; } -} // OpGroupExplicitCommandOps +} // OpGroupCmdOps //===----------------------------------------------------------------------===// // Synchronization ops
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir index 50c7c26..2d33e5e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
@@ -82,6 +82,8 @@ // This covers all_gather, all_reduce, and reduce_scatter variants. +util.global private @device : !hal.device + // CHECK-LABEL: @asyncCollectiveAllGather util.func private @asyncCollectiveAllGather( // CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel, @@ -95,8 +97,8 @@ %recv = stream.async.alloca : !stream.resource<*>{%recv_size} // CHECK: = stream.async.collective<all_gather : f32>[%[[COUNT]]] %0 = stream.async.collective<all_gather : f32>[%count] - // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]]) - on(#hal.affinity.queue<[0]>) channel(%channel) + // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]]) + on(#hal.device.affinity<@device>) channel(%channel) // CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] : @@ -110,6 +112,8 @@ // This covers broadcast and reduce variants. +util.global private @device : !hal.device + // CHECK-LABEL: @asyncCollectiveBroadcast util.func private @asyncCollectiveBroadcast( // CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel, @@ -125,8 +129,8 @@ %recv = stream.async.alloca : !stream.resource<*>{%recv_size} // CHECK: = stream.async.collective<broadcast : f32>[%[[COUNT]]] %0 = stream.async.collective<broadcast : f32>[%count] - // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]]) source(%[[RANK]]) - on(#hal.affinity.queue<[0]>) channel(%channel) source(%rank) + // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]]) source(%[[RANK]]) + on(#hal.device.affinity<@device>) channel(%channel) source(%rank) // CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] : @@ -147,10 +151,12 @@ // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @asyncTransferAffinities util.func private @asyncTransferAffinities(%arg0: !stream.resource<constant>, %arg1: index) -> !stream.resource<constant> { - // CHECK: = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource<constant>{%arg1} - %0 = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource<constant>{%arg1} + // CHECK: = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource<constant>{%arg1} + %0 = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource<constant>{%arg1} util.return %0 : !stream.resource<constant> }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir index 486a03f..a465546 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir
@@ -1,10 +1,12 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @channel_create // CHECK-SAME: (%[[RANK:.+]]: index, %[[COUNT:.+]]: index) util.func private @channel_create(%rank: index, %count: index) { - // CHECK: %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel - %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%rank) count(%count) : !stream.channel + // CHECK: %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel + %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%rank) count(%count) : !stream.channel util.return }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir index ab523ec..950643a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir
@@ -1,12 +1,14 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @context_resolve util.func private @context_resolve() { // CHECK: = stream.context.resolve : !hal.allocator %allocator = stream.context.resolve : !hal.allocator - // CHECK: = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 - %device1, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 - // CHECK: = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 - %device0, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 + // CHECK: = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + %device1, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + // CHECK: = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 + %device0, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 util.return }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir index 14e8fb2..ed1f338 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
@@ -16,8 +16,8 @@ stream.executable private @rebaseBindingsEx { stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig} builtin.module { - // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, - // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) + // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, + // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %operand: index) { %c0 = arith.constant 0 : index %c20 = arith.constant 20 : index @@ -39,7 +39,7 @@ } } } -// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index) +// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index) util.func public @rebaseBindings(%operand: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -97,8 +97,8 @@ stream.executable private @deduplicateBindingsEx { stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig} builtin.module { - // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, - // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) + // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, + // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %binding_c: !stream.binding, %operand: index) { %c0 = arith.constant 0 : index %c20 = arith.constant 20 : index @@ -127,7 +127,7 @@ } } } -// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index) +// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index) util.func public @deduplicateBindings(%operand: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir index a3b4ef6..a1509e3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
@@ -110,13 +110,15 @@ // TODO(#11249): support in-place collectives - when supported this will become // a negative test as we'd expect %send_recv to be used for both operands. +util.global private @device : !hal.device + // CHECK-LABEL: @tiedCollectivesTODO // CHECK-SAME: (%[[CHANNEL:.+]]: !stream.channel, %[[SEND_RECV:.+]]: !stream.resource<*>, %[[SEND_SIZE:.+]]: index, %[[RECV_SIZE:.+]]: index, %[[COUNT:.+]]: index) util.func private @tiedCollectivesTODO(%channel: !stream.channel, %send_recv: !stream.resource<*>, %send_size: index, %recv_size: index, %count: index) -> !stream.resource<*> { %c0 = arith.constant 0 : index - // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.affinity.queue<[0]>) %[[SEND_RECV]] + // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.device.affinity<@device>) %[[SEND_RECV]] // CHECK: %[[ALL_GATHER:.+]] = stream.async.collective<all_gather : f32>[%[[COUNT]]] - %0 = stream.async.collective<all_gather : f32>[%count] on(#hal.affinity.queue<[0]>) channel(%channel) + %0 = stream.async.collective<all_gather : f32>[%count] on(#hal.device.affinity<@device>) channel(%channel) // CHECK-SAME: %[[SEND_RECV]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send_recv[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV_CLONE]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] :
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir index 00f5c32..8266ca5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir
@@ -223,27 +223,29 @@ // execution region. We expect them to be placed into packed slices and // allocated with the async stream-ordered alloca/dealloca ops. +util.global private @device : !hal.device + // CHECK-LABEL: @locals // CHECK-SAME: (%[[SIZE0:.+]]: index, %[[SIZE1:.+]]: index, %[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint) util.func public @locals(%size0: index, %size1: index, %await_timepoint: !stream.timepoint) -> !stream.timepoint { %c254_i32 = arith.constant 254 : i32 %c255_i32 = arith.constant 255 : i32 - // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.affinity.queue<[0]>) slices({ + // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.device.affinity<@device>) slices({ // CHECK-NEXT: [0, 0] = %[[SIZE0]], // CHECK-NEXT: [1, 1] = %[[SIZE1]] // CHECK-NEXT: }) - // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.affinity.queue<[0]>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint + // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint // CHECK-NEXT: %[[AWAIT_JOIN:.+]] = stream.timepoint.join max(%[[AWAIT_TIMEPOINT]], %[[ALLOCA_TIMEPOINT]]) - // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.affinity.queue<[0]>) await(%[[AWAIT_JOIN]]) + // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.device.affinity<@device>) await(%[[AWAIT_JOIN]]) // CHECK-SAME: with(%[[ALLOCA]] as %[[CAPTURE:.+]]: !stream.resource<transient>{%[[SLICES]]#0}) - %result_timepoint = stream.async.execute on(#hal.affinity.queue<[0]>) await(%await_timepoint) => with() { + %result_timepoint = stream.async.execute on(#hal.device.affinity<@device>) await(%await_timepoint) => with() { // CHECK: stream.cmd.fill %c254_i32, %[[CAPTURE]][%[[SLICES]]#1 for %[[SIZE0]]] : i32 -> !stream.resource<transient>{%[[SLICES]]#0} %0 = stream.async.splat %c254_i32 : i32 -> !stream.resource<transient>{%size0} // CHECK: stream.cmd.fill %c255_i32, %[[CAPTURE]][%[[SLICES]]#2 for %[[SIZE1]]] : i32 -> !stream.resource<transient>{%[[SLICES]]#0} %1 = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%size1} stream.yield } => !stream.timepoint - // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.affinity.queue<[0]>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint + // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource<transient>{%[[SLICES]]#0} => !stream.timepoint // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[DEALLOCA_TIMEPOINT]], %[[EXEC_TIMEPOINT]]) => !stream.timepoint // CHECK: util.return %[[JOIN]] util.return %result_timepoint : !stream.timepoint
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index 3ccd781..0f33b51 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir
@@ -38,6 +38,9 @@ // Dispatches with the same affinities should be placed into the same execution // regions. +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK-LABEL: @partitioningWithAffinities // CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>) util.func public @partitioningWithAffinities(%arg0: !stream.resource<external>) -> !stream.resource<external> { @@ -48,31 +51,31 @@ %c255_i32 = arith.constant 255 : i32 // CHECK: %[[TRANSIENTS:.+]]:2, %[[TIMEPOINT0:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[0]>) + // CHECK-SAME: on(#hal.device.affinity<@device_a>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource<external>{%c20}) // CHECK-SAME: -> (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) { // CHECK-NEXT: %[[SPLAT:.+]] = stream.async.splat %splat = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c1280} // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}]) - %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280} + %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280} // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}]) - %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20} + %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[DISPATCH1]] // CHECK-NEXT: } => !stream.timepoint // CHECK: %[[RESULT:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: await(%[[TIMEPOINT0]]) // CHECK-SAME: with(%[[TRANSIENTS]]#0 as %[[TRANSIENT0_CAPTURE:.+]]: !stream.resource<transient>{%c1280}, // CHECK-SAME: %[[TRANSIENTS]]#1 as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource<transient>{%c20}) // CHECK-SAME: -> !stream.resource<external>{%c20} // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}]) - %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20} + %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH2]] // CHECK-NEXT: } => !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: %[[TIMEPOINT1]] => %[[RESULT]] : !stream.resource<external>{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource<external> @@ -84,6 +87,10 @@ // dependencies. Unrelated dispatches with differing affinities should end up // in concurrently executable regions. +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device +util.global private @device_c : !hal.device + // CHECK-LABEL: @partitioningWithConcurrentAffinities // CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>) util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource<external>) -> !stream.resource<external> { @@ -94,23 +101,23 @@ %c255_i32 = arith.constant 255 : i32 // CHECK: %[[TRANSIENT0:.+]], %[[TIMEPOINT0:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[0]>) + // CHECK-SAME: on(#hal.device.affinity<@device_a>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE0:.+]]: !stream.resource<external>{%c20}) // CHECK-SAME: !stream.resource<transient>{%c1280} // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat %splat = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c1280} // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE0]][{{.+}}], %[[SPLAT0]][{{.+}}]) - %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280} + %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c1280} // CHECK-NEXT: stream.yield %[[DISPATCH0]] // CHECK-NEXT: } => !stream.timepoint // CHECK: %[[TRANSIENT1:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE1:.+]]: !stream.resource<external>{%c20}) // CHECK-SAME: -> !stream.resource<transient>{%c20} { // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE1]][{{.+}}], %[[SPLAT1]][{{.+}}]) - %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20} + %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> !stream.resource<transient>{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH1]] // CHECK-NEXT: } => !stream.timepoint @@ -121,12 +128,12 @@ // CHECK-SAME: with(%[[TRANSIENT0]] as %[[TRANSIENT0_CAPTURE:.+]]: !stream.resource<transient>{%c1280}, // CHECK-SAME: %[[TRANSIENT1]] as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource<transient>{%c20}) // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}]) - %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[2]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20} + %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_c>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH2]] // CHECK-NEXT: } => !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.affinity.queue<[2]>) + // CHECK-SAME: on(#hal.device.affinity<@device_c>) // CHECK-SAME: %[[TIMEPOINT2]] => %[[RESULT]] : !stream.resource<external>{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource<external>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir index 4082bbf..e289f07 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
@@ -142,12 +142,14 @@ // CHECK-LABEL: @hoist_dialect_attrs module @hoist_dialect_attrs { + // CHECK: util.global private @device + util.global private @device : !hal.device // CHECK: util.global private @[[HOISTED:[a-z0-9_]+]] - // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> // CHECK: util.initializer - // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> util.func public @main() -> tensor<i32> attributes { - hal.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } { %0 = arith.constant dense<3> : tensor<i32> %1 = "iree_unregistered.const_expr"(%0) : (tensor<i32>) -> tensor<i32>
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel index 42d582b..a3255ea 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel
@@ -22,6 +22,7 @@ ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL:Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check/IR",
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt index 161a143..3c20a5b 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt
@@ -22,6 +22,7 @@ MLIRTransformUtils MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::Conversion::StreamToHAL::Utils iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::IR
diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index 3dab905..d9db0b3 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp
@@ -7,6 +7,7 @@ #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h" #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" @@ -68,8 +69,7 @@ state.addAttributes(srcOp->getAttrs()); // Add device argument. - // TODO(multi-device): support multiple devices in check tests . - Value device = IREE::HAL::DeviceType::resolveAny(srcOp->getLoc(), rewriter); + Value device = lookupDeviceFor(srcOp, rewriter); state.addOperands({device}); for (auto [srcOperand, dstOperand] :