[Codegen][GPU] Allow iree_gpu.tensor_barrier to take vectors (#17479)
This allows synchronizing on vectors as well as tensors with similar
semantics. In a typical lowering flow, this will represent the
read equivalent to a tensor barrier, in that a tensor barrier represents
a wait until all writes to a shared allocation has finished, while this
represents a wait until all threads have read the value they need from
that shared allocation.
Renames the operation to iree_gpu.value_barrier for clarity.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index b6dde8f..911a44d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2580,7 +2580,7 @@
%c0 = arith.constant 0 : index
%alloc = bufferization.alloc_tensor() : tensor<2xf32>
%tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
- %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+ %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
%res = vector.transfer_read %barrier[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
return %res : vector<2xf32>
}
@@ -2601,7 +2601,7 @@
%alloc = bufferization.alloc_tensor() : tensor<2xf32>
%loop = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%init = %alloc) -> tensor<2xf32> {
%tmp = vector.transfer_write %cst, %init[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
- %barrier = iree_gpu.tensor_barrier %tmp : tensor<2xf32>
+ %barrier = iree_gpu.value_barrier %tmp : tensor<2xf32>
scf.yield %barrier : tensor<2xf32>
}
%res = vector.transfer_read %loop[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
@@ -2614,3 +2614,23 @@
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: }
// CHECK: vector.transfer_read %[[ALLOC]]
+
+// -----
+
+func.func @vector_barrier() -> vector<2xf32> {
+ %cst = arith.constant dense<0.0> : vector<2xf32>
+ %cst0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %alloc = bufferization.alloc_tensor() : tensor<2xf32>
+ %tmp = vector.transfer_write %cst, %alloc[%c0] {in_bounds = [true]} : vector<2xf32>, tensor<2xf32>
+ %read = vector.transfer_read %tmp[%c0], %cst0 {in_bounds = [true]} : tensor<2xf32>, vector<2xf32>
+ %barrier = iree_gpu.value_barrier %read : vector<2xf32>
+ return %barrier : vector<2xf32>
+}
+
+// Verify that the dual-modes of `value_barrier` are adhered to.
+// CHECK-LABEL: func @vector_barrier()
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]]
+// CHECK-NEXT: %[[RD:.+]] = vector.transfer_read %[[ALLOC]]
+// CHECK-NEXT: iree_gpu.value_barrier %[[RD]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
index 3950ccc..2e106e1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
@@ -237,12 +237,4 @@
return success();
}
-//===----------------------------------------------------------------------===//
-// TensorBarrierOp
-//===----------------------------------------------------------------------===//
-
-MutableOperandRange TensorBarrierOp::getDpsInitsMutable() {
- return getInputMutable();
-}
-
} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
index 4c360a7..6b18e5c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
@@ -352,36 +352,39 @@
}
//===----------------------------------------------------------------------===//
-// TensorBarrierOp
+// ValueBarrierOp
//===----------------------------------------------------------------------===//
-def IREEGPU_TensorBarrierOp : Op<IREEGPU_Dialect, "tensor_barrier", [
+def IREEGPU_ValueBarrierOp : Op<IREEGPU_Dialect, "value_barrier", [
Pure,
- DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
AllTypesMatch<["input", "result"]>,
]> {
let summary = "Shuffles a private tensor across a shared allocation";
let description = [{
- This operation acts as a barrier on a tensor value. It takes a single
- tensor operand and produces an equivalent tensor. This does not have copy
- and/or data movement semantics and simply represents a barrier on all writes
- to the input tensor.
+ This operation acts as a barrier on a value semantic SSA value (tensor or
+ vector). It takes a single operand and produces a value equivalent to the
+ input. This does not have copy and/or data movement semantics and simply
+ represents a barrier on all writes in the tensor case, and a barrier until
+ all threads acquire the input vector in the vector case.
This operation is a no-op when not present in a parallel context. This
operation is pure as it only requires synchronization for the value it
produces.
}];
- let arguments = (ins AnyRankedTensor:$input);
- let results = (outs AnyRankedTensor:$result);
+ let arguments = (ins AnyRankedTensorOrVector:$input);
+ let results = (outs AnyRankedTensorOrVector:$result);
let assemblyFormat = [{
$input attr-dict `:` type($result)
}];
let extraClassDeclaration = [{
- RankedTensorType getInputType() {
- return getInput().getType();
+ bool hasTensorSemantics() {
+ return isa<::mlir::RankedTensorType>(getInput().getType());
+ }
+ ::mlir::ShapedType getInputType() {
+ return ::llvm::cast<::mlir::ShapedType>(getInput().getType());
}
}];
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index 08b2f17..e2744fb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -134,10 +134,21 @@
// -----
func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
- %out = iree_gpu.tensor_barrier %input : tensor<?xf16>
+ %out = iree_gpu.value_barrier %input : tensor<?xf16>
return %out : tensor<?xf16>
}
// CHECK-LABEL: func @tensor_barrier
// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: tensor<?xf16>
-// CHECK: iree_gpu.tensor_barrier %[[INPUT]] : tensor<?xf16>
+// CHECK: iree_gpu.value_barrier %[[INPUT]] : tensor<?xf16>
+
+// -----
+
+func.func @vector_barrier(%input: vector<8xf16>) -> vector<8xf16> {
+ %out = iree_gpu.value_barrier %input : vector<8xf16>
+ return %out : vector<8xf16>
+}
+
+// CHECK-LABEL: func @vector_barrier
+// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<8xf16>
+// CHECK: iree_gpu.value_barrier %[[INPUT]] : vector<8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
index f8cb914..8dcabd1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp
@@ -26,6 +26,15 @@
}
//===---------------------------------------------------------------------===//
+// ApplyLowerValueBarrierOp
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ IREE::GPU::populateIREEGPULowerValueBarrierPatterns(patterns);
+}
+
+//===---------------------------------------------------------------------===//
// ApplyUnrollMultiMmaOp
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
index 08e13ec..dc69083 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td
@@ -14,6 +14,19 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def ApplyLowerValueBarrierOp : Op<Transform_Dialect,
+ "apply_patterns.iree.lower_value_barrier",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Populate patterns to convert value barriers on vectors into gpu.barrier ops.
+ Barriers on tensors are ignored.
+ }];
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyUnrollMultiMmaOp : Op<Transform_Dialect,
"apply_patterns.iree.unroll_multi_mma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
index 11c0c8e..f78f6d0 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel
@@ -18,6 +18,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "lower_vector_barrier.mlir",
"transform_fuse_forall.mlir",
"vectorize_multi_mma.mlir",
"unroll_multi_mma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
index 686ce93..f3e2e40 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "lower_vector_barrier.mlir"
"transform_fuse_forall.mlir"
"unroll_multi_mma.mlir"
"vectorize_multi_mma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir
new file mode 100644
index 0000000..b8391a3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_vector_barrier.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+func.func @lower_value_barrier(%input: vector<4xf32>) -> vector<4xf32> {
+ %0 = iree_gpu.value_barrier %input : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.lower_value_barrier
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @lower_value_barrier
+// CHECK-SAME: %[[INPUT:[A-Za-z0-9]+]]: vector<4xf32>
+// CHECK-NEXT: gpu.barrier
+// CHECK-NEXT: return %[[INPUT]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
index 1071d91..4dc7621 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
@@ -24,9 +24,9 @@
/// Bufferization of iree_gpu.tensor_barrier. Always just bufferizes in place
/// and replaces with a barrier.
-struct TensorBarrierOpBufferizationInterface
+struct ValueBarrierOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
- TensorBarrierOpBufferizationInterface, IREE::GPU::TensorBarrierOp> {
+ ValueBarrierOpBufferizationInterface, IREE::GPU::ValueBarrierOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// This op never needs to bufferize to a copy.
@@ -47,8 +47,11 @@
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
- auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+ auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
assert(value == barrierOp.getResult() && "invalid value");
+ if (!barrierOp.hasTensorSemantics()) {
+ return failure();
+ }
auto srcMemrefType = bufferization::getBufferType(barrierOp.getInput(),
options, invocationStack);
if (failed(srcMemrefType))
@@ -58,7 +61,10 @@
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
- auto barrierOp = cast<IREE::GPU::TensorBarrierOp>(op);
+ auto barrierOp = cast<IREE::GPU::ValueBarrierOp>(op);
+ if (!barrierOp.hasTensorSemantics()) {
+ return failure();
+ }
FailureOr<Value> buffer =
getBuffer(rewriter, barrierOp.getInput(), options);
if (failed(buffer)) {
@@ -78,8 +84,8 @@
void registerIREEGPUBufferizationInterfaces(DialectRegistry ®istry) {
registry.addExtension(
+[](MLIRContext *context, IREE::GPU::IREEGPUDialect *dialect) {
- IREE::GPU::TensorBarrierOp::attachInterface<
- TensorBarrierOpBufferizationInterface>(*context);
+ IREE::GPU::ValueBarrierOp::attachInterface<
+ ValueBarrierOpBufferizationInterface>(*context);
});
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index a01ea45..5e982a8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -11,8 +11,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -406,4 +404,28 @@
patterns.add<VectorizeStaticMultiMmaOpPattern>(patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// VectorBarrierOp Lowering
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LowerValueBarrierPattern
+ : public OpRewritePattern<IREE::GPU::ValueBarrierOp> {
+ using OpRewritePattern<IREE::GPU::ValueBarrierOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IREE::GPU::ValueBarrierOp barrier,
+ PatternRewriter &rewriter) const override {
+ if (barrier.hasTensorSemantics()) {
+ return failure();
+ }
+ rewriter.create<gpu::BarrierOp>(barrier.getLoc());
+ rewriter.replaceOp(barrier, barrier.getInput());
+ return success();
+ }
+};
+} // namespace
+
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) {
+ patterns.add<LowerValueBarrierPattern>(patterns.getContext());
+}
+
} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
index 56d1589..0e2afa3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h
@@ -39,6 +39,8 @@
void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);
+void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
+
} // namespace mlir::iree_compiler::IREE::GPU
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_