[Codegen] Add op for copying tensor operands (#17256)

This is useful for providing an anchor to tile a copy of a particular
tensor operand. The advantage of this over alternatives like pad and
pack is that it does not require any fixup operations on the destination
of a linalg op.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index dac7e3c..e7ae1ef 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -49,6 +49,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/CSE.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -248,6 +249,44 @@
   populatePrepareVectorToMMAPatterns(patterns, getUseNvGpu());
 }
 
+//===----------------------------------------------------------------------===//
+// CopyTensorOperandOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform_dialect::CopyTensorOperandOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  int64_t operandIndex = getOperandIndex();
+  if (operandIndex > target->getNumOperands()) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "Operand index out of range");
+  }
+  Value operand = target->getOperand(operandIndex);
+  auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+  if (!tensorType) {
+    return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                     "Non tensor type operand to copy");
+  }
+  rewriter.setInsertionPoint(target);
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      target->getLoc(),
+      tensor::getMixedSizes(rewriter, target->getLoc(), operand),
+      tensorType.getElementType());
+  Operation *copy =
+      rewriter.create<linalg::CopyOp>(target->getLoc(), operand, empty);
+  target->setOperand(operandIndex, copy->getResult(0));
+  results.push_back(copy);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::CopyTensorOperandOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===---------------------------------------------------------------------===//
 // ForallToWorkgroupOp
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index e931823..13fd709 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -154,6 +154,43 @@
   let assemblyFormat = "attr-dict";
 }
 
+def CopyTensorOperandOp :  Op<Transform_Dialect, "iree.copy_tensor_operand",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     FunctionalStyleTransformOpTrait,
+     TransformEachOpTrait,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let summary = "Hoist static allocations";
+  let description = [{
+    Inserts a copy of the specified operand of the target operation.
+
+    #### Return modes
+    Returns a handle to the new copy.
+
+    It does not consume the target handle and emits a definite failure if the
+    operand index is out of range or if the operand is not a tensor type.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$target,
+    I64Attr:$operand_index);
+  let results = (outs TransformHandleTypeInterface:$result);
+
+  let assemblyFormat = [{
+    $target `[` $operand_index `]` attr-dict
+    `:` functional-type(operands, results)
+  }];
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def GpuDistributeSharedMemoryCopyOp :
   Op<Transform_Dialect, "iree.gpu_distribute_shared_memory_copy",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 2d09efc..5b09cad 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -65,6 +65,7 @@
             "test_partitionable_loops_interface.mlir",
             "tile_and_distribute_to_workgroups.mlir",
             "transform_buffer_opt.mlir",
+            "transform_copy_operand.mlir",
             "transform_match_partial_reduction.mlir",
             "transform_ops_invalid.mlir",
             "transpose_canonicalization.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index df9fd29..b284856 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -61,6 +61,7 @@
     "test_partitionable_loops_interface.mlir"
     "tile_and_distribute_to_workgroups.mlir"
     "transform_buffer_opt.mlir"
+    "transform_copy_operand.mlir"
     "transform_match_partial_reduction.mlir"
     "transform_ops_invalid.mlir"
     "transpose_canonicalization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_copy_operand.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_copy_operand.mlir
new file mode 100644
index 0000000..e97b99b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_copy_operand.mlir
@@ -0,0 +1,22 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+func.func @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  return %arg0 : tensor<?xf32>
+}
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.return"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.iree.copy_tensor_operand %0 [0] : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: @main
+//  CHECK-SAME:   (%[[ARG:.+]]: tensor<?xf32>)
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG]], %c0 : tensor<?xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+//       CHECK:   %[[COPY:.+]] = linalg.copy
+//  CHECK-SAME:     ins(%[[ARG]] : tensor<?xf32>)
+//  CHECK-SAME:     outs(%[[EMPTY]] : tensor<?xf32>)
+//       CHECK:   return %[[COPY]] : tensor<?xf32>