[VectorDistribution] Fix 0-rank vector.broadcast distribution (#19007)
Fixes: https://github.com/iree-org/iree/issues/18955
Also removes locally carried revert
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 80647b9..164d900 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -305,9 +305,9 @@
auto vectorType = VectorType::get(distShape, elementType);
VectorValue srcVector = dyn_cast<VectorValue>(broadcastOp.getSource());
- // If the srcVector is a scalar (like f32) or a rank-0 vector (like
- // vector<f32>), we proceed with the scalar distribution branch.
- if (!srcVector || !isNonZeroRank(srcVector)) {
+ // If the srcVector is a scalar (like f32) we proceed with the scalar
+ // distribution branch.
+ if (!srcVector) {
// The way distribution currently works, there is no partial thread
// distribution, so a scalar is available to all threads. Scalar
// distribution is simply a broadcast from scalar to the distributed
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 7e927b4..a883180 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -132,16 +132,14 @@
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
// If this value is a vector type, it must be converted back to simd.
- if (auto replacementType = dyn_cast<VectorType>(replacement.getType())) {
- if (replacementType.getRank() != 0) {
- auto oldResult = cast<VectorValue>(opResult);
- // Create a toSIMD op to convert the value back to the simd.
- rewriter.setInsertionPointAfterValue(oldResult);
- Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
- oldResult.getLoc(), oldResult.getType(), replacement);
- // Add to replacements.
- replacement = toSIMD;
- }
+ if (isa<VectorType>(replacement.getType())) {
+ auto oldResult = cast<VectorValue>(opResult);
+ // Create a toSIMD op to convert the value back to the simd.
+ rewriter.setInsertionPointAfterValue(oldResult);
+ Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+ oldResult.getLoc(), oldResult.getType(), replacement);
+ // Add to replacements.
+ replacement = toSIMD;
}
replacements.push_back(replacement);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index 71448ef..98455c9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -790,6 +790,47 @@
thread_tile = [4, 16, 8],
element_tile = [1, 4, 4],
subgroup_strides = [4, 2, 1],
+ thread_strides = [128, 8, 1]
+>
+
+func.func @zero_rank_broadcast(%src: vector<f16>) -> (vector<32x256x64xf16>) {
+ %bcast = vector.broadcast %src : vector<f16> to vector<32x256x64xf16>
+ %bcastl = iree_vector_ext.to_layout %bcast to layout(#layout) : vector<32x256x64xf16>
+ return %bcastl : vector<32x256x64xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @zero_rank_broadcast
+// CHECK-SAME: (%[[SRC:.*]]: vector<f16>)
+// CHECK: %[[SRC_SIMT:.*]] = iree_vector_ext.to_simt %[[SRC]] : vector<f16>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC_SIMT]]
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f16 to vector<1x4x4xf16>
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: vector.insert %[[BCAST]], %{{.*}}
+// CHECK: %[[OUT:.*]] = vector.insert %[[BCAST]], %{{.*}}
+// CHECK: iree_vector_ext.to_simd %[[OUT]] : vector<2x2x1x2x1x1x1x4x4xf16> -> vector<32x256x64xf16>
+
+// -----
+
+#layout = #iree_vector_ext.nested_layout<
+ subgroup_tile = [2, 2, 2],
+ batch_tile = [2, 2, 1],
+ outer_tile = [2, 1, 1],
+ thread_tile = [4, 16, 8],
+ element_tile = [1, 4, 4],
+ subgroup_strides = [4, 2, 1],
thread_strides = [128, 8, 1]
>
diff --git a/third_party/llvm-project b/third_party/llvm-project
index ac39504..889525f 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit ac39504813f8c52f10c0e364485569bff5a5f7a1
+Subproject commit 889525fa99b251dc962edb516e0108088ba7e44d