Add vector_ext dialect (#15599)
This patch introduces the vector_ext dialect. The purpose of this
dialect is to have a place for experimenting with things beyond what the
upstream vector dialect provides.
In this particular PR, two new features are added
1. An explicit IR representation of the high dimensional layout that
looks like this #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY",
2>, 4>, 4> The nesting makes clear what the innermost dimensions are and
their corresponding shapes.
2. Adds a layout conflict resolution operator. During layout analysis,
this operator can be used to resolve any differences in layout. The
lowering of this operator is not provided but the semantics are that
given a vector with an existing layout and a desired layout, the
operator transforms the vector to the desired layout.
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 297fd7f..ec3db19 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -66,6 +66,7 @@
"//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
+ "//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//mlir:IR",
],
)
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index 475a06f..d725933 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -53,6 +53,7 @@
IREELinalgExtTransforms
IREELinalgTransformDialect
IREELinalgTransformDialectPasses
+ IREEVectorExtDialect
MLIRIR
iree::compiler::Bindings::Native::Transforms
iree::compiler::Bindings::TFLite::Transforms
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
index 7d19e14..b454421 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -16,6 +16,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Interfaces/Interfaces.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
@@ -49,6 +50,7 @@
IREE::Util::UtilDialect,
IREE::VM::VMDialect,
IREE::VMVX::VMVXDialect,
+ IREE::VectorExt::IREEVectorExtDialect,
IREE::Vulkan::VulkanDialect>();
// clang-format on
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 68c5395..5cba3b0 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -32,6 +32,7 @@
"include/iree-dialects/Dialect/Input/*.td",
"include/iree-dialects/Dialect/LinalgExt/IR/*.td",
"include/iree-dialects/Dialect/LinalgExt/Passes/*.td",
+ "include/iree-dialects/Dialect/VectorExt/IR/*.td",
]),
)
@@ -42,6 +43,7 @@
"include/iree-dialects/Dialect/LinalgExt/IR/*.td",
"include/iree-dialects/Dialect/LinalgExt/Passes/*.td",
"include/iree-dialects/Dialect/LinalgTransform/*.td",
+ "include/iree-dialects/Dialect/VectorExt/IR/*.td",
"python/iree/compiler/dialects/*.td",
]),
includes = ["include"],
@@ -619,6 +621,84 @@
)
################################################################################
+# IREEVectorExt Dialect
+################################################################################
+
+gentbl_cc_library(
+ name = "IREEVectorExtIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "--dialect=iree_vector_ext",
+ "--gen-dialect-decls",
+ ],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc",
+ ),
+ (
+ [
+ "--dialect=iree_vector_ext",
+ "--gen-dialect-defs",
+ ],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc",
+ ),
+ (
+ ["--gen-attrdef-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc",
+ ),
+ (
+ ["--gen-attrdef-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc",
+ ),
+ (
+ ["--gen-enum-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc",
+ ),
+ (
+ ["--gen-enum-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc",
+ ),
+ (
+ ["--gen-op-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc",
+ ),
+ (
+ ["--gen-op-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc",
+ ),
+ (
+ ["--gen-typedef-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.h.inc",
+ ),
+ (
+ ["--gen-typedef-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtTypes.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
+cc_library(
+ name = "IREEVectorExtDialect",
+ srcs = glob([
+ "lib/Dialect/VectorExt/IR/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Dialect/VectorExt/IR/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":IREEVectorExtIncGen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ ],
+)
+
+################################################################################
# CAPI
################################################################################
@@ -681,6 +761,7 @@
":IREELinalgExtTransformOps",
":IREELinalgTransformDialect",
":IREELinalgTransformDialectPasses",
+ ":IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
index 16d52d4..18881bd 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
add_subdirectory(LinalgTransform)
+add_subdirectory(VectorExt)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt
new file mode 100644
index 0000000..9ba3d84
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt
new file mode 100644
index 0000000..0b29b25
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt
@@ -0,0 +1,36 @@
+function(_add_dialect)
+ set(LLVM_TARGET_DEFINITIONS VectorExtOps.td)
+ mlir_tablegen(VectorExtAttrs.h.inc -gen-attrdef-decls)
+ mlir_tablegen(VectorExtAttrs.cpp.inc -gen-attrdef-defs)
+ mlir_tablegen(VectorExtEnums.h.inc -gen-enum-decls)
+ mlir_tablegen(VectorExtEnums.cpp.inc -gen-enum-defs)
+ mlir_tablegen(VectorExtOps.h.inc -gen-op-decls)
+ mlir_tablegen(VectorExtOps.cpp.inc -gen-op-defs)
+ mlir_tablegen(VectorExtTypes.h.inc -gen-typedef-decls)
+ mlir_tablegen(VectorExtTypes.cpp.inc -gen-typedef-defs)
+ mlir_tablegen(VectorExtDialect.h.inc --gen-dialect-decls --dialect=iree_vector_ext)
+ mlir_tablegen(VectorExtDialect.cpp.inc --gen-dialect-defs --dialect=iree_vector_ext)
+ add_public_tablegen_target(IREEVectorExtIncGen)
+ add_dependencies(mlir-headers IREEVectorExtIncGen)
+endfunction()
+
+function(_add_doc)
+ set(LLVM_TARGET_DEFINITIONS VectorExtOps.td)
+ set(_FLAGS
+ "--strip-prefix=::mlir::iree_compiler::IREE::"
+ )
+ mlir_tablegen(VectorExtOps.md -gen-dialect-doc ${_FLAGS})
+ set(GEN_DOC_FILE ${IREE_DIALECTS_BINARY_DIR}/docs/Dialects/VectorExtOps.md)
+ add_custom_command(
+ OUTPUT ${GEN_DOC_FILE}
+ COMMAND ${CMAKE_COMMAND} -E copy
+ ${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md
+ ${GEN_DOC_FILE}
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/VectorExtOps.md)
+ add_custom_target(VectorExtOpsDocGen DEPENDS ${GEN_DOC_FILE})
+ add_dependencies(iree-dialects-doc VectorExtOpsDocGen)
+endfunction()
+
+_add_dialect()
+_add_doc()
+
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
new file mode 100644
index 0000000..1de6d81
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
@@ -0,0 +1,72 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_VECTOREXT_BASE
+#define IREE_DIALECT_VECTOREXT_BASE
+
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Dialect definition
+//===----------------------------------------------------------------------===//
+
+def IREEVectorExt_Dialect : Dialect {
+ let name = "iree_vector_ext";
+ let cppNamespace = "::mlir::iree_compiler::IREE::VectorExt";
+ let summary = [{
+ IREE Vector Extensions.
+ }];
+ let description = [{
+ A dialect designed for experimenting with vector operations
+ beyond what is currently available in the Vector Dialect.
+ }];
+ let useDefaultAttributePrinterParser = 1;
+}
+
+//===---------------------------------------------------------------------===//
+// Vector layout attributes
+//===---------------------------------------------------------------------===//
+
+class IREEVectorExt_Attr<string name, list<Trait> traits = []>
+ : AttrDef<IREEVectorExt_Dialect, name, traits>;
+
+def PerDimLayoutAttr : IREEVectorExt_Attr<"PerDimLayout"> {
+ let mnemonic = "per_dim_layout";
+ let summary = [{high-dimensional vector register layout for a given vector dimension}];
+ let description = [{
+ This attribute describes the per dimension register layout for a given vector
+ that could be prescribed by an operator such as matrix multiplication.
+ This is a way to explicitly represent the layout in the IR
+ when it is in the SIMD form prior to converting to the SIMT form so that
+ we can reason about layouts, propagating layouts and layout conflicts.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"std::string", "labels for the high dimensional layout dims">:$labels,
+ ArrayRefParameter<"int64_t", "shapes for the high dimensional layout dims">:$shapes
+ );
+ let hasCustomAssemblyFormat = 1;
+ let genVerifyDecl = 0;
+}
+
+def LayoutAttr : IREEVectorExt_Attr<"Layout"> {
+ let mnemonic = "layout";
+ let summary = [{high-dimensional vector register layout for a given vector}];
+ let description = [{
+ This contains a complete specification of the layout for a given vector,
+ whereas the attribute above only specifies the per dimension layout.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"PerDimLayoutAttr", "layout for each dimension of the vector">:$layouts
+ );
+ let hasCustomAssemblyFormat = 1;
+ let genVerifyDecl = 0;
+}
+
+#endif // IREE_DIALECT_VECTOREXT_BASE
+
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h
new file mode 100644
index 0000000..82bdccc
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h
@@ -0,0 +1,17 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_
+#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+// clang-format off: must be included after all LLVM/MLIR headers
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h.inc" // IWYU pragma: keep
+// clang-format on
+
+#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTDIALECT_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
new file mode 100644
index 0000000..b1f4b6f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -0,0 +1,30 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
+#define IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
+
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+// clang-format off
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc" // IWYU pragma: export
+
+#define GET_ATTRDEF_CLASSES
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc" // IWYU pragma: export
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc" // IWYU pragma: export
+
+// clang-format on
+
+#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
new file mode 100644
index 0000000..77476d0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td
@@ -0,0 +1,47 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_VECTOREXT_OPS
+#define IREE_DIALECT_VECTOREXT_OPS
+
+include "iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td"
+
+//===----------------------------------------------------------------------===//
+// Base class.
+//===----------------------------------------------------------------------===//
+
+class IREEVectorExt_PureOp<string mnemonic, list<Trait> traits = []> :
+ Op<IREEVectorExt_Dialect, mnemonic, traits> {
+}
+
+//===----------------------------------------------------------------------===//
+// Layout ops.
+//===----------------------------------------------------------------------===//
+
+def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> {
+ let summary = "Layout Conflict Resolution operator";
+ let description = [{
+ The layout conflict resolution operator takes a vector and a
+ desired layout and transforms the vector to one with the
+ desired layout.
+ }];
+ let arguments = (ins
+ AnyVector:$input,
+ LayoutAttr:$sourceLayout,
+ LayoutAttr:$desiredLayout
+ );
+ let results = (outs
+ AnyVector:$output
+ );
+ let extraClassDeclaration = [{
+
+ }];
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
+ let hasVerifier = 1;
+}
+
+#endif // IREE_DIALECT_VECTOREXT_OPS
+
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
index 16d52d4..18881bd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
add_subdirectory(LinalgTransform)
+add_subdirectory(VectorExt)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt
new file mode 100644
index 0000000..9ba3d84
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
new file mode 100644
index 0000000..8f8f9bb
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(IREEVectorExtDialect
+ VectorExtDialect.cpp
+ VectorExtOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREEVectorExtIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
+
+iree_dialects_target_includes(IREEVectorExtDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
new file mode 100644
index 0000000..06740ee
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -0,0 +1,90 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc" // IWYU pragma: keep
+
+#define GET_ATTRDEF_CLASSES
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" // IWYU pragma: keep
+
+void IREEVectorExtDialect::initialize() {
+
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc"
+ >();
+
+#define GET_OP_LIST
+ addOperations<
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc"
+ >();
+}
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc"
+
+// Parses an attribute with syntax
+// <"BatchX"<"VecX", 2>, 4>
+Attribute PerDimLayoutAttr::parse(AsmParser &parser, Type type) {
+ SmallVector<std::string> dimNames;
+ SmallVector<int64_t> dimShapes;
+ std::string name;
+ while (!(parser.parseOptionalLess() || parser.parseOptionalString(&name))) {
+ dimNames.push_back(name);
+ }
+ int64_t dim;
+ while (!(parser.parseOptionalComma() || parser.parseInteger(dim) ||
+ parser.parseGreater())) {
+ dimShapes.push_back(dim);
+ }
+ std::reverse(dimShapes.begin(), dimShapes.end());
+ return PerDimLayoutAttr::get(parser.getContext(), dimNames, dimShapes);
+}
+
+void PerDimLayoutAttr::print(AsmPrinter &printer) const {
+ for (auto label : getLabels())
+ printer << "<" << label;
+ for (auto shape : llvm::reverse(getShapes()))
+ printer << ", " << shape << ">";
+}
+
+// Parses an attribute with syntax
+// #layout<<"BatchX"<"VecX", 2>, 4>, <"BatchY"<"VecZ", 4>,2>>>
+Attribute LayoutAttr::parse(AsmParser &parser, Type type) {
+ if (parser.parseLess())
+ return {};
+ SmallVector<PerDimLayoutAttr> layout;
+ PerDimLayoutAttr perDimLayout;
+ while (!(parser.parseAttribute<PerDimLayoutAttr>(perDimLayout, type))) {
+ layout.push_back(perDimLayout);
+ if (parser.parseOptionalComma())
+ break;
+ }
+ if ((parser.parseGreater()))
+ return {};
+ return LayoutAttr::get(parser.getContext(), layout);
+}
+
+static void printArray(AsmPrinter &printer,
+ ArrayRef<PerDimLayoutAttr> layouts) {
+ printer << "<";
+ for (auto layout : llvm::enumerate(layouts)) {
+ printer << layout.value();
+ if (layout.index() < layouts.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
+void LayoutAttr::print(AsmPrinter &printer) const {
+ printArray(printer, getLayouts());
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
new file mode 100644
index 0000000..2d45345
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -0,0 +1,50 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include <numeric>
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+namespace IREE = mlir::iree_compiler::IREE;
+
+//===----------------------------------------------------------------------===//
+// LayoutConflictResolutionOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult validateLayout(Operation *op, StringRef label, LayoutAttr layout,
+ ArrayRef<int64_t> inputShape) {
+ for (auto perDimLayout : llvm::enumerate(layout.getLayouts())) {
+ ArrayRef<int64_t> shape = perDimLayout.value().getShapes();
+ int64_t computedShape =
+ std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
+ int64_t expectedShape = inputShape[perDimLayout.index()];
+ if (computedShape != expectedShape) {
+ return op->emitError("The " + label +
+ " layout shape does not match the input shape. "
+ "Expected shape to be ")
+ << std::to_string(expectedShape) << ", got "
+ << std::to_string(computedShape);
+ }
+ }
+ return success();
+}
+
+// Validate that the desired layout has the same shape as the input.
+LogicalResult LayoutConflictResolutionOp::verify() {
+ Operation *op = getOperation();
+ ArrayRef<int64_t> inputShape =
+ cast<VectorType>(getInput().getType()).getShape();
+ if (succeeded(validateLayout(op, "source", getSourceLayout(), inputShape)))
+ return validateLayout(op, "desired", getDesiredLayout(), inputShape);
+ return failure();
+}
+
+// clang-format off
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep
+// clang-format: on
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
new file mode 100644
index 0000000..79cf669
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-dialects-opt --split-input-file --verify-diagnostics %s
+
+#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 1>, 1>, 1>
+#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
+func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
+ // expected-error @+1 {{The desired layout shape does not match the input shape. Expected shape to be 32, got 1}}
+ %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
+ return %2 : vector<32x32xf16>
+}
+
+// -----
+
+#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 1>, 1>, 1>
+#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
+func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
+ // expected-error @+1 {{The source layout shape does not match the input shape. Expected shape to be 32, got 1}}
+ %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16>
+ return %2 : vector<32x32xf16>
+}
+
+// -----
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
new file mode 100644
index 0000000..63a0f0b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -0,0 +1,22 @@
+// RUN: iree-dialects-opt --split-input-file %s | FileCheck %s
+
+#row_layout1 = #iree_vector_ext.per_dim_layout<"BatchX"<"LaneX"<"VecY", 2>, 4>, 4>
+#col_layout1 = #iree_vector_ext.per_dim_layout<"BatchY"<"LaneY"<"VecX", 4>, 2>, 4>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+#layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1>
+func.func @specify_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
+ %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
+ return %2 : vector<32x32xf16>
+}
+
+// CHECK-LABEL: func.func @specify_layout
+// CHECK: iree_vector_ext.layout_conflict_resolution
+// CHECK: desiredLayout = #iree_vector_ext.layout<#iree_vector_ext.per_dim_layout<BatchY<LaneY<VecX, 4>, 2>, 4>,
+// CHECK-SAME: #iree_vector_ext.per_dim_layout<BatchX<LaneX<VecY, 2>, 4>, 4>>
+// CHECK: sourceLayout = #iree_vector_ext.layout<#iree_vector_ext.per_dim_layout<BatchX<LaneX<VecY, 2>, 4>, 4>,
+// CHECK-SAME: #iree_vector_ext.per_dim_layout<BatchY<LaneY<VecX, 4>, 2>, 4>>
+
+// -----
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 548b60e..5789bc5 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -8,6 +8,7 @@
IREELinalgTransformDialect
IREELinalgTransformDialectPasses
IREETransformsTestPasses
+ IREEVectorExtDialect
# Core dialects.
MLIRAffineDialect
MLIRArithDialect
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index b7e73f3..085c31e 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -10,6 +10,7 @@
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
@@ -59,6 +60,7 @@
// Local dialects
mlir::iree_compiler::IREE::Input::IREEInputDialect,
mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
+ mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect,
// Upstream dialects
mlir::async::AsyncDialect,
mlir::arith::ArithDialect,