[VectorDistribution] Allow 0-d vectors in scf.for distribution (#19317)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index d72ac17..276b7fe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -145,10 +145,8 @@
SmallVector<Value> newInitArgs;
for (Value initArg : forOp.getInitArgs()) {
if (auto vectorInitArg = dyn_cast<VectorValue>(initArg)) {
- if (isNonZeroRank(vectorInitArg)) {
- initArg =
- getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
- }
+ initArg =
+ getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
}
newInitArgs.push_back(initArg);
}
@@ -193,14 +191,8 @@
SmallVector<Value> operands;
for (Value operand : yieldOp->getOperands()) {
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
- // Distributing the operand requires it to have a non-zero rank, meaning
- // it must have at least one dimension. If the vector has a non-zero
- // rank, the operand is distributed according to the provided layout
- // signature.
- if (isNonZeroRank(vectorOperand)) {
- operand = DistributionPattern::getDistributed(
- rewriter, vectorOperand, signature[vectorOperand]);
- }
+ operand = DistributionPattern::getDistributed(rewriter, vectorOperand,
+ signature[vectorOperand]);
}
operands.push_back(operand);
}
@@ -223,10 +215,8 @@
for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) {
Value val = bbArg;
if (auto oldVectorInit = dyn_cast<VectorValue>(oldInit)) {
- if (isNonZeroRank(oldVectorInit)) {
- val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
- oldVectorInit.getLoc(), oldVectorInit.getType(), val);
- }
+ val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+ oldVectorInit.getLoc(), oldVectorInit.getType(), val);
}
replacements.push_back(val);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index 343366e..972c8c3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -62,6 +62,40 @@
return %out : vector<16x16xi32>
}
+#layout_0d = #iree_vector_ext.nested_layout<
+ subgroup_tile = [],
+ batch_tile = [],
+ outer_tile = [],
+ thread_tile = [],
+ element_tile = [],
+
+ subgroup_strides = [],
+ thread_strides = []
+>
+
+// CHECK-LABEL: @distribute_scf_for_0d
+func.func @distribute_scf_for_0d(%a: vector<i32>, %b: vector<i32>) -> vector<i32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %cst_0 = arith.constant 0 : i32
+ // CHECK: %[[ROOT:.*]] = arith.constant dense<0> : vector<i32>
+ %root = arith.constant dense<0> : vector<i32>
+ %rootl = iree_vector_ext.to_layout %root to layout(#layout_0d) : vector<i32>
+ // CHECK: iter_args(%[[ARG0:.*]] = %[[ROOT]]) -> (vector<i32>)
+ %out = scf.for %i = %c0 to %c128 step %c1 iter_args(%arg0 = %rootl) -> (vector<i32>) {
+ // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<i32> -> vector<i32>
+ // CHECK-DAG: %[[C:.*]] = arith.muli %[[ARG0]], %[[B]] {{.*}} : vector<i32>
+ %c = arith.muli %arg0, %b : vector<i32>
+ // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<i32> -> vector<i32>
+ // CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] {{.*}} : vector<i32>
+ %d = arith.addi %c, %a : vector<i32>
+ // CHECK: scf.yield %[[D]] : vector<i32>
+ scf.yield %d : vector<i32>
+ }
+ return %out : vector<i32>
+}
+
builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op