[VectorDistribution] Replace layout_resolution with to_layout (#18027)
This patch replaces the layout_resolution operator with a new
"to_layout" operation, representing a layout cast on the result. This
allows the operation to be used as an anchor and a conversion operation.
This operation will be used in later patches to set layout anchors in IR
and preserve them.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 4b3ae90..b953d78 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -750,7 +750,7 @@
/// sequence of multiplications and additions.
///
struct DistributeLayoutConflictResolutions final
- : OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
+ : OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;
VectorValue reshapeVector(Location loc, RewriterBase &rewriter,
@@ -792,10 +792,9 @@
return newVector;
}
- LogicalResult
- matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
- DistributionSignature &signature,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
@@ -837,13 +836,12 @@
/// especially used when we don't have an optimized way
/// to resolve the conflict.
struct DistributeLayoutConflictToSharedMemory final
- : OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
+ : OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;
- LogicalResult
- matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
- DistributionSignature &signature,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
auto loc = resolutionOp.getLoc();
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index e5ecf59..2654eae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -643,7 +643,7 @@
%vcst = arith.constant dense<0.0> : vector<32x16xf16>
// CHECK-COUNT-8: vector.load %[[MEM]]
%vec = vector.transfer_read %a[%c0, %c0], %cst {"__vector_layout_test_anchor_result_0" = #layout1} : memref<32x16xf16>, vector<32x16xf16>
- // CHECK: iree_vector_ext.layout_conflict_resolution {{.*}}
+ // CHECK: iree_vector_ext.to_layout {{.*}}
%vec2 = arith.addf %vec, %vcst : vector<32x16xf16>
// CHECK-COUNT-16: vector.store {{.*}}, vector<1xf16>
vector.transfer_write %vec2, %b[%c0, %c0] {in_bounds = [true, true],
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index a8aae48..69c7ff9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -1152,7 +1152,7 @@
mlir::FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
// Do not emit remarks for conflict operations.
- if (isa<VectorExt::LayoutConflictResolutionOp>(op)) {
+ if (isa<VectorExt::ToLayoutOp>(op)) {
return;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 394d740..98ad063 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -217,9 +217,8 @@
Value input = opOperand.get();
// Create a resolution operation. This conflict should be handeled later by
// someone else, not this analysis.
- Operation *resolveOp =
- builder.create<IREE ::VectorExt::LayoutConflictResolutionOp>(
- input.getLoc(), input.getType(), input, vectorLayout, rhs);
+ Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
+ input.getLoc(), input.getType(), input, rhs);
Value resolvedValue = resolveOp->getResult(0);
opOperand.set(resolvedValue);
@@ -1015,9 +1014,9 @@
continue;
}
- // Do not annotate resolve_conflict operations since they already have
+ // Do not annotate to_layout operations since they already have
// this information in their attributes.
- if (isa<IREE::VectorExt::LayoutConflictResolutionOp>(op)) {
+ if (isa<IREE::VectorExt::ToLayoutOp>(op)) {
continue;
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 6c9f90a..77d0e95 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -1676,10 +1676,8 @@
if (!parentOp || (parentOp->getNumResults() != 1))
continue;
parentOp->setAttr("__vector_layout_test_anchor_result_0", readLayout);
- Value resolvedOperand =
- rewriter.create<VectorExt::LayoutConflictResolutionOp>(
- contract.getLoc(), operand.getType(), operand, layout,
- readLayout);
+ Value resolvedOperand = rewriter.create<VectorExt::ToLayoutOp>(
+ contract.getLoc(), operand.getType(), operand, layout);
contract.setOperand(operandIndices[i], resolvedOperand);
}
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
index fdaecc6..f21607a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
@@ -24,23 +24,24 @@
// Layout ops.
//===----------------------------------------------------------------------===//
-def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> {
- let summary = "Layout Conflict Resolution operator";
+def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
+ Pure,
+ AllTypesMatch<["input", "output"]>
+ ]> {
+ let summary = "Layout conversion operator";
let description = [{
- The layout conflict resolution operator takes a vector and a
- desired layout and transforms the vector to one with the
- desired layout.
+ The layout conversion operator takes a shaped value and a layout and
+ transforms the value to have that layout.
}];
let arguments = (ins
AnyVector:$input,
- VectorLayoutInterface:$sourceLayout,
- VectorLayoutInterface:$desiredLayout
+ VectorLayoutInterface:$layout
);
let results = (outs
AnyVector:$output
);
let extraClassDeclaration = [{}];
- let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
+ let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
let hasVerifier = 1;
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index dd52c2c..c694398 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -17,15 +17,9 @@
// LayoutConflictResolutionOp
//===----------------------------------------------------------------------===//
-// Validate that the desired layout has the same shape as the input.
-LogicalResult LayoutConflictResolutionOp::verify() {
- if (getSourceLayout().isValidLayout(getInput()).failed()) {
- return failure();
- }
- if (getDesiredLayout().isValidLayout(getOutput()).failed()) {
- return failure();
- }
- return success();
+// Validate that the layout has the same shape as the input.
+LogicalResult ToLayoutOp::verify() {
+ return getLayout().isValidLayout(getInput());
}
// to_simd -> to_simt
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
index 739f3d4..7a848362 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
@@ -3,28 +3,12 @@
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
-#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
-func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
+func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
- %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
- %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
- return %2 : vector<32x32xf16>
-}
-
-// -----
-
-#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
-#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
-#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
-#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
-func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
- %cst_0 = arith.constant 0.0 : f16
- %c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
- // expected-error @-1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
- %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16>
+ %2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 49fe27c..5d8018d 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -2,22 +2,18 @@
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 4, 4]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
-#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
#layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1>
func.func @specify_layout(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
- %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
+ %2 = iree_vector_ext.to_layout %result to #layout2 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}
// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.layout<<[ BATCHY, LANEY, VECTORX], [4, 2, 4]>, <[ BATCHX, LANEX, VECTORY], [2, 4, 4]>>
-// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.layout<<[ BATCHX, LANEX, VECTORY], [2, 4, 4]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>
// CHECK-LABEL: func.func @specify_layout
-// CHECK: iree_vector_ext.layout_conflict_resolution
-// CHECK-SAME: desiredLayout = #[[LAYOUT0]]
-// CHECK-SAME: sourceLayout = #[[LAYOUT1]]
+// CHECK: iree_vector_ext.to_layout {{.*}} to #[[LAYOUT0]]
// -----