Adding stream.dispatch.workgroup.* info ops. (#15889)
These mirror the flow and hal ops and allow us to remove a few more flow
ops from the flow->stream->hal path. Until we remove/replace the flow
dispatch tensor load/store ops we can't do the trivial conversions but
as of this PR those (and the workgroup count codegen op) are the only
ops that still survive from flow.
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index a73cf1c..194ff82 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -567,9 +567,9 @@
/// form
/// ```
/// %dim = arith.constant ... : index
-/// %id = flow.dispatch.workgroup.id[%dim]
-/// %count = flow.dispatch.workgroup.count[%dim]
-/// %size = flow.dispatch.workgroup.size[%dim]
+/// %id = stream.dispatch.workgroup.id[%dim]
+/// %count = stream.dispatch.workgroup.count[%dim]
+/// %size = stream.dispatch.workgroup.size[%dim]
/// %offset = affine.apply
/// affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>(%lb)[%id, %size]
/// %new_step = affine.apply
@@ -655,7 +655,7 @@
int numLoops = op.getLoopIteratorTypes().size();
SmallVector<int64_t> fixedTileSizes(tileSizes);
fixedTileSizes.resize(numLoops, /*default=*/0);
- SmallVector<int64_t> fixedTileScalableFlags(tileScalableFlags);
+ SmallVector<bool> fixedTileScalableFlags(tileScalableFlags);
fixedTileScalableFlags.resize(numLoops, /*default=*/false);
if (!llvm::is_contained(fixedTileScalableFlags, true)) {
// Non-scalable case: All constant tile sizes.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 83bf836..0b0ba53 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -493,15 +493,17 @@
} // namespace
-static LogicalResult convertFlowInfoOps(IREE::HAL::ExecutableOp executableOp) {
+static LogicalResult
+convertDispatchWorkgroupInfoOps(IREE::HAL::ExecutableOp executableOp) {
RewritePatternSet patterns(executableOp.getContext());
patterns.insert<
ConvertReturnPattern,
- ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupIDOp,
+ ConvertDispatchWorkgroupInfoPattern<IREE::Stream::DispatchWorkgroupIDOp,
IREE::HAL::InterfaceWorkgroupIDOp>,
- ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupCountOp,
- IREE::HAL::InterfaceWorkgroupCountOp>,
- ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupSizeOp,
+ ConvertDispatchWorkgroupInfoPattern<
+ IREE::Stream::DispatchWorkgroupCountOp,
+ IREE::HAL::InterfaceWorkgroupCountOp>,
+ ConvertDispatchWorkgroupInfoPattern<IREE::Stream::DispatchWorkgroupSizeOp,
IREE::HAL::InterfaceWorkgroupSizeOp>,
InlineConstantWorkgroupSizePattern>(executableOp.getContext());
return applyPatternsAndFoldGreedily(executableOp, std::move(patterns));
@@ -595,9 +597,9 @@
return signalPassFailure();
}
- // Convert interface-related flow.dispatch.* ops to their hal.interface.*
- // versions.
- if (failed(convertFlowInfoOps(executableOp))) {
+ // Convert interface-related stream.dispatch.* ops to their
+ // hal.interface.* versions.
+ if (failed(convertDispatchWorkgroupInfoOps(executableOp))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 4e73fb2..75d6284 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -756,6 +756,25 @@
});
}
+template <typename FlowOpT, typename StreamOpT>
+static void replaceDispatchWorkgroupInfoOp(FlowOpT op,
+ PatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<StreamOpT>(op, op.getResult().getType(),
+ op.getDimension());
+}
+
+template <typename FlowOpT, typename StreamOpT>
+struct ConvertDispatchWorkgroupInfoOp : public OpConversionPattern<FlowOpT> {
+ using OpConversionPattern<FlowOpT>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(FlowOpT op, typename FlowOpT::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<StreamOpT>(op, op.getResult().getType(),
+ adaptor.getDimension());
+ return success();
+ }
+};
+
struct ConvertExecutableOp
: public OpConversionPattern<IREE::Flow::ExecutableOp> {
using OpConversionPattern::OpConversionPattern;
@@ -828,6 +847,31 @@
funcOp.setType(rewriter.getFunctionType(newTypes, {}));
}
+
+ // Walk the module and replace some ops that we don't rely on the pattern
+ // rewriter for today. This is pretty shady and a side-effect of
+ // recursively marking the stream executable contents as legal - if we
+ // didn't do that (and converted all flow ops) we could drop this logic
+ // and rely only the patterns.
+ moduleOp.walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case<IREE::Flow::DispatchWorkgroupIDOp>([&](auto op) {
+ replaceDispatchWorkgroupInfoOp<
+ IREE::Flow::DispatchWorkgroupIDOp,
+ IREE::Stream::DispatchWorkgroupIDOp>(op, rewriter);
+ })
+ .Case<IREE::Flow::DispatchWorkgroupCountOp>([&](auto op) {
+ replaceDispatchWorkgroupInfoOp<
+ IREE::Flow::DispatchWorkgroupCountOp,
+ IREE::Stream::DispatchWorkgroupCountOp>(op, rewriter);
+ })
+ .Case<IREE::Flow::DispatchWorkgroupSizeOp>([&](auto op) {
+ replaceDispatchWorkgroupInfoOp<
+ IREE::Flow::DispatchWorkgroupSizeOp,
+ IREE::Stream::DispatchWorkgroupSizeOp>(op, rewriter);
+ })
+ .Default([&](auto *op) {});
+ });
}
rewriter.eraseOp(flowOp);
@@ -868,6 +912,14 @@
patterns.insert<ConvertDispatchOp>(typeConverter, context);
patterns.insert<ConvertFuncOp, ConvertCallOp>(typeConverter, context);
patterns.insert<ConvertExecutableOp>(typeConverter, context);
+ patterns.insert<
+ ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp,
+ IREE::Stream::DispatchWorkgroupIDOp>,
+ ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupCountOp,
+ IREE::Stream::DispatchWorkgroupCountOp>,
+ ConvertDispatchWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupSizeOp,
+ IREE::Stream::DispatchWorkgroupSizeOp>>(
+ typeConverter, context);
patterns.insert<ConvertReturnOp>(typeConverter, context);
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/executable_ops.mlir
index 36305db..d221ddf 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/executable_ops.mlir
@@ -125,14 +125,18 @@
builtin.module {
// CHECK: func.func @dispatch(%[[DIM:.+]]: index, %[[INPUT:.+]]: !stream.binding, %[[OUTPUT:.+]]: !stream.binding)
func.func @dispatch(%dim: index, %input: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>, %output: !flow.dispatch.tensor<writeonly:tensor<?xf32>>) {
+ // CHECK-DAG: stream.dispatch.workgroup.size[0] : index
+ %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
+ // CHECK-DAG: stream.dispatch.workgroup.id[0] : index
+ %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
+ // CHECK-DAG: stream.dispatch.workgroup.count[0] : index
+ %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
+
// CHECK-DAG: %[[TIED_INPUT:.+]] = stream.binding.subspan %[[INPUT]][%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<1x?xf32>>{%[[DIM]]}
- // CHECK-DAG: %[[TIED_OUTPUT:.+]] = stream.binding.subspan %[[OUTPUT]][%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[DIM]]}
%tied_input = flow.dispatch.tie_shape %input : !flow.dispatch.tensor<readonly:tensor<1x?xf32>>{%dim}
+ // CHECK-DAG: %[[TIED_OUTPUT:.+]] = stream.binding.subspan %[[OUTPUT]][%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%[[DIM]]}
%tied_output = flow.dispatch.tie_shape %output : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%dim}
- %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
- %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
- %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
%5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_0, %workgroup_id_0]
%6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_0, %workgroup_count_0]
scf.for %arg3 = %5 to %dim step %6 {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 189f7af..0d006b1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -3716,6 +3716,59 @@
}
//===----------------------------------------------------------------------===//
+// stream.dispatch.workgroup.*
+//===----------------------------------------------------------------------===//
+
+static void getAsmResultNamesForDispatchWorkgroupInfoOp(
+ StringRef prefix, const APInt &dimension, Value result,
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(result, (prefix + std::to_string(dimension.getZExtValue())).str());
+}
+
+static LogicalResult verifyDispatchWorkgroupInfoOp(Operation *op,
+ uint64_t dimension) {
+ if (dimension < 0 || dimension >= 3) {
+ return op->emitOpError()
+ << "dimension " << dimension
+ << " out of bounds of dispatch dimensions; expected [0, 3)";
+ }
+ return success();
+}
+
+void DispatchWorkgroupIDOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_id_", getDimension(),
+ getResult(), setNameFn);
+}
+
+LogicalResult DispatchWorkgroupIDOp::verify() {
+ return verifyDispatchWorkgroupInfoOp(getOperation(),
+ getDimension().getZExtValue());
+}
+
+void DispatchWorkgroupCountOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ getAsmResultNamesForDispatchWorkgroupInfoOp(
+ "workgroup_count_", getDimension(), getResult(), setNameFn);
+}
+
+LogicalResult DispatchWorkgroupCountOp::verify() {
+ return verifyDispatchWorkgroupInfoOp(getOperation(),
+ getDimension().getZExtValue());
+}
+
+void DispatchWorkgroupSizeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_size_", getDimension(),
+ getResult(), setNameFn);
+}
+
+LogicalResult DispatchWorkgroupSizeOp::verify() {
+ return verifyDispatchWorkgroupInfoOp(getOperation(),
+ getDimension().getZExtValue());
+}
+
+//===----------------------------------------------------------------------===//
// stream.yield
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 08ac630..82aace0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -4066,8 +4066,8 @@
}
def Stream_BindingSubspanOp : Stream_PureOp<"binding.subspan", [
- Util_ShapeAwareOp,
- ]> {
+ Util_ShapeAwareOp,
+]> {
let summary = [{returns an alias to a subspan of interface binding data}];
let description = [{
Returns a subview to a tensor or memref-like type from a binding. The same
@@ -4096,6 +4096,113 @@
}];
}
+def Stream_DispatchWorkgroupIDOp : Stream_PureOp<"dispatch.workgroup.id", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{returns the index of the current workgroup in the grid}];
+ let description = [{
+ The global workgroup ID of the current workgroup in the range of
+ `[0, stream.dispatch.workgroup.count)` along each dimension.
+
+ Represented as a 3D grid classically written as XYZ.
+ Corresponds to the `WorkgroupId` SPIR-V built-in and the `blockIdx` CUDA
+ built-in variable.
+
+ ```mlir
+ %x = stream.dispatch.workgroup.id[0] : index
+ %y = stream.dispatch.workgroup.id[1] : index
+ %z = stream.dispatch.workgroup.id[2] : index
+ ```
+ }];
+
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs Stream_Dim:$result);
+
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+
+ let hasVerifier = 1;
+}
+
+def Stream_DispatchWorkgroupCountOp : Stream_PureOp<"dispatch.workgroup.count", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{returns the total workgroup count of the grid}];
+ let description = [{
+ The total number of workgroups along each dimension in the dispatch grid.
+
+ Represented as a 3D grid classically written as XYZ.
+ Corresponds to the `NumWorkgroups` SPIR-V built-in and the `gridDim` CUDA
+ built-in variable.
+
+ ```mlir
+ %x = stream.dispatch.workgroup.count[0] : index
+ %y = stream.dispatch.workgroup.count[1] : index
+ %z = stream.dispatch.workgroup.count[2] : index
+ ```
+ }];
+
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs Stream_Dim:$result);
+
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+
+ let hasVerifier = 1;
+}
+
+def Stream_DispatchWorkgroupSizeOp : Stream_PureOp<"dispatch.workgroup.size", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{returns the size of each workgroup in invocations}];
+ let description = [{
+ The number of local invocations within the current workgroup along each
+ dimension. Depending on backend this may map to the SIMT thread count or
+ inner loop nest parameters.
+
+ Workgroup sizes are not determined at the stream dialect level as they are
+ dependent on the target backend determined when lowering into the HAL. It's
+ still possible to use the symbolic workgroup size inside of dispatch
+ executables as a placeholder for the resolved value once in the HAL.
+
+ Represented as a 3D grid classically written as XYZ.
+ Corresponds to the `WorkgroupSize` SPIR-V built-in and the `blockDim` CUDA
+ built-in variable.
+
+ ```mlir
+ %x = stream.dispatch.workgroup.size[0] : index
+ %y = stream.dispatch.workgroup.size[1] : index
+ %z = stream.dispatch.workgroup.size[2] : index
+ ```
+ }];
+
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs Stream_Dim:$result);
+
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+
+ let hasVerifier = 1;
+}
+
} // OpGroupExecutableOps
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
index 0146b93..e7c7aa0 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
@@ -55,9 +55,9 @@
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<4xi32>>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<4xi32>>
%2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<4xi32>>
- %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
- %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
- %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
+ %workgroup_size_0 = stream.dispatch.workgroup.size[0] : index
+ %workgroup_id_0 = stream.dispatch.workgroup.id[0] : index
+ %workgroup_count_0 = stream.dispatch.workgroup.count[0] : index
%3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
%4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
scf.for %arg3 = %3 to %c4 step %4 {
@@ -86,9 +86,9 @@
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<3xi32>>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<3xi32>>
%2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<3xi32>>
- %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
- %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
- %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
+ %workgroup_size_0 = stream.dispatch.workgroup.size[0] : index
+ %workgroup_id_0 = stream.dispatch.workgroup.id[0] : index
+ %workgroup_count_0 = stream.dispatch.workgroup.count[0] : index
%3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
%4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
scf.for %arg3 = %3 to %c3 step %4 {
diff --git a/samples/custom_dispatch/cpu/embedded/example_stream.mlir b/samples/custom_dispatch/cpu/embedded/example_stream.mlir
index 49adade..2620ca7 100644
--- a/samples/custom_dispatch/cpu/embedded/example_stream.mlir
+++ b/samples/custom_dispatch/cpu/embedded/example_stream.mlir
@@ -131,7 +131,7 @@
// particular workgroup is in the grid. In this example we use a
// workgroup size of 64x1x1 (which is exceedingly small for CPUs but
// useful for demonstration).
- %workgroup_id_x = flow.dispatch.workgroup.id[0] : index
+ %workgroup_id_x = stream.dispatch.workgroup.id[0] : index
%tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
// Bindings are accessed by reference.
@@ -165,7 +165,7 @@
%dim: index) {
%c0 = arith.constant 0 : index
- %workgroup_id_x = flow.dispatch.workgroup.id[0] : index
+ %workgroup_id_x = stream.dispatch.workgroup.id[0] : index
%tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
// Same as above but note that we're treating %binding1 as read/write.
diff --git a/samples/custom_dispatch/cpu/plugin/standalone_example.mlir b/samples/custom_dispatch/cpu/plugin/standalone_example.mlir
index 56bf393..7bc4d33 100644
--- a/samples/custom_dispatch/cpu/plugin/standalone_example.mlir
+++ b/samples/custom_dispatch/cpu/plugin/standalone_example.mlir
@@ -43,11 +43,11 @@
builtin.module {
// External function declaration using a user-chosen calling convention.
func.func private @simple_mul_workgroup(
- %binding0: memref<f32>,
+ %binding0: memref<f32>,
%binding0_offset : index,
%binding1: memref<f32>,
%binding1_offset : index,
- %binding2: memref<f32>,
+ %binding2: memref<f32>,
%binding2_offset : index,
%dim: index, %tid: index) attributes {
// We can include some additional fields on the parameters struct as
@@ -72,7 +72,7 @@
// particular workgroup is in the grid. In this example we use a
// workgroup size of 64x1x1 (which is exceedingly small for CPUs but
// useful for demonstration).
- %workgroup_id_x = flow.dispatch.workgroup.id[0] : index
+ %workgroup_id_x = stream.dispatch.workgroup.id[0] : index
%tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
// Bindings are accessed by reference.
@@ -89,7 +89,7 @@
: memref<?xf32> -> memref<f32>, index, index, index
%base2, %offset2, %size2, %stride2 = memref.extract_strided_metadata %memref2
: memref<?xf32> -> memref<f32>, index, index, index
-
+
// Call the externally defined C function with an (almost) plain C
// calling convention (see above for details about the mess memrefs
// turn into). This will be fetched at runtime from the plugin binary.
diff --git a/samples/custom_dispatch/cpu/plugin/system_example.mlir b/samples/custom_dispatch/cpu/plugin/system_example.mlir
index 586ab00..1d82b77 100644
--- a/samples/custom_dispatch/cpu/plugin/system_example.mlir
+++ b/samples/custom_dispatch/cpu/plugin/system_example.mlir
@@ -83,7 +83,7 @@
// particular workgroup is in the grid. In this example we use a
// workgroup size of 64x1x1 (which is exceedingly small for CPUs but
// useful for demonstration).
- %workgroup_id_x = flow.dispatch.workgroup.id[0] : index
+ %workgroup_id_x = stream.dispatch.workgroup.id[0] : index
%tid = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
// Bindings are accessed by reference.
@@ -97,7 +97,7 @@
: memref<?xf32> -> memref<f32>, index, index, index
%base2, %offset2, %size2, %stride2 = memref.extract_strided_metadata %memref2
: memref<?xf32> -> memref<f32>, index, index, index
-
+
// Call the externally defined C function with an (almost) plain C
// calling convention. This will be fetched at runtime from the plugin binary.