[Stream] Specialize encoding for TensorPhaseOp that have result_encoding (#19707)

There are three Stream ops that only have the `result_encoding` operand:
TensorEmptyOp, TensorSplatOp, TensorConstantOp. Only empty ops and splat
ops can support the specialization at this moment because they are pure
shape-like operation. For TensorConstantOp, we return a failure if the
encoding is present. Because we do not know how to update the constant
with the layout at this moment. It could be done by adding interface
methods to `EncodingAttrInterface`.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
index 40571f3..c0b5e40 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
@@ -16,6 +16,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/LogicalResult.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -52,6 +53,35 @@
   return results;
 }
 
+// Returns an updated encoding attribute if the type is a RankedTensorType
+// and an EncodingAttr is present. Otherwise, returns std::nullopt. The
+// method uses the EncodingLayoutAttrInterface from the EncodingAttr to
+// resolve the layouts of the given `type`; returns the new encodings with
+// the resolved layouts.
+static std::optional<IREE::Encoding::EncodingAttr>
+getEncodingWithNewLayouts(Type type,
+                          const SetVector<Attribute> &layoutResolvers) {
+  auto rankedTensorType = dyn_cast<RankedTensorType>(type);
+  if (!rankedTensorType) {
+    return std::nullopt;
+  }
+  auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType);
+  if (!encodingAttr) {
+    return std::nullopt;
+  }
+  SmallVector<Attribute> layouts;
+  for (auto attr : layoutResolvers) {
+    auto encodingLayoutAttr =
+        dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(attr);
+    if (!encodingLayoutAttr) {
+      layouts.push_back(attr);
+      continue;
+    }
+    layouts.push_back(encodingLayoutAttr.getLayout(rankedTensorType));
+  }
+  return encodingAttr.cloneWithLayouts(layouts);
+};
+
 // TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
 static RankedTensorType cloneWithEncoding(RankedTensorType type,
                                           Attribute encodingAttr) {
@@ -59,6 +89,58 @@
                                encodingAttr);
 }
 
+/// Updates the encoding of `sizeOfOp` with resolved layouts.
+static LogicalResult
+updateTensorSizeOfOp(RewriterBase &rewriter,
+                     IREE::Stream::TensorSizeOfOp sizeOfOp,
+                     const SetVector<Attribute> &layoutResolvers) {
+  auto encodingType = dyn_cast<RankedTensorType>(sizeOfOp.getEncoding());
+  std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
+      getEncodingWithNewLayouts(encodingType, layoutResolvers);
+  if (!encodingAttr) {
+    return success();
+  }
+  rewriter.modifyOpInPlace(sizeOfOp, [&] {
+    sizeOfOp.setEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
+  });
+  return success();
+}
+
+/// Returns failure if `op` has encoding. The EncodingAttr has padding
+/// semantic, a constant op with such  encoding can not be resolved at this
+/// moment.
+static LogicalResult
+updateTensorConstantOp(RewriterBase &rewriter,
+                       IREE::Stream::TensorConstantOp op,
+                       const SetVector<Attribute> &layoutResolvers) {
+  auto encodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
+  if (!encodingType) {
+    return success();
+  }
+  if (IREE::Encoding::getEncodingAttr(encodingType)) {
+    return failure();
+  }
+  return success();
+}
+
+/// Updates the result_encoding for `op`. The op have to define a
+/// `result_encoding` parameter.
+template <typename OpTy>
+static LogicalResult
+updateResultEncoding(RewriterBase &rewriter, OpTy op,
+                     const SetVector<Attribute> &layoutResolvers) {
+  auto encodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
+  std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
+      getEncodingWithNewLayouts(encodingType, layoutResolvers);
+  if (!encodingAttr) {
+    return success();
+  }
+  rewriter.modifyOpInPlace(op, [&] {
+    op.setResultEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
+  });
+  return success();
+}
+
 static LogicalResult addLayoutsToTensorPhaseOps(
     ModuleOp moduleOp, FunctionOpInterface funcOp,
     IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
@@ -89,50 +171,18 @@
       return affinityOp.emitError("failed on making layout resolvers");
     }
 
-    // Returns an updated encoding attribute if an encoding attribute is present
-    // in the type. Otherwise, returns std::nullopt.
-    auto getEncodingWithNewLayouts =
-        [=](Type type) -> std::optional<IREE::Encoding::EncodingAttr> {
-      auto rankedTensorType = dyn_cast<RankedTensorType>(type);
-      if (!rankedTensorType) {
-        return std::nullopt;
-      }
-      auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType);
-      if (!encodingAttr) {
-        return std::nullopt;
-      }
-      SmallVector<Attribute> layouts;
-      for (auto attr : layoutResolvers) {
-        auto encodingLayoutAttr =
-            dyn_cast<IREE::Encoding::EncodingLayoutAttrInterface>(attr);
-        if (!encodingLayoutAttr) {
-          layouts.push_back(attr);
-          continue;
-        }
-        layouts.push_back(encodingLayoutAttr.getLayout(rankedTensorType));
-      }
-      return encodingAttr.cloneWithLayouts(layouts);
-    };
-
     // TODO(hanchung): Update other Stream operations.
     LogicalResult result =
         TypeSwitch<Operation *, LogicalResult>(affinityOp)
-            .Case<IREE::Stream::TensorSizeOfOp>([&](auto sizeOfOp) {
-              auto encodingType =
-                  dyn_cast<RankedTensorType>(sizeOfOp.getEncoding());
-              if (!encodingType) {
-                return success();
-              }
-              std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
-                  getEncodingWithNewLayouts(encodingType);
-              if (!encodingAttr) {
-                return success();
-              }
-              rewriter.modifyOpInPlace(sizeOfOp, [&] {
-                sizeOfOp.setEncoding(
-                    cloneWithEncoding(encodingType, encodingAttr.value()));
-              });
-              return success();
+            .Case<IREE::Stream::TensorSizeOfOp>([&](auto op) {
+              return updateTensorSizeOfOp(rewriter, op, layoutResolvers);
+            })
+            .Case<IREE::Stream::TensorEmptyOp, IREE::Stream::TensorSplatOp>(
+                [&](auto op) {
+                  return updateResultEncoding(rewriter, op, layoutResolvers);
+                })
+            .Case<IREE::Stream::TensorConstantOp>([&](auto op) {
+              return updateTensorConstantOp(rewriter, op, layoutResolvers);
             })
             .Default([](auto *op) { return failure(); });
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
index 5fab86a..19dcb37 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-stream-specialize-encodings %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-stream-specialize-encodings --verify-diagnostics %s | FileCheck %s
 
 //------------------------------------------------------------------------------
 // Stream ops that have TensorPhaseOp trait. This test suite tests that the
@@ -33,3 +33,52 @@
 // CHECK:         %[[D0_RES:.+]] = stream.tensor.sizeof {{.+}} tensor<?x?xf32, #[[$ENCODING0]]>
 // CHECK:         %[[D1_RES:.+]] = stream.tensor.sizeof {{.+}} tensor<?x?xf32, #[[$ENCODING1]]>
 // CHECK:         return %[[D0_RES]], %[[D1_RES]]
+
+// -----
+
+#map0 = affine_map<(m, n, k) -> (m, k)>
+#map1 = affine_map<(m, n, k) -> (k, n)>
+#map2 = affine_map<(m, n, k) -> (m, n)>
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
+#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type =  matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
+module {
+  util.global private @device_a = #device_target_local_0_
+
+  util.func public @ops_with_result_encoding_only(%arg0: index, %arg1: index, %scalar_f32 : f32) {
+    %0 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<?x0xf32, #encoding>{%arg0} in !stream.resource<*>{%arg1}
+    %1 = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<?x5x64xf32>{%arg0} in !stream.resource<constant> = dense<0.000000e+00> : tensor<1x5x64xf32>
+    %2 = stream.tensor.splat on(#hal.device.affinity<@device_a>) %scalar_f32 : f32 -> tensor<?x1x10xf32, #encoding>{%arg0} in !stream.resource<*>{%arg1}
+    util.return
+  }
+}
+// CHECK:       #[[$ENCODING:.+]] = #iree_encoding.encoding
+// CHECK-SAME:    #iree_cpu.vmvx_encoding_layout
+// CHECK-SAME:    encoding_info = {innerDimsPos = [{{.+}}], innerTileSizes = [{{.+}}], outerDimsPerm = [{{.+}}]}
+// CHECK:       #[[TARGET:.+]] = #hal.device.target
+// CHECK:       util.global private @[[$DEVICE:.+]] = #[[TARGET]]
+// CHECK-LABEL: util.func public @ops_with_result_encoding_only
+// CHECK:         stream.tensor.empty on(#hal.device.affinity<@[[$DEVICE]]>) : tensor<?x0xf32, #[[$ENCODING]]>
+// CHECK:         stream.tensor.constant {{.+}} : tensor<1x5x64xf32>
+// CHECK:         stream.tensor.splat on(#hal.device.affinity<@[[$DEVICE]]>) {{.+}} -> tensor<?x1x10xf32, #[[$ENCODING]]>
+// CHECK:         return
+
+// -----
+
+// Checks that the stream.tensor.constant op with encoding is not supported.
+
+#map0 = affine_map<(m, n, k) -> (m, k)>
+#map1 = affine_map<(m, n, k) -> (k, n)>
+#map2 = affine_map<(m, n, k) -> (m, n)>
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
+#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type =  matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
+module {
+  util.global private @device_a = #device_target_local_0_
+
+  // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
+  util.func public @ops_with_result_encoding_only(%arg0: index) {
+    %0 = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<?x5x64xf32, #encoding>{%arg0} in !stream.resource<constant> = dense<0.000000e+00> : tensor<1x5x64xf32>
+    util.return
+  }
+}