Retire usage of the transform dialect at the Flow level for cpu/matmul.mlir and cuda/reduction.mlir (#10527)
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index f2c1dc3..aeee547 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -310,28 +310,39 @@
"scf.foreach_thread with rank > 3 does not lower to workgroup");
// Step 0. Outline the compute workload region and set up the workload
- // operands.
- auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp);
- if (failed(maybeWorkgroupCounts) ||
- llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) {
- return !getConstantIntValue(ofr).has_value();
- }))
- return foreachThreadOp->emitError(
- "unsupported dynamic workgroup_count atm --- need to slice out "
- "workgroup_count computation into ExecutableExport::workgroup_count. "
- "This region may require arbitrary computations and cannot magically "
- "match what the `stream.cmd.dispatch` has already imposed on us at a "
- "distance. For now we must specify the number of values properly when "
- "applying the topLevel tile_to_foreach_thread_op");
-
- SmallVector<int64_t> workgroupCounts;
- for (OpFoldResult ofr : *maybeWorkgroupCounts)
- workgroupCounts.push_back(getConstantIntValue(ofr).value());
- if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
- exportOp))) {
- return foreachThreadOp->emitOpError(
- "failed to populate workload region for dispatchOp: ")
- << exportOp;
+ // operands, if this has not been done already.
+ // Using `transform.iree.tile_to_foreach_thread_and_workgroup_count_region` is
+ // the preferred way to set up tiling and workgroup_count region **at the same
+ // time**.
+ //
+ // The block of code below will be retired once there is enough confidence we
+ // can do everything without it. This includes in particular providing custom
+ // fusion heuristics at the flow level: at this time, the only way to fully
+ // control fusion of more advanced cases is to use the transform dialect at
+ // the flow level and explicitly match the ops we want to fuse.
+ // Once fusion is customizable enough in perpetuity, we can retire this.
+ if (exportOp.getWorkgroupCount().empty()) {
+ auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp);
+ if (failed(maybeWorkgroupCounts) ||
+ llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) {
+ return !getConstantIntValue(ofr).has_value();
+ }))
+ return foreachThreadOp->emitError(
+ "unsupported dynamic workgroup_count atm --- need to slice out "
+ "workgroup_count computation into ExecutableExport::workgroup_count. "
+ "This region may require arbitrary computations and cannot magically "
+ "match what the `stream.cmd.dispatch` has already imposed on us at a "
+ "distance. For now we must specify the number of values properly "
+ "when applying the topLevel tile_to_foreach_thread_op");
+ SmallVector<int64_t> workgroupCounts;
+ for (OpFoldResult ofr : *maybeWorkgroupCounts)
+ workgroupCounts.push_back(getConstantIntValue(ofr).value());
+ if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
+ exportOp))) {
+ return foreachThreadOp->emitOpError(
+ "failed to populate workload region for dispatchOp: ")
+ << exportOp;
+ }
}
// Step 1. Create the workgroup id and count ops.
@@ -406,10 +417,6 @@
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}
- if (!exportOp.getWorkgroupCount().empty())
- return emitDefaultSilenceableFailure(target)
- << "export op must have an empty workgroup count region that "
- "the transform fills --- the transform is not applied";
scf::ForeachThreadOp topLevelForeachThreadOp;
auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -436,8 +443,14 @@
return DiagnosedSilenceableFailure(success());
}
+void transform_dialect::ForeachThreadToWorkgroupOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getTarget(), effects);
+ transform::producesHandle(getTransformed(), effects);
+}
+
//===---------------------------------------------------------------------===//
-// TileToWorkgroupsOp
+// TileToForeachThreadAndWorkgroupCountRegion
//===---------------------------------------------------------------------===//
/// Lower the ops within the workgroup count region of `exportOp` that
@@ -495,41 +508,42 @@
return success();
}
-SmallVector<OpFoldResult>
-transform_dialect::TileToWorkgroupsOp::getMixedNumThreads() {
+SmallVector<OpFoldResult> transform_dialect::
+ TileToForeachThreadAndWorkgroupCountRegion::getMixedNumThreads() {
return getMixedSizes(getStaticNumThreads(), getNumThreads());
}
-SmallVector<OpFoldResult>
-transform_dialect::TileToWorkgroupsOp::getMixedTileSizes() {
+SmallVector<OpFoldResult> transform_dialect::
+ TileToForeachThreadAndWorkgroupCountRegion::getMixedTileSizes() {
return getMixedSizes(getStaticTileSizes(), getTileSizes());
}
-LogicalResult transform_dialect::TileToWorkgroupsOp::verify() {
+LogicalResult
+transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::verify() {
if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
return emitOpError("either num_threads or tile_sizes must be specified");
return success();
}
-void transform_dialect::TileToWorkgroupsOp::getEffects(
+void transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
- transform::consumesHandle(getFunc(), effects);
transform::onlyReadsHandle(getTileSizes(), effects);
transform::onlyReadsHandle(getNumThreads(), effects);
transform::producesHandle(getResults(), effects);
}
-DiagnosedSilenceableFailure transform_dialect::TileToWorkgroupsOp::apply(
+DiagnosedSilenceableFailure
+transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply(
transform::TransformResults &transformResults,
transform::TransformState &state) {
- ArrayRef<Operation *> funcOps = state.getPayloadOps(getFunc());
- assert(funcOps.size() == 1 && "expected single func op in payload");
- FailureOr<IREE::HAL::ExecutableExportOp> exportOp =
- getEntryPoint(cast<func::FuncOp>(funcOps[0]));
+ ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ assert(targetOps.size() == 1 && "expected single target op in payload");
+ auto funcOp = targetOps.front()->getParentOfType<func::FuncOp>();
+ FailureOr<IREE::HAL::ExecutableExportOp> exportOp = getEntryPoint(funcOp);
if (failed(exportOp)) {
state.getTopLevel()->emitOpError("couldn't find export op for func");
- return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOps[0]));
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOp));
}
SmallVector<OpFoldResult> mixedTileSizes = getMixedTileSizes();
@@ -539,8 +553,8 @@
reportUnknownTransformError(exportOp.value()));
}
- /// Lower the workgroup count region in keeping with the way dispatch regions
- /// are created by default in IREEs compilation flow.
+ /// Lower the workgroup count region in keeping with the way dispatch
+ /// regions are created by default in IREEs compilation flow.
IRRewriter rewriter(getContext());
if (failed(lowerWorkgroupCountComputingRegion(rewriter, exportOp.value(),
mixedTileSizes))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 216e25b..7571fb7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -101,7 +101,7 @@
def ForeachThreadToWorkgroupOp : Op<Transform_Dialect,
"iree.foreach_thread_to_workgroup",
[FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface,
TransformEachOpTrait]> {
let description = [{
@@ -153,8 +153,8 @@
}];
}
-def TileToWorkgroupsOp :
- Op<Transform_Dialect, "iree.tile_to_workgroups_op",
+def TileToForeachThreadAndWorkgroupCountRegion :
+ Op<Transform_Dialect, "iree.tile_to_foreach_thread_and_workgroup_count_region",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface]> {
@@ -162,14 +162,13 @@
Wrapper around `structured.tile_to_foreach_thread_op` for use within IREE.
In addition to tile and distribute using `scf.foreach_thread`, lowers the
- the `workgroup_count` region of the export op corresponding to the
- `func.func` to return the number of workgroups.
+ the `workgroup_count` region of the export op corresponding to the parent
+ `func.func` of the target to return the number of workgroups.
Please see the doc of `structured.tile_to_foreach_thread_op` for full
description of op semantics.
}];
let arguments = (ins PDL_Operation:$target,
- PDL_Operation:$func,
Variadic<PDL_Operation>:$num_threads,
Variadic<PDL_Operation>:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
@@ -178,7 +177,7 @@
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
let assemblyFormat = [{
- $target $func oilist(
+ $target oilist(
`num_threads` custom<DynamicIndexList>($num_threads,
$static_num_threads,
"ShapedType::kDynamicSize") |
diff --git a/tests/transform_dialect/cpu/BUILD b/tests/transform_dialect/cpu/BUILD
index 7a7d652..88ac647 100644
--- a/tests/transform_dialect/cpu/BUILD
+++ b/tests/transform_dialect/cpu/BUILD
@@ -20,10 +20,8 @@
# transform dialect spec files are MLIR files that specify a transformation,
# they need to be included as data.
data = [
+ "matmul_codegen_custom_dispatch_formation_spec.mlir",
"matmul_codegen_default_spec.mlir",
- "matmul_codegen_spec.mlir",
- "matmul_dispatch_spec.mlir",
- "matmul_tiled_dispatch_spec.mlir",
],
tags = [
"noasan",
diff --git a/tests/transform_dialect/cpu/CMakeLists.txt b/tests/transform_dialect/cpu/CMakeLists.txt
index 5bfb667..ea778f5 100644
--- a/tests/transform_dialect/cpu/CMakeLists.txt
+++ b/tests/transform_dialect/cpu/CMakeLists.txt
@@ -22,10 +22,8 @@
iree-opt
iree-run-module
DATA
+ matmul_codegen_custom_dispatch_formation_spec.mlir
matmul_codegen_default_spec.mlir
- matmul_codegen_spec.mlir
- matmul_dispatch_spec.mlir
- matmul_tiled_dispatch_spec.mlir
LABELS
"noasan"
"nomsan"
diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir
index 5b0cf05..e7a26e1 100644
--- a/tests/transform_dialect/cpu/matmul.mlir
+++ b/tests/transform_dialect/cpu/matmul.mlir
@@ -10,71 +10,45 @@
return %0 : !C_size
}
-// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-abi-transformation-pipeline \
-// RUN: --iree-flow-transformation-pipeline \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=DISPATCH
-
-// TODO: make this test drop transform dialect usage at the flow level and use:
-// --iree-flow-transformation-pipeline --iree-flow-convert-region-to-workgroups
-// Atm the 3rd flow.dispatch.tensor.load shows as readonly instead of readwrite.
-
-// DISPATCH: flow.executable private @matmul_static_dispatch_0 {
-// DISPATCH: flow.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
-// DISPATCH: builtin.module {
-// DISPATCH: func.func @matmul_static_dispatch_0_matmul_3x3x5
-// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : !flow.dispatch.tensor<readonly:3x5xf32> -> tensor<3x5xf32>
-// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [5, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:5x3xf32> -> tensor<5x3xf32>
-// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : !flow.dispatch.tensor<readwrite:3x3xf32> -> tensor<3x3xf32>
-// DISPATCH: linalg.matmul ins({{.*}} : tensor<3x5xf32>, tensor<5x3xf32>) outs({{.*}} : tensor<3x3xf32>) -> tensor<3x3xf32>
-// DISPATCH: flow.dispatch.tensor.store {{.*}} offsets = [0, 0], sizes = [3, 3], strides = [1, 1] : tensor<3x3xf32> -> !flow.dispatch.tensor<readwrite:3x3xf32>
-// DISPATCH: return
-
-// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
-// RUN: --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=CODEGEN
-
// Run with C++ dispatch region formation but transform dialect codegen
// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
// RUN: --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline \
-// RUN: --iree-flow-dispatch-via-region-ops --iree-flow-dispatch-via-region-ops-generate-workload-region=false \
-// RUN: --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
+// RUN: --iree-flow-dispatch-via-region-ops \
+// RUN: --iree-flow-dispatch-via-region-ops-generate-workload-region=false \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: FileCheck %s --check-prefixes=CODEGEN
+// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_custom_dispatch_formation_spec.mlir | \
+// RUN: FileCheck %s --check-prefix=CODEGEN-CUSTOM-DISPATCH-FORMATION
-// CODEGEN: hal.executable private @matmul_static_dispatch_0 {
-// CODEGEN: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
-// CODEGEN: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} {
-// CODEGEN: ^bb0(%{{.*}}: !hal.device):
-// CODEGEN: arith.constant 2 : index
-// CODEGEN: arith.constant 1 : index
-// CODEGEN: hal.return %{{.*}}, %{{.*}}, %{{.*}} : index, index, index
-// CODEGEN: }
-// CODEGEN: builtin.module {
-// CODEGEN: func.func @matmul_static_dispatch_0_matmul_3x3x5() {
-// CODEGEN: arith.constant 0 : index
-// CODEGEN: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32>
-// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32>
-// CODEGEN: hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32>
-// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32>
-// CODEGEN: hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32>
-// CODEGEN: memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32>
-// CODEGEN: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CODEGEN: affine.apply {{.*}}()[%workgroup_id_x]
-// CODEGEN: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref<?x5xf32, strided<[5, 1], offset: ?>>
-// CODEGEN: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref<?x3xf32, strided<[3, 1], offset: ?>>
-// CODEGEN: linalg.matmul ins(%{{.*}}, %{{.*}} : memref<?x5xf32, strided<[5, 1], offset: ?>>, memref<5x3xf32>) outs(%{{.*}} : memref<?x3xf32, strided<[3, 1], offset: ?>>)
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable private @matmul_static_dispatch_0 {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: ^bb0(%{{.*}}: !hal.device):
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[C2:.*]] = arith.constant 2 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[C1:.*]] = arith.constant 1 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.return %[[C2]], %[[C1]], %[[C1]] : index, index, index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: }
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: builtin.module {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: func.func @matmul_static_dispatch_0_matmul_3x3x5() {
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: arith.constant 0 : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x5xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<3x5xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset({{.*}}) alignment(64) : memref<5x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<5x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset({{.*}}) alignment(64) : memref<3x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: affine.apply {{.*}}()[%workgroup_id_x]
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref<?x5xf32, strided<[5, 1], offset: ?>>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref<?x3xf32, strided<[3, 1], offset: ?>>
+// CODEGEN-CUSTOM-DISPATCH-FORMATION: linalg.matmul ins(%{{.*}}, %{{.*}} : memref<?x5xf32, strided<[5, 1], offset: ?>>, memref<5x3xf32>) outs(%{{.*}} : memref<?x3xf32, strided<[3, 1], offset: ?>>)
// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
-// RUN: --iree-hal-configuration-pipeline | \
+// RUN: --iree-hal-configuration-pipeline | \
// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_default_spec.mlir | \
// RUN: FileCheck %s --check-prefixes=CODEGEN-DEFAULT
@@ -86,22 +60,4 @@
// CODEGEN-DEFAULT: %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CODEGEN-DEFAULT: hal.return %[[D0]], %[[C1]], %[[C1]]
-// RUN: iree-compile %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
-// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
-// RUN: iree-run-module --entry_function=matmul_static \
-// RUN: --function_input="3x5xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \
-// RUN: --function_input="5x3xf32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" \
-// RUN: --function_input="3x3xf32=0 0 0 0 0 0 0 0 0"| \
-// RUN: FileCheck %s --check-prefixes=EXEC
-
// EXEC: 3x3xf32=[5 5 5][5 5 5][5 5 5]
-
-// RUN: iree-compile --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_tiled_dispatch_spec.mlir \
-// RUN: --iree-flow-export-benchmark-funcs %s | \
-// RUN: iree-benchmark-module --device=local-task | \
-// RUN: FileCheck %s --check-prefixes=BENCHMARK-MODULE
-
-// When running iree-benchmark-module, we only check the existence of the func.
-// BENCHMARK-MODULE: matmul_static
diff --git a/tests/transform_dialect/cpu/matmul_codegen_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
similarity index 100%
rename from tests/transform_dialect/cpu/matmul_codegen_spec.mlir
rename to tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
index d26eaf1..9d7b392 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
@@ -2,11 +2,12 @@
transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
- %func = transform.structured.match ops{["func.func"]} in %variant_op
- %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op
%foreach_thread, %tiled_generic =
- transform.iree.tile_to_workgroups_op %0 %func tile_sizes [2]
-
- %1 = transform.iree.bufferize %variant_op
+ transform.iree.tile_to_foreach_thread_and_workgroup_count_region %matmul tile_sizes [2]
+
+ %variant_op_2 = transform.iree.bufferize %variant_op
+ %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+ transform.iree.foreach_thread_to_workgroup %func
}
diff --git a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir
deleted file mode 100644
index eba247b..0000000
--- a/tests/transform_dialect/cpu/matmul_dispatch_spec.mlir
+++ /dev/null
@@ -1,9 +0,0 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %region_op = transform.iree.wrap_in_dispatch_region %0
- transform.iree.region_to_workgroups %region_op
- }
-}
diff --git a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
deleted file mode 100644
index 41dc2dd..0000000
--- a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-transform.structured.canonicalized_sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
- %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
-}
diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD
index 8af33f6..4aeb99b 100644
--- a/tests/transform_dialect/cuda/BUILD
+++ b/tests/transform_dialect/cuda/BUILD
@@ -34,8 +34,9 @@
# they need to be included as data.
data = [
"reduction_codegen_spec.mlir",
- "reduction_dispatch_spec.mlir",
"softmax_codegen_spec.mlir",
+ # FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+ # issue and we even end up emitting out of bounds accesses.
"softmax_dispatch_spec.mlir",
"softmax_fused_codegen_spec.mlir",
],
diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt
index d5a0fba..139817f 100644
--- a/tests/transform_dialect/cuda/CMakeLists.txt
+++ b/tests/transform_dialect/cuda/CMakeLists.txt
@@ -27,7 +27,6 @@
iree-run-module
DATA
reduction_codegen_spec.mlir
- reduction_dispatch_spec.mlir
softmax_codegen_spec.mlir
softmax_dispatch_spec.mlir
softmax_fused_codegen_spec.mlir
diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir
index 5863591..4ea9300 100644
--- a/tests/transform_dialect/cuda/reduction.mlir
+++ b/tests/transform_dialect/cuda/reduction.mlir
@@ -24,7 +24,6 @@
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
@@ -32,7 +31,6 @@
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/reduction_dispatch_spec.mlir \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=reduce --device=cuda |\
// RUN: FileCheck %s --check-prefix=EXEC
@@ -49,7 +47,7 @@
// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
// Distributed reduction: everyone loads then 5 xor + addf expected
- // CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+ // CHECK: vector.transfer_read %{{.*}}[%[[TIDX]]]
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 1015856..18db174 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -14,7 +14,7 @@
// First level of tiling + fusion parallelizes to blocks.
// The mapping to block ids can only happen after bufferization atm.
%foreach_thread_grid, %grid_combiner_op =
- transform.structured.tile_to_foreach_thread_op %combiner_op tile_sizes [1]
+ transform.iree.tile_to_foreach_thread_and_workgroup_count_region %combiner_op tile_sizes [1]
%not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op
transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid
diff --git a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir b/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
deleted file mode 100644
index 7dc0521..0000000
--- a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: iree-opt %s
-
-// Dispatch reduction.
-transform.structured.canonicalized_sequence failures(propagate){
-^bb1(%variant_op: !pdl.operation):
- %root = transform.structured.match interface{LinalgOp}
- attributes{iterator_types = ["parallel", "reduction"]} in %variant_op
- %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
-
- // TODO: this could be replaced by a C++ only version.
- // Atm the IR produced is not the same so all pieces do not connect.
- %region_op = transform.iree.wrap_in_dispatch_region %root
- %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %fill into %region_op
- transform.iree.region_to_workgroups %region_op_2
-}
diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir
index b6e4fed..7fa6c2c 100644
--- a/tests/transform_dialect/cuda/softmax.mlir
+++ b/tests/transform_dialect/cuda/softmax.mlir
@@ -2,7 +2,6 @@
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
@@ -10,7 +9,6 @@
// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
// RUN: FileCheck %s
@@ -18,6 +16,10 @@
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
+///
+/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+/// issue and we even end up emitting out of bounds accesses.
+///
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-stream-transformation-pipeline \
// RUN: --iree-hal-configuration-pipeline | \
@@ -26,6 +28,10 @@
// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+///
+/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite
+/// issue and we even end up emitting out of bounds accesses.
+///
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
@@ -41,11 +47,13 @@
// CHECK-SHUFFLE: gpu.shuffle xor
// Execution only checks that @max_sub_exp runs.
-// CHECK: EXEC @max_sub_exp
+// CHECK: EXEC @max_sub_exp
+// CHECK: 16x128x128xf32=[
+// CHECK-SAME: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
-func.func @max_sub_exp() {
+func.func @max_sub_exp() -> !out_tensor_t {
%cst = arith.constant -3.40282347E+38 : f32
- %cst_0 = arith.constant dense<1.000000e+00> : !out_tensor_t
+ %cst_0 = arith.constant dense<1121212.000000e+00> : !out_tensor_t
%cst_1 = arith.constant dense<5.000000e+00> : !out_tensor_t
%0 = util.do_not_optimize(%cst_1) : !out_tensor_t
@@ -73,6 +81,5 @@
linalg.yield %7 : f32
} -> !out_tensor_t
- check.expect_almost_eq(%5, %cst_0) : !out_tensor_t
- return
+ return %5: !out_tensor_t
}
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 3666f20..4607570 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -1,82 +1,77 @@
// RUN: iree-opt %s
// Codegen
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- // First level of tiling + fusion parallelizes to blocks.
- // The mapping to block ids can only happen after bufferization atm
- %root = transform.structured.match interface{LinalgOp}
- attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
- %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
- %red = transform.structured.match interface{LinalgOp}
- attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
- %not_root = merge_handles %fill, %red
- %foreach_thread, %tiled_generic =
- transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 4]
- transform.structured.fuse_into_containing_op %not_root into %foreach_thread
-
- // Second level of tiling + fusion parallelizes to threads.
- // Leaving the reduction untiled on threadIdx.x makes it sequential on
- // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
- // introduced and opportunities for distributing vector ops across warps
- // appear.
- %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
- %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
- attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
- %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
- attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
- %foreach_thread_reduction, %tiled_reduction_generic =
- transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1]
- (mapped to dims [2, 1, 0])
- // TODO: this fusion currently does not happen properly, this is related to the clone
- // behavior when fusing into scf.foreach_thread.
- // Once fixed we'll be able to fuse.
- // Fusion will save us one roundtrip to memory.
- // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction
- transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32]
- (mapped to dims [2, 1, 0])
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ // First level of tiling + fusion parallelizes to blocks.
+ // The mapping to block ids can only happen after bufferization atm
+ %root = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %red = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %not_root = merge_handles %fill, %red
+ %foreach_thread, %tiled_generic =
+ transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4]
+ transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+
+ // Second level of tiling + fusion parallelizes to threads.
+ // Leaving the reduction untiled on threadIdx.x makes it sequential on
+ // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
+ // introduced and opportunities for distributing vector ops across warps
+ // appear.
+ %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %foreach_thread_reduction, %tiled_reduction_generic =
+ transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1]
+ (mapped to dims [2, 1, 0])
+ // TODO: this fusion currently does not happen properly, this is related to the clone
+ // behavior when fusing into scf.foreach_thread.
+ // Once fixed we'll be able to fuse.
+ // Fusion will save us one roundtrip to memory.
+ // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction
+ transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32]
+ (mapped to dims [2, 1, 0])
- // Inability to tile reductions to scf.foreach_thread has 2 implications:
- // 1. since no scf.foreach_thread is present, no gpu.barrier is added.
- // This should be fixed independently: ops that are not nested in an scf.foreach_thread
- // should have a gpu.barrier. Later needs to be complemented by a barrier
- // removal pass.
- // 2. Similarly, needs to be predicated under an if threadIx == 0 to avoid
- // multiple threads updating the buffer inplace once bufferized.
- //
- // Instead, we can vectorize and go to vector SSA values that sidestep these
- // issues.
- // Everyone will race to the write while still computing the same value.
- //
- // That is still not good enough because we need to predicate this in order
- // to enable the parallel reduction on warps.
- %func = transform.structured.match ops{["func.func"]} in %variant_op
- %funcx = transform.iree.apply_patterns %func { rank_reducing }
- transform.structured.vectorize %funcx
+ // Inability to tile reductions to scf.foreach_thread has 2 implications:
+ // 1. since no scf.foreach_thread is present, no gpu.barrier is added.
+ // This should be fixed independently: ops that are not nested in an scf.foreach_thread
+ // should have a gpu.barrier. Later needs to be complemented by a barrier
+ // removal pass.
+ // 2. Similarly, needs to be predicated under an if threadIx == 0 to avoid
+ // multiple threads updating the buffer inplace once bufferized.
+ //
+ // Instead, we can vectorize and go to vector SSA values that sidestep these
+ // issues.
+ // Everyone will race to the write while still computing the same value.
+ //
+ // That is still not good enough because we need to predicate this in order
+ // to enable the parallel reduction on warps.
+ %func = transform.structured.match ops{["func.func"]} in %variant_op
+ %funcx = transform.iree.apply_patterns %func { rank_reducing }
+ transform.structured.vectorize %funcx
- // Bufferization is necessary for:
- // 1. lowering scf.foreach_thread to workgroup (block level parallelism)
- // 2. lowering scf.foreach_thread to gpu (thread level parallelism)
- // 3. introducing predication (due to 1. + 2.) which enables rewriting to
- // warp_execute_on_lane_0 and later vector distribution.
- %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+ // Bufferization is necessary for:
+ // 1. lowering scf.foreach_thread to workgroup (block level parallelism)
+ // 2. lowering scf.foreach_thread to gpu (thread level parallelism)
+ // 3. introducing predication (due to 1. + 2.) which enables rewriting to
+ // warp_execute_on_lane_0 and later vector distribution.
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
- %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
- %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
- transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
- { workgroup_size = [32, 4, 1] }
+ %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
+ transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+ { workgroup_size = [32, 4, 1] }
- %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
- %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
+ %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
- // Vector distribution needs to happen on buffers.
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
- %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
- transform.iree.vector.warp_distribute %end_func_2
- }
+ // Vector distribution needs to happen on buffers.
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+ %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+ transform.iree.vector.warp_distribute %end_func_2
}
-
-
diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken
new file mode 100644
index 0000000..68f891e
--- /dev/null
+++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken
@@ -0,0 +1,57 @@
+// RUN: iree-opt %s
+
+// Codegen
+transform.structured.canonicalized_sequence failures(propagate) {
+// transform.sequence %arg0 failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ // First level of tiling + fusion parallelizes to blocks.
+ // The mapping to block ids can only happen after bufferization atm
+ %root = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %red = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %not_root = merge_handles %fill, %red
+ %foreach_thread, %tiled_generic =
+ transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 1]
+ (mapped to dims [0, 1, 2])
+ transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+
+ // Second level of tiling + fusion parallelizes to threads.
+ // Leaving the reduction untiled on threadIdx.x makes it sequential on
+ // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
+ // introduced and opportunities for distributing vector ops across warps
+ // appear.
+ %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %not_root_2 = merge_handles %fill_linalg, %reduction_linalg
+ %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %foreach_thread_2, %parallel_linalg_2 =
+ transform.structured.tile_to_foreach_thread_op %parallel_linalg tile_sizes [1, 1, 0]
+ (mapped to dims [2, 1, 0])
+ transform.structured.fuse_into_containing_op %not_root_2 into %foreach_thread_2
+
+ // Rank-reduce and vectorize.
+ %func = transform.structured.match ops{["func.func"]} in %variant_op
+ %funcx = transform.iree.apply_patterns %func { rank_reducing }
+ transform.structured.vectorize %funcx
+
+ // Bufferization is necessary for:
+ // 1. lowering scf.foreach_thread to workgroup (block level parallelism)
+ // 2. lowering scf.foreach_thread to gpu (thread level parallelism)
+ // 3. introducing predication (due to 1. + 2.) which enables rewriting to
+ // warp_execute_on_lane_0 and later vector distribution.
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+ %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
+ transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+ { workgroup_size = [32, 1, 1] }
+
+ // Vector distribution needs to happen on buffers.
+ %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+ %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+ transform.iree.vector.warp_distribute %end_func
+}