[Codegen][GPU] Allow iree_gpu.barrier_region to take multiple operands/results (#18490)
The restriction to a single input and output was artificial as this op
simply represents synchronization on input and output values.
Additionally this removes the restriction on tensor/vector types, but
for the time being this op is still only used with those types.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
index 780fc09..0705895 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -32,18 +33,18 @@
// Build a BarrierRegionOp with an empty.
void BarrierRegionOp::build(OpBuilder &b, OperationState &result,
- Type resultType, Value dest) {
- result.addOperands(dest);
+ TypeRange resultTypes, ValueRange inputs) {
+ result.addOperands(inputs);
(void)result.addRegion();
- result.addTypes(resultType);
+ result.addTypes(resultTypes);
+ SmallVector<Location> blockArgLocs(inputs.size(), result.location);
Region *region = result.regions[0].get();
// `builder.createBlock` changes the insertion point within the block. Create
// a guard to reset the insertion point of the builder after it is destroyed.
OpBuilder::InsertionGuard guard(b);
- b.createBlock(region, region->end(), ArrayRef<Type>{dest.getType()},
- ArrayRef<Location>{result.location});
+ b.createBlock(region, region->end(), inputs.getTypes(), blockArgLocs);
}
LogicalResult BarrierRegionOp::verify() { return success(); }
@@ -51,19 +52,26 @@
LogicalResult BarrierRegionOp::verifyRegions() {
auto ®ion = getRegion();
Block &block = region.front();
- if (block.getNumArguments() != 1) {
- return emitError("expected the block to have a single argument");
+ if (block.getNumArguments() != getNumOperands()) {
+ return emitError(
+ "expected the block argument count to match operand count");
}
- if (block.getArgumentTypes()[0] != getDestType()) {
- return emitError("expected block to have single argument type of")
- << getDestType();
+ if (!llvm::all_of_zip(block.getArgumentTypes(), getOperandTypes(),
+ [](Type a, Type b) { return a == b; })) {
+ return emitError("expected block argument types to match operand types");
}
// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<GPU::YieldOp>(block.getTerminator());
- if (yieldOp.getValue().getType() != getResult().getType()) {
- return emitOpError("expected yield type to match result type");
+ if (yieldOp->getNumOperands() != getNumResults()) {
+ return emitOpError(
+ "expected body to yield same number of values as results");
+ }
+
+ if (!llvm::all_of_zip(yieldOp->getOperandTypes(), getResultTypes(),
+ [](Type a, Type b) { return a == b; })) {
+ return emitError("expected yielded value types to match result types");
}
return success();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 95f53fca..f2d9586 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -28,8 +28,8 @@
]> {
let summary = "Synchronizes uses of a shared tensor.";
let description = [{
- This op is designed to represent synchronization of workers on a
- particular shared tensor. This operation naturally arises when combining
+ This op is designed to represent synchronization of workers on the operands
+ and results of the given region. This operation naturally arises when combining
the regions of producer-consumer `scf.forall` operations that share a
mapping type.
@@ -58,27 +58,26 @@
```mlir
%0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
- %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
- %in = ...
- %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
- %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
- %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%alloc = bufferization.alloc_tensor {memory_space = #gpu.address_space<workgroup>}
: tensor<4x128xf32>
- %inserted_slice = tensor.insert_slice %in into %alloc[%2, %3] [2, 4] [1, 1]
- : tensor<2x4xf32> to tensor<4x128xf32>
- %slice = iree_gpu.barrier_region %inserted_slice {
+ %barrier = iree_gpu.barrier_region %alloc {
^bb0(%shared: tensor<4x128xf32>):
- %slice = tensor.extract_slice %shared[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
+ %ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
+ %in = ...
+ %2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
+ %3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
+ %inserted_slice = tensor.insert_slice %in into %shared[%2, %3] [2, 4] [1, 1]
+ : tensor<2x4xf32> to tensor<4x128xf32>
iree_gpu.yield %slice : tensor<4x16xf32>
} : tensor<4x128xf32> -> tensor<4x16xf32>
+ %4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
+ %slice = tensor.extract_slice %barrier[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
```
- A barrier_region can be lowered to two barriers, one on the |dest| input
- operand, and a second one on the result. Note that the |dest| operand must
- bufferize with memory space `#gpu.address_space<workgroup>`.
+ A barrier_region can be lowered to two barriers, one on the input operands
+ and a second one on the results.
Movtivation and Intended Use Cases:
@@ -92,26 +91,20 @@
}];
let arguments = (ins
- AnyRankedTensor:$dest
+ Variadic<AnyType>:$inputs
);
let regions = (region SizedRegion<1>:$region);
- let results = (outs AnyRankedTensorOrVector:$result);
+ let results = (outs Variadic<AnyType>:$results);
let assemblyFormat = [{
- $dest $region attr-dict
- `:` type($dest) `->` type($result)
+ (`ins` `(` $inputs^ `:` type($inputs) `)` )?
+ $region attr-dict `:` type($results)
}];
let builders = [
- OpBuilder<(ins "Type":$result_type, "Value":$dest)>
+ OpBuilder<(ins "TypeRange":$result_types, "ValueRange":$inputs)>
];
- let extraClassDeclaration = [{
- RankedTensorType getDestType() {
- return getDest().getType();
- }
- }];
-
let skipDefaultBuilders = 1;
let hasVerifier = 1;
let hasRegionVerifier = 1;
@@ -448,14 +441,16 @@
def IREEGPU_YieldOp : Op<IREEGPU_Dialect, "yield", [
Pure, ReturnLike, Terminator,
HasParent<"::mlir::iree_compiler::IREE::GPU::BarrierRegionOp">]> {
- let summary = "Yield a value from a region";
+ let summary = "Yield values from a region";
let description = [{
- This operation is used to yield a single value from a within a region.
+ This operation is used to yield values from a within a region.
}];
- let arguments = (ins AnyType:$value);
- let assemblyFormat = "$value attr-dict `:` type($value)";
+ let arguments = (ins Variadic<AnyType>:$values);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+ let assemblyFormat =
+ [{ attr-dict ($values^ `:` type($values))? }];
}
#endif // IREE_CODEGEN_DIALECT_IREEGPUOPS
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index e706645..b174f8f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -1,39 +1,53 @@
// RUN: iree-opt %s --split-input-file | FileCheck %s
func.func @barrier_region(%init: tensor<6x6xf32>) -> tensor<3x2xf32> {
- %0 = iree_gpu.barrier_region %init {
+ %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
^bb0(%intermediate: tensor<6x6xf32>):
%slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
iree_gpu.yield %slice : tensor<3x2xf32>
- } : tensor<6x6xf32> -> tensor<3x2xf32>
+ } : tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}
// CHECK-LABEL: func @barrier_region
// CHECK-SAME: %[[INIT:[A-Za-z0-9]+]]: tensor<6x6xf32>
-// CHECK: iree_gpu.barrier_region %[[INIT]] {
+// CHECK: iree_gpu.barrier_region ins(%[[INIT]] : tensor<6x6xf32>) {
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<6x6xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
// CHECK: iree_gpu.yield %[[SLICE]] : tensor<3x2xf32>
-// CHECK: } : tensor<6x6xf32> -> tensor<3x2xf32>
+// CHECK: } : tensor<3x2xf32>
// -----
-func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> tensor<2x1x3x2xf32> {
- %0 = iree_gpu.barrier_region %init {
+func.func @multi_result_barrier_region(%init: tensor<12x12xf32>) -> (tensor<2x1x3x2xf32>, index) {
+ %0:2 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
^bb0(%intermediate: tensor<12x12xf32>):
%expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
%slice = tensor.extract_slice %expand[0, 0, 0, 0] [2, 1, 3, 2] [1, 1, 1, 1] : tensor<4x3x3x4xf32> to tensor<2x1x3x2xf32>
- iree_gpu.yield %slice : tensor<2x1x3x2xf32>
- } : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
- return %0 : tensor<2x1x3x2xf32>
+ %c0 = arith.constant 0 : index
+ iree_gpu.yield %slice, %c0 : tensor<2x1x3x2xf32>, index
+ } : tensor<2x1x3x2xf32>, index
+ return %0#0, %0#1 : tensor<2x1x3x2xf32>, index
}
-// CHECK-LABEL: func @reshape_barrier_region
-// CHECK: iree_gpu.barrier_region
-// CHECK: tensor.expand_shape
-// CHECK: tensor.extract_slice
-// CHECK: } : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
+// CHECK-LABEL: func @multi_result_barrier_region
+// CHECK: %{{.*}}:2 = iree_gpu.barrier_region ins(%{{.*}} : tensor<12x12xf32>)
+// CHECK: } : tensor<2x1x3x2xf32>, index
+
+// -----
+
+func.func @multi_input_barrier_region(%x: index, %y: index) -> index {
+ %0 = iree_gpu.barrier_region ins(%x, %y : index, index) {
+ ^bb0(%ix: index, %iy: index):
+ %sum = arith.addi %ix, %iy : index
+ iree_gpu.yield %sum : index
+ } : index
+ return %0 : index
+}
+
+// CHECK-LABEL: func @multi_input_barrier_region
+// CHECK: %{{.*}} = iree_gpu.barrier_region ins(%{{.*}}, %{{.*}} : index, index)
+// CHECK: } : index
// -----
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
index b70768d..d1a3201 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir
@@ -60,11 +60,11 @@
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1]
// CHECK: scf.yield %[[INSERT]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
-// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
+// CHECK: } : tensor<16x16xf32>
// CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: %[[MM:.+]] = linalg.matmul ins(%[[SHUFFLE]], %[[SHUFFLE]] : tensor<16x16xf32>, tensor<16x16xf32>)
// CHECK-SAME: outs(%[[OUTSLICE]] : tensor<16x16xf32>) -> tensor<16x16xf32>
@@ -124,8 +124,8 @@
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
-// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
+// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
// -----
@@ -180,12 +180,12 @@
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
-// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
+// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
// -----
@@ -253,8 +253,8 @@
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128]
// CHECK: scf.yield %[[INSERT]]
-// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
-// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
+// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
+// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}
@@ -308,7 +308,7 @@
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
// CHECK: scf.yield
-// CHECK: iree_gpu.barrier_region %[[LOOP]]
+// CHECK: iree_gpu.barrier_region ins(%[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
// -----
@@ -363,5 +363,5 @@
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
// CHECK: scf.yield
-// CHECK: iree_gpu.barrier_region %[[LOOP]]
+// CHECK: iree_gpu.barrier_region ins(%[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
index 8e901c3..9f2a811 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_lower_barrier_region.mlir
@@ -1,11 +1,11 @@
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
func.func @barrier_region(%init: tensor<6x6xf32>, %x: index) -> tensor<3x2xf32> {
- %0 = iree_gpu.barrier_region %init {
+ %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
^bb0(%intermediate: tensor<6x6xf32>):
%slice = tensor.extract_slice %intermediate[0, %x] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
iree_gpu.yield %slice : tensor<3x2xf32>
- } : tensor<6x6xf32> -> tensor<3x2xf32>
+ } : tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}
@@ -33,12 +33,12 @@
func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> vector<2x1x3x2xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
- %0 = iree_gpu.barrier_region %init {
+ %0 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
^bb0(%intermediate: tensor<12x12xf32>):
%expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
%read = vector.transfer_read %expand[%c0, %c0, %c0, %c0], %cst : tensor<4x3x3x4xf32>, vector<2x1x3x2xf32>
iree_gpu.yield %read : vector<2x1x3x2xf32>
- } : tensor<12x12xf32> -> vector<2x1x3x2xf32>
+ } : vector<2x1x3x2xf32>
return %0 : vector<2x1x3x2xf32>
}
@@ -59,3 +59,31 @@
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[WRITE_BARRIER]]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[EXPAND]]
// CHECK: %[[READ_BARRIER:.+]] = iree_gpu.value_barrier %[[READ]]
+
+// -----
+
+func.func @multi_barrier_region(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2xf32>) {
+ %0:2 = iree_gpu.barrier_region ins(%arg0, %arg1 : tensor<2xf32>, tensor<3xf32>) {
+ ^bb0(%in0: tensor<2xf32>, %in1: tensor<3xf32>):
+ iree_gpu.yield %in1, %in0 : tensor<3xf32>, tensor<2xf32>
+ } : tensor<3xf32>, tensor<2xf32>
+ return %0#0, %0#1 : tensor<3xf32>, tensor<2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.lower_barrier_region
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @multi_barrier_region
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3xf32>
+
+// CHECK: %[[WB:.+]]:2 = iree_gpu.value_barrier %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RB:.+]]:2 = iree_gpu.value_barrier %[[WB]]#1, %[[WB]]#0
+// CHECK: return %[[RB]]#0, %[[RB]]#1 : tensor<3xf32>, tensor<2xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
index d7f8457..07bd94b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/vectorize_iree_gpu_ops.mlir
@@ -75,11 +75,11 @@
// -----
func.func @barrier_region(%init: tensor<6x6xf32>) -> tensor<3x2xf32> {
- %0 = iree_gpu.barrier_region %init {
+ %0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
^bb0(%intermediate: tensor<6x6xf32>):
%slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
iree_gpu.yield %slice : tensor<3x2xf32>
- } : tensor<6x6xf32> -> tensor<3x2xf32>
+ } : tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}
@@ -99,6 +99,38 @@
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : tensor<3x2xf32>, vector<3x2xf32>
// CHECK: iree_gpu.yield %[[READ]] : vector<3x2xf32>
-// CHECK: } : tensor<6x6xf32> -> vector<3x2xf32>
+// CHECK: } : vector<3x2xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2xf32>
// CHECK: vector.transfer_write %[[SHUFFLE]], %[[EMPTY]]
+
+// -----
+
+func.func @multi_result_barrier_region(%init: tensor<6x6xf32>) -> (index, tensor<3x2xf32>) {
+ %0:2 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
+ ^bb0(%intermediate: tensor<6x6xf32>):
+ %slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
+ %c0 = arith.constant 0 : index
+ iree_gpu.yield %c0, %slice : index, tensor<3x2xf32>
+ } : index, tensor<3x2xf32>
+ return %0#0, %0#1 : index, tensor<3x2xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.vectorize_iree_gpu
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @multi_result_barrier_region
+// CHECK: %[[SHUFFLE:.+]]:2 = iree_gpu.barrier_region
+// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<6x6xf32>):
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
+// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : tensor<3x2xf32>, vector<3x2xf32>
+// CHECK: iree_gpu.yield %c0, %[[READ]] : index, vector<3x2xf32>
+// CHECK: } : index, vector<3x2xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2xf32>
+// CHECK: vector.transfer_write %[[SHUFFLE]]#1, %[[EMPTY]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index 60360a7..54e840b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -13,6 +13,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
@@ -34,6 +35,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#define DEBUG_TYPE "iree-codegen-gpu-transforms"
@@ -128,8 +130,8 @@
}
(*consumerChain.begin())
->replaceUsesOfWith(source, barrierRegionOp.getBody()->getArgument(0));
- rewriter.replaceAllUsesExcept(extractSlice.getResult(), barrierRegionOp,
- terminator);
+ rewriter.replaceAllUsesExcept(extractSlice.getResult(),
+ barrierRegionOp.getResult(0), terminator);
}
LogicalResult fuseForallIntoSlice(RewriterBase &rewriter,
@@ -975,20 +977,19 @@
// Step 1. Synchronize the workers on the shared dest.
auto writeBarrier = rewriter.create<IREE::GPU::ValueBarrierOp>(
- loc, barrierRegionOp.getDest());
+ loc, barrierRegionOp.getInputs());
// Step 2. Inline the barrier op region.
auto terminator = barrierRegionOp.getBody()->getTerminator();
- Value replacement = terminator->getOperand(0);
rewriter.inlineBlockBefore(barrierRegionOp.getBody(), barrierRegionOp,
- {writeBarrier.getResult(0)});
- rewriter.setInsertionPointAfterValue(replacement);
- Value barrier;
+ writeBarrier.getResults());
+ rewriter.setInsertionPoint(terminator);
// Step 3. Synchronize the result value.
- barrier = rewriter.create<IREE::GPU::ValueBarrierOp>(loc, replacement)
- .getResult(0);
- rewriter.replaceAllUsesWith(barrierRegionOp.getResult(), barrier);
+ auto barrier = rewriter.create<IREE::GPU::ValueBarrierOp>(
+ loc, terminator->getOperands());
+ rewriter.replaceAllUsesWith(barrierRegionOp.getResults(),
+ barrier.getResults());
rewriter.eraseOp(terminator);
return success();
}
@@ -1068,48 +1069,79 @@
static LogicalResult
vectorizeStaticBarrierRegionResult(RewriterBase &rewriter,
IREE::GPU::BarrierRegionOp barrier) {
- auto tensorResultType =
- dyn_cast<RankedTensorType>(barrier.getResult().getType());
- if (!tensorResultType || !tensorResultType.hasStaticShape()) {
+ SmallVector<Type> newResultTypes(barrier->getResultTypes());
+ llvm::SmallBitVector vectorizationTargets(newResultTypes.size(), false);
+ for (auto [i, type] : llvm::enumerate(newResultTypes)) {
+ auto tensorResultType = dyn_cast<RankedTensorType>(type);
+ if (!tensorResultType || !tensorResultType.hasStaticShape()) {
+ continue;
+ }
+ vectorizationTargets[i] = true;
+ VectorType newResultType = VectorType::get(
+ tensorResultType.getShape(), tensorResultType.getElementType());
+ type = newResultType;
+ }
+
+ if (vectorizationTargets.none()) {
return failure();
}
- VectorType newResultType = VectorType::get(tensorResultType.getShape(),
- tensorResultType.getElementType());
-
- auto paddingValue = rewriter.create<arith::ConstantOp>(
- barrier.getLoc(), rewriter.getZeroAttr(newResultType.getElementType()));
-
auto newBarrier = rewriter.create<IREE::GPU::BarrierRegionOp>(
- barrier.getLoc(), newResultType, barrier.getDest());
-
+ barrier.getLoc(), newResultTypes, barrier.getInputs());
auto currentTerminator =
cast<IREE::GPU::YieldOp>(barrier.getBody()->getTerminator());
+ rewriter.setInsertionPointToEnd(newBarrier.getBody());
rewriter.mergeBlocks(barrier.getBody(), newBarrier.getBody(),
newBarrier.getBody()->getArguments());
- rewriter.setInsertionPointToEnd(newBarrier.getBody());
- auto innerRead = vector::createReadOrMaskedRead(
- rewriter, currentTerminator.getLoc(), currentTerminator->getOperand(0),
- newResultType.getShape(), paddingValue,
- /*useInBoundsInsteadOfMasking=*/true);
- rewriter.create<IREE::GPU::YieldOp>(currentTerminator->getLoc(), innerRead);
+ // Create the tensor -> vector conversions within the body of the new op.
+ SmallVector<Value> newYields = currentTerminator.getOperands();
+ for (auto [i, val] : llvm::enumerate(newYields)) {
+ if (!vectorizationTargets[i]) {
+ continue;
+ }
+
+ auto resultType = cast<VectorType>(newResultTypes[i]);
+ auto paddingValue = rewriter.create<arith::ConstantOp>(
+ barrier.getLoc(), rewriter.getZeroAttr(resultType.getElementType()));
+
+ auto innerRead =
+ vector::createReadOrMaskedRead(rewriter, currentTerminator.getLoc(),
+ val, resultType.getShape(), paddingValue,
+ /*useInBoundsInsteadOfMasking=*/true);
+ val = innerRead;
+ }
+
+ rewriter.create<IREE::GPU::YieldOp>(currentTerminator->getLoc(), newYields);
rewriter.eraseOp(currentTerminator);
rewriter.setInsertionPointAfter(newBarrier);
- // Create the write back to a tensor.
- auto empty = rewriter.create<tensor::EmptyOp>(
- barrier.getLoc(), tensorResultType.getShape(),
- tensorResultType.getElementType());
- int64_t rank = tensorResultType.getRank();
- auto zero = rewriter.create<arith::ConstantIndexOp>(barrier.getLoc(), 0);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- barrier,
- /*vector=*/newBarrier,
- /*source=*/empty,
- /*indices=*/SmallVector<Value>(rank, zero),
- /*inBounds=*/SmallVector<bool>(rank, true));
+ // Create the writes back to tensor types.
+ SmallVector<Value> replacements = newBarrier.getResults();
+ for (auto [i, val] : llvm::enumerate(replacements)) {
+ if (!vectorizationTargets[i]) {
+ continue;
+ }
+
+ auto tensorResultType =
+ cast<RankedTensorType>(barrier->getResultTypes()[i]);
+ auto empty = rewriter.create<tensor::EmptyOp>(
+ barrier.getLoc(), tensorResultType.getShape(),
+ tensorResultType.getElementType());
+ int64_t rank = tensorResultType.getRank();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(barrier.getLoc(), 0);
+ auto write = rewriter.create<vector::TransferWriteOp>(
+ barrier.getLoc(),
+ /*vector=*/val,
+ /*dest=*/empty,
+ /*indices=*/SmallVector<Value>(rank, zero),
+ /*inBounds=*/SmallVector<bool>(rank, true));
+ val = write->getResult(0);
+ }
+
+ rewriter.replaceOp(barrier, replacements);
+
return success();
}