[VectorDistribution] Replace layout_resolution with to_layout (#18027)

This patch replaces the layout_resolution operator with a new
"to_layout" operation, representing a layout cast on the result. This
allows the operation to be used as an anchor and a conversion operation.

This operation will be used in later patches to set layout anchors in IR
and preserve them.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 4b3ae90..b953d78 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -750,7 +750,7 @@
 ///     sequence of multiplications and additions.
 ///
 struct DistributeLayoutConflictResolutions final
-    : OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
+    : OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
   VectorValue reshapeVector(Location loc, RewriterBase &rewriter,
@@ -792,10 +792,9 @@
     return newVector;
   }
 
-  LogicalResult
-  matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
-                  DistributionSignature &signature,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
     VectorValue vector = resolutionOp.getInput();
     VectorValue result = resolutionOp.getOutput();
     LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
@@ -837,13 +836,12 @@
 /// especially used when we don't have an optimized way
 /// to resolve the conflict.
 struct DistributeLayoutConflictToSharedMemory final
-    : OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
+    : OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
-  LogicalResult
-  matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
-                  DistributionSignature &signature,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
     auto loc = resolutionOp.getLoc();
     VectorValue vector = resolutionOp.getInput();
     VectorValue result = resolutionOp.getOutput();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index e5ecf59..2654eae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -643,7 +643,7 @@
   %vcst = arith.constant dense<0.0> : vector<32x16xf16>
   // CHECK-COUNT-8: vector.load %[[MEM]]
   %vec = vector.transfer_read  %a[%c0, %c0], %cst {"__vector_layout_test_anchor_result_0" = #layout1} : memref<32x16xf16>, vector<32x16xf16>
-  // CHECK: iree_vector_ext.layout_conflict_resolution {{.*}}
+  // CHECK: iree_vector_ext.to_layout {{.*}}
   %vec2 = arith.addf %vec, %vcst : vector<32x16xf16>
   // CHECK-COUNT-16: vector.store {{.*}}, vector<1xf16>
   vector.transfer_write %vec2, %b[%c0, %c0] {in_bounds = [true, true],
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index a8aae48..69c7ff9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -1152,7 +1152,7 @@
                               mlir::FunctionOpInterface funcOp) {
   funcOp.walk([&](Operation *op) {
     // Do not emit remarks for conflict operations.
-    if (isa<VectorExt::LayoutConflictResolutionOp>(op)) {
+    if (isa<VectorExt::ToLayoutOp>(op)) {
       return;
     }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 394d740..98ad063 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -217,9 +217,8 @@
   Value input = opOperand.get();
   // Create a resolution operation. This conflict should be handeled later by
   // someone else, not this analysis.
-  Operation *resolveOp =
-      builder.create<IREE ::VectorExt::LayoutConflictResolutionOp>(
-          input.getLoc(), input.getType(), input, vectorLayout, rhs);
+  Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
+      input.getLoc(), input.getType(), input, rhs);
   Value resolvedValue = resolveOp->getResult(0);
   opOperand.set(resolvedValue);
 
@@ -1015,9 +1014,9 @@
         continue;
       }
 
-      // Do not annotate resolve_conflict operations since they already have
+      // Do not annotate to_layout operations since they already have
       // this information in their attributes.
-      if (isa<IREE::VectorExt::LayoutConflictResolutionOp>(op)) {
+      if (isa<IREE::VectorExt::ToLayoutOp>(op)) {
         continue;
       }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 6c9f90a..77d0e95 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -1676,10 +1676,8 @@
         if (!parentOp || (parentOp->getNumResults() != 1))
           continue;
         parentOp->setAttr("__vector_layout_test_anchor_result_0", readLayout);
-        Value resolvedOperand =
-            rewriter.create<VectorExt::LayoutConflictResolutionOp>(
-                contract.getLoc(), operand.getType(), operand, layout,
-                readLayout);
+        Value resolvedOperand = rewriter.create<VectorExt::ToLayoutOp>(
+            contract.getLoc(), operand.getType(), operand, layout);
         contract.setOperand(operandIndices[i], resolvedOperand);
       }
     }
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
index fdaecc6..f21607a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
@@ -24,23 +24,24 @@
 // Layout ops.
 //===----------------------------------------------------------------------===//
 
-def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> {
-  let summary = "Layout Conflict Resolution operator";
+def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
+  Pure,
+  AllTypesMatch<["input", "output"]>
+  ]> {
+  let summary = "Layout conversion operator";
   let description = [{
-    The layout conflict resolution operator takes a vector and a
-    desired layout and transforms the vector to one with the
-    desired layout.
+    The layout conversion operator takes a shaped value and a layout and
+    transforms the value to have that layout.
   }];
   let arguments = (ins
     AnyVector:$input,
-    VectorLayoutInterface:$sourceLayout,
-    VectorLayoutInterface:$desiredLayout
+    VectorLayoutInterface:$layout
   );
   let results = (outs
     AnyVector:$output
   );
   let extraClassDeclaration = [{}];
-  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
+  let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
   let hasVerifier = 1;
 }
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index dd52c2c..c694398 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -17,15 +17,9 @@
 // LayoutConflictResolutionOp
 //===----------------------------------------------------------------------===//
 
-// Validate that the desired layout has the same shape as the input.
-LogicalResult LayoutConflictResolutionOp::verify() {
-  if (getSourceLayout().isValidLayout(getInput()).failed()) {
-    return failure();
-  }
-  if (getDesiredLayout().isValidLayout(getOutput()).failed()) {
-    return failure();
-  }
-  return success();
+// Validate that the layout has the same shape as the input.
+LogicalResult ToLayoutOp::verify() {
+  return getLayout().isValidLayout(getInput());
 }
 
 // to_simd -> to_simt
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
index 739f3d4..7a848362 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
@@ -3,28 +3,12 @@
 #row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
 #col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
 #layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
-#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
-func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
+func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
   %cst_0 = arith.constant 0.0 : f16
   %c0 = arith.constant 0 : index
-  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
   // expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX,  LANEX,  VECTORY], [1, 1, 1]>, <[ BATCHY,  LANEY,  VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
-  %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
-  return %2 : vector<32x32xf16>
-}
-
-// -----
-
-#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
-#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
-#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
-#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
-func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
-  %cst_0 = arith.constant 0.0 : f16
-  %c0 = arith.constant 0 : index
   %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
-  // expected-error @-1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX,  LANEX,  VECTORY], [1, 1, 1]>, <[ BATCHY,  LANEY,  VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
-  %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16>
+  %2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16>
   return %2 : vector<32x32xf16>
 }
 
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 49fe27c..5d8018d 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -2,22 +2,18 @@
 
 #row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 4, 4]>
 #col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
-#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
 #layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1>
 func.func @specify_layout(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
   %cst_0 = arith.constant 0.0 : f16
   %c0 = arith.constant 0 : index
   %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
-  %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
+  %2 = iree_vector_ext.to_layout %result to #layout2 : vector<32x32xf16>
   return %2 : vector<32x32xf16>
 }
 
 // CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.layout<<[ BATCHY,  LANEY,  VECTORX], [4, 2, 4]>, <[ BATCHX,  LANEX,  VECTORY], [2, 4, 4]>>
-// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.layout<<[ BATCHX,  LANEX,  VECTORY], [2, 4, 4]>, <[ BATCHY,  LANEY,  VECTORX], [4, 2, 4]>>
 // CHECK-LABEL: func.func @specify_layout
-// CHECK:      iree_vector_ext.layout_conflict_resolution
-// CHECK-SAME:         desiredLayout = #[[LAYOUT0]]
-// CHECK-SAME:         sourceLayout = #[[LAYOUT1]]
+// CHECK:      iree_vector_ext.to_layout {{.*}} to #[[LAYOUT0]]
 
 // -----