[Stream] Implement SpecializeEncodings pass (1/n) (#19502)
There are three major changes in the revision:
- Introduce `AffinityAnalysisDialectInterface` Stream dialect interface.
It is used to fetch attributes that are defined by other dialects. In
the revision, HAL implements the dialect interface, and it can return
whatever attribute attached in HAL::ExecutableTarget attributes. The
main idea of the dialect interface is that Stream **does not** need to
depend on HAL to get the layout information.
- Add `cloneWithLayouts` method to the EncodingAttr. It is used in the
encoding specialization pass where it can resolve the layout
requirements and add it to the `layouts` field. The other optional
parameters are dropped because the layout is already resolved. It can be
a new Encoding dialect attribute because it is just describing the
layout. The stream tensor ops do not need to know the `op_type`,
`element_types` and `operand_index` parameters. It only needs the layout
information, and the attribute should implement the interface method.
- Partially implement the SpecializeEncodings pass. The responsibility
of the pass is large, so I decide to implement it incrementally. This
revision only implements the mechanism of updating stream tensor ops'
encoding, and only stream.tensor.sizeof op is supported. The rest of the
support for other stream tensor op can be added later on. The executable
duplication and the update of dispatch ops will be implemented in
subsequent PRs.
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
index 593d9b8..b388b9c 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
@@ -113,6 +114,15 @@
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}
+EncodingAttr EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) {
+ MLIRContext *ctx = getContext();
+ return get(ctx, getOperandIndex(), getOpType(), getElementTypes(),
+ /*user_indexing_maps=*/ArrayAttr(),
+ /*bcast_map=*/AffineMapAttr(),
+ /*round_dims_to=*/DenseI64ArrayAttr(),
+ ArrayAttr::get(ctx, layouts));
+}
+
/// Returns the bit-width of the scalar type. If the type is complex, it returns
/// the type of individual elements * 2 (1 for real and 1 for complex).
static unsigned getTypeBitWidth(Type type) {
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
index 54829b6..434356a 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
@@ -113,6 +113,10 @@
/// Clones an encoding with a new bcast_map
EncodingAttr clone(AffineMap bcastMap);
+
+ /// Clones an encoding with a new layout list and drops other optional
+ /// parameters (because they are resolved).
+ EncodingAttr cloneWithLayouts(ArrayRef<Attribute> layouts);
}];
let genVerifyDecl = 0;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
index 3f80245..576e77d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
@@ -101,7 +101,9 @@
deps = [
":IR",
"//compiler/src/iree/compiler/Dialect/HAL:hal_imports",
+ "//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 846bcf0..e0b68bd 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -79,8 +79,10 @@
MLIRParser
MLIRSCFDialect
MLIRTransformUtils
+ iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::hal_imports
+ iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Conversion
PUBLIC
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index 00c2c6e..e28d08f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -6,13 +6,16 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/HAL/hal.imports.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -115,6 +118,29 @@
}
};
+class HALAffinityAnalysisDialectInterface
+ : public IREE::Stream::AffinityAnalysisDialectInterface {
+public:
+ using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface;
+ IREE::Stream::ResolveLayoutAttrFn
+ makeLayoutAttrResolver(ModuleOp moduleOp) const {
+ return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op,
+ SetVector<Attribute> &layoutAttrs) -> LogicalResult {
+ // This needs to be in the lambda because the moduleOp could be modified..
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return op->emitError("failed to run DeviceAnalysis");
+ }
+ SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
+ deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
+ resultSet);
+ // TODO(hanchung): Populate the EncodingLayoutAttr when it is ready.
+ layoutAttrs.insert(resultSet.begin(), resultSet.end());
+ return success();
+ };
+ };
+};
+
} // namespace
HALDialect::HALDialect(MLIRContext *context)
@@ -131,6 +157,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc"
>();
addInterfaces<HALInlinerInterface, HALOpAsmInterface,
+ HALAffinityAnalysisDialectInterface,
HALToVMConversionInterface>();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
index 9959bd1..2fa22ed 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
@@ -50,6 +50,7 @@
hdrs = [
"StreamDialect.h",
"StreamEnums.h.inc",
+ "StreamInterfaces.h",
"StreamOpInterfaces.h.inc",
"StreamOps.h",
"StreamOps.h.inc",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
index 286bb71..2f10910 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
@@ -16,6 +16,7 @@
HDRS
"StreamDialect.h"
"StreamEnums.h.inc"
+ "StreamInterfaces.h"
"StreamOpInterfaces.h.inc"
"StreamOps.h"
"StreamOps.h.inc"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h
new file mode 100644
index 0000000..d18b7a5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h
@@ -0,0 +1,36 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_
+#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+using ResolveLayoutAttrFn = std::function<LogicalResult(
+ AffinityAttr, Operation *, SetVector<Attribute> &)>;
+
+class AffinityAnalysisDialectInterface
+ : public DialectInterface::Base<AffinityAnalysisDialectInterface> {
+public:
+ AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// The `moduleOp` must remain live and unmodified for as long as the returned
+ /// capture is. Otherwise, it will likely be incorrect or crash if the module
+ /// op is mutated, especially when module scope analysis is run.
+ virtual ResolveLayoutAttrFn
+ makeLayoutAttrResolver(ModuleOp moduleOp) const = 0;
+};
+
+} // namespace mlir::iree_compiler::IREE::Stream
+
+#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAM_INTERFACES_H_
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
index 90e22a6..d7ef2a2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
@@ -39,6 +39,7 @@
"ScheduleConcurrency.cpp",
"ScheduleExecution.cpp",
"SpecializeDispatches.cpp",
+ "SpecializeEncodings.cpp",
"VerifyAffinities.cpp",
"VerifyAsyncAccessRanges.cpp",
"VerifyLowerings.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 9e15b84..b905053 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -40,6 +40,7 @@
"ScheduleConcurrency.cpp"
"ScheduleExecution.cpp"
"SpecializeDispatches.cpp"
+ "SpecializeEncodings.cpp"
"VerifyAffinities.cpp"
"VerifyAsyncAccessRanges.cpp"
"VerifyLowerings.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 2234c62..69e65fe 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -22,6 +22,15 @@
"the pipeline for debugging."),
llvm::cl::init(false));
+// TODO(hanchung): Enable the pass by default once the implementation is done.
+static llvm::cl::opt<bool> clSpecializeEncodings(
+ "iree-stream-experimental-specialize-encodings",
+ llvm::cl::desc(
+ "Enables SpecializeEncodingPass in Stream pass pipeline. This pass is "
+ "currently under development, so it is not enabled by default. It can "
+ "only handle limited cases at this moment."),
+ llvm::cl::init(false));
+
namespace mlir::iree_compiler::IREE::Stream {
using FunctionLikeNest =
@@ -140,6 +149,10 @@
// Tensor lowering and resource management
//----------------------------------------------------------------------------
+ if (clSpecializeEncodings) {
+ passManager.addPass(IREE::Stream::createSpecializeEncodingsPass());
+ }
+
// Lower stream.tensor.* ops to stream.async.* ops based on
// affinity/configuration assigned during placement.
FunctionLikeNest(passManager)
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 3aec709..3dcbbb5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -414,6 +414,16 @@
];
}
+def SpecializeEncodingsPass :
+ Pass<"iree-stream-specialize-encodings", "mlir::ModuleOp"> {
+ let summary = "Specializes data-tiling encodings based on device analysis.";
+ let description = [{
+ Attaches layouts to encodings and duplicates executables based on device
+ analysis.
+ TODO: Unpack the context. The pass is not fully implemented yet.
+ }];
+}
+
def AnnotateDispatchArgumentsPass :
Pass<"iree-stream-annotate-dispatch-arguments", "mlir::ModuleOp"> {
let summary = "Annotates dispatch arguments with potential values derived from dispatch sites.";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
new file mode 100644
index 0000000..b177bed
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
@@ -0,0 +1,169 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+#define DEBUG_TYPE "iree-stream-specialize-encodings"
+
+#define GEN_PASS_DEF_SPECIALIZEENCODINGSPASS
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"
+
+namespace {
+/// Returns a stably sorted list of dialect interfaces of T for all dialects
+/// used within the given module.
+template <typename T>
+SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
+ SmallPtrSet<const T *, 4> resultSet;
+ for (auto dialect : moduleOp.getContext()->getLoadedDialects()) {
+ auto *dialectInterface = dialect->getRegisteredInterface<T>();
+ if (!dialectInterface)
+ continue;
+ resultSet.insert(dialectInterface);
+ }
+
+ // NOTE: to ensure deterministic output we sort the result so that imports are
+ // always added in a consistent order.
+ SmallVector<const T *> results = {resultSet.begin(), resultSet.end()};
+ llvm::sort(
+ results, +[](const T *a, const T *b) {
+ return a->getDialect()->getNamespace().compare(
+ b->getDialect()->getNamespace()) < 0;
+ });
+ return results;
+}
+
+// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
+static RankedTensorType cloneWithEncoding(RankedTensorType type,
+ Attribute encodingAttr) {
+ return RankedTensorType::get(type.getShape(), type.getElementType(),
+ encodingAttr);
+}
+
+static LogicalResult addLayoutsToTensorPhaseOps(
+ ModuleOp moduleOp, FunctionOpInterface funcOp,
+ IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
+ SmallVector<IREE::Stream::AffinityOpInterface> candidates;
+ funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
+ // Only need to update encoding types for ops that have TensorPhaseOp trait.
+ if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
+ return;
+ }
+
+ // Bail out if the operation does not have an affinity attribute.
+ auto affinityAttr = affinityOp.getAffinityAttr();
+ if (!affinityAttr) {
+ return;
+ }
+ candidates.push_back(affinityOp);
+ });
+
+ if (candidates.empty()) {
+ return success();
+ }
+
+ IRRewriter rewriter(funcOp.getContext());
+ for (auto affinityOp : candidates) {
+ auto affinityAttr = affinityOp.getAffinityAttr();
+ SetVector<Attribute> layouts;
+ if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) {
+ return affinityOp.emitError("failed on making layouts");
+ }
+
+ // Returns an updated encoding attribute if an encoding attribute is present
+ // in the type. Otherwise, returns std::nullopt.
+ auto getEncodingWithNewLayouts =
+ [=](Type type) -> std::optional<IREE::Encoding::EncodingAttr> {
+ auto rankedTensorType = dyn_cast<RankedTensorType>(type);
+ if (!rankedTensorType) {
+ return std::nullopt;
+ }
+ auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType);
+ if (!encodingAttr) {
+ return std::nullopt;
+ }
+ return encodingAttr.cloneWithLayouts(layouts.getArrayRef());
+ };
+
+ // TODO(hanchung): Update other Stream operations.
+ LogicalResult result =
+ TypeSwitch<Operation *, LogicalResult>(affinityOp)
+ .Case<IREE::Stream::TensorSizeOfOp>([&](auto sizeOfOp) {
+ auto encodingType =
+ dyn_cast<RankedTensorType>(sizeOfOp.getEncoding());
+ if (!encodingType) {
+ return success();
+ }
+ std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
+ getEncodingWithNewLayouts(encodingType);
+ if (!encodingAttr) {
+ return success();
+ }
+ rewriter.modifyOpInPlace(sizeOfOp, [&] {
+ sizeOfOp.setEncoding(
+ cloneWithEncoding(encodingType, encodingAttr.value()));
+ });
+ return success();
+ })
+ .Default([](auto *op) { return failure(); });
+
+ if (failed(result)) {
+ return failure();
+ }
+ }
+ return success();
+}
+} // namespace
+
+struct SpecializeEncodingsPass
+ : public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
+ void runOnOperation() override {
+ ModuleOp moduleOp = getOperation();
+ auto usedDialects = gatherUsedDialectInterfaces<
+ IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
+ if (usedDialects.size() != 1) {
+ moduleOp.emitError("expected only one dialect implementing "
+ "AffinityAnalysisDialectInterface");
+ return signalPassFailure();
+ }
+
+ llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
+ for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
+ executableOps[executableOp.getName()] = executableOp;
+ }
+
+ IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
+ usedDialects[0]->makeLayoutAttrResolver(moduleOp);
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp,
+ resolveLayoutAttr))) {
+ funcOp.emitError(
+ "failed on adding layouts to Stream::TensorPhaseOp with encodings");
+ return signalPassFailure();
+ }
+
+ // TODO(hanchung): Duplicate executables and update dispatch ops.
+ }
+ }
+};
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
index 87d6bea..722ce78 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
@@ -46,6 +46,7 @@
"schedule_concurrency.mlir",
"schedule_execution.mlir",
"specialize_dispatches.mlir",
+ "specialize_encodings.mlir",
"verify_affinities.mlir",
"verify_async_access_ranges.mlir",
],
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 8c4ca85..6eb964a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -44,6 +44,7 @@
"schedule_concurrency.mlir"
"schedule_execution.mlir"
"specialize_dispatches.mlir"
+ "specialize_encodings.mlir"
"verify_affinities.mlir"
"verify_async_access_ranges.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
new file mode 100644
index 0000000..1ae03e6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt --split-input-file --iree-stream-specialize-encodings %s | FileCheck %s
+
+//------------------------------------------------------------------------------
+// Stream ops that have TensorPhaseOp trait. This test suite tests that the
+// encoding is updated that carries resolved layouts.
+//------------------------------------------------------------------------------
+
+#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_layout = #iree_cpu.vmvx_encoding_layout<>}>
+#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32]>
+module {
+ util.global private @device_a = #device_target_local_0_
+
+ util.func public @tensor_sizeof(%d0: index, %d1: index) -> index {
+ %size = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?xf32, #encoding>{%d0, %d1} : index
+ util.return %size : index
+ }
+}
+// CHECK: #[[EXECUTABLE:.+]] = #hal.executable.target<"vmvx",
+// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding
+// CHECK-SAME: layouts = [#[[EXECUTABLE]]]
+// CHECK-LABEL: util.func public @tensor_sizeof
+// CHECK: %[[RES:.+]] = stream.tensor.sizeof {{.+}} tensor<?x?xf32, #[[$ENCODING]]>
+// CHECK: return %[[RES]]