[NFC] Split VectorExt attributes into a separate file (#16284)
In preparation for potentially adding more attribute definitions.
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 359dfe6..8dfa939 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -648,22 +648,6 @@
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc",
),
(
- ["--gen-attrdef-decls"],
- "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc",
- ),
- (
- ["--gen-attrdef-defs"],
- "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc",
- ),
- (
- ["--gen-enum-decls"],
- "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc",
- ),
- (
- ["--gen-enum-defs"],
- "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc",
- ),
- (
["--gen-op-decls"],
"include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h.inc",
),
@@ -688,6 +672,34 @@
)
gentbl_cc_library(
+ name = "IREEVectorExtAttrsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["--gen-attrdef-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.h.inc",
+ ),
+ (
+ ["--gen-attrdef-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc",
+ ),
+ (
+ ["--gen-enum-decls"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.h.inc",
+ ),
+ (
+ ["--gen-enum-defs"],
+ "include/iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
+gentbl_cc_library(
name = "IREEVectorExtInterfacesIncGen",
strip_include_prefix = "include",
tbl_outs = [
@@ -717,6 +729,7 @@
]),
includes = ["include"],
deps = [
+ ":IREEVectorExtAttrsIncGen",
":IREEVectorExtIncGen",
":IREEVectorExtInterfacesIncGen",
"@llvm-project//llvm:Support",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt
index c1ce00b..1e0b7b6 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/CMakeLists.txt
@@ -8,10 +8,6 @@
function(_add_dialect)
set(LLVM_TARGET_DEFINITIONS VectorExtOps.td)
- mlir_tablegen(VectorExtAttrs.h.inc -gen-attrdef-decls)
- mlir_tablegen(VectorExtAttrs.cpp.inc -gen-attrdef-defs)
- mlir_tablegen(VectorExtEnums.h.inc -gen-enum-decls)
- mlir_tablegen(VectorExtEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorExtOps.h.inc -gen-op-decls)
mlir_tablegen(VectorExtOps.cpp.inc -gen-op-defs)
mlir_tablegen(VectorExtTypes.h.inc -gen-typedef-decls)
@@ -20,6 +16,14 @@
mlir_tablegen(VectorExtDialect.cpp.inc --gen-dialect-defs --dialect=iree_vector_ext)
add_public_tablegen_target(IREEVectorExtIncGen)
add_dependencies(mlir-headers IREEVectorExtIncGen)
+
+ set(LLVM_TARGET_DEFINITIONS VectorExtAttrs.td)
+ mlir_tablegen(VectorExtAttrs.h.inc -gen-attrdef-decls)
+ mlir_tablegen(VectorExtAttrs.cpp.inc -gen-attrdef-defs)
+ mlir_tablegen(VectorExtEnums.h.inc -gen-enum-decls)
+ mlir_tablegen(VectorExtEnums.cpp.inc -gen-enum-defs)
+ add_public_tablegen_target(IREEVectorExtAttrsIncGen)
+ add_dependencies(mlir-headers IREEVectorExtAttrsIncGen)
endfunction()
function(_add_doc)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
new file mode 100644
index 0000000..688d5f4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -0,0 +1,88 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_VECTOREXT_ATTRS
+#define IREE_DIALECT_VECTOREXT_ATTRS
+
+include "iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td"
+
+//===---------------------------------------------------------------------===//
+// Vector layout attributes
+//===---------------------------------------------------------------------===//
+
+// Defines the batch dimensions for the original SIMD tensor.
+// By convention, X is along rows and Y along columns.
+def BATCHX : I32EnumAttrCase<"BATCHX", 0>;
+def BATCHY : I32EnumAttrCase<"BATCHY", 1>;
+// Defines the vector dimension.
+def VECTORX : I32EnumAttrCase<"VECTORX", 2>;
+def VECTORY : I32EnumAttrCase<"VECTORY", 3>;
+def VECTORZ : I32EnumAttrCase<"VECTORZ", 4>;
+// Defines the lane dimensions.
+def LANEX : I32EnumAttrCase<"LANEX", 5>;
+def LANEY : I32EnumAttrCase<"LANEY", 6>;
+def LANEZ : I32EnumAttrCase<"LANEZ", 7>;
+
+def LayoutDimension : IREEVectorExt_I32EnumAttr<"LayoutDimension",
+ "Describes the dimension of the high-dimensional layout", [
+ BATCHX,
+ BATCHY,
+ VECTORX,
+ VECTORY,
+ VECTORZ,
+ LANEX,
+ LANEY,
+ LANEZ,
+ ]>;
+
+def LayoutDimensionAttr : IREEVectorExt_EnumAttr<LayoutDimension, "dimension">;
+
+def PerDimLayoutAttr : IREEVectorExt_Attr<"PerDimLayout"> {
+ let mnemonic = "per_dim_layout";
+ let summary = [{high-dimensional vector register layout for a given vector dimension}];
+ let description = [{
+ This attribute describes the per dimension register layout for a given vector
+ that could be prescribed by an operator such as matrix multiplication.
+ This is a way to explicitly represent the layout in the IR
+ when it is in the SIMD form prior to converting to the SIMT form so that
+ we can reason about layouts, propagating layouts and layout conflicts.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"LayoutDimensionAttr", "labels for the high dimensional layout dims">:$labels,
+ ArrayRefParameter<"int64_t", "shapes for the high dimensional layout dims">:$shapes
+ );
+ let assemblyFormat = "`<``[` $labels `]``,` `[` $shapes `]``>`";
+ let genVerifyDecl = 0;
+ let extraClassDeclaration = [{
+ std::optional<int64_t> getShape(const LayoutDimension &dim);
+ bool contains(const LayoutDimension &dim);
+ }];
+}
+
+def LayoutAttr : IREEVectorExt_Attr<"Layout",
+ [ DeclareAttrInterfaceMethods<VectorLayoutInterface> ]> {
+ let mnemonic = "layout";
+ let summary = [{high-dimensional vector register layout for a given vector}];
+ let description = [{
+ This contains a complete specification of the layout for a given vector,
+ whereas the attribute above only specifies the per dimension layout.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"PerDimLayoutAttr", "layout for each dimension of the vector">:$layouts
+ );
+ let assemblyFormat = "`<`$layouts`>`";
+ let genVerifyDecl = 0;
+ let extraClassDeclaration = [{
+ std::optional<int64_t> getBatchDim(int64_t dim);
+
+ // Returns the grid of lane ids. Assumes a valid layout.
+ ::std::tuple<int64_t, int64_t, int64_t> getLaneGrid();
+ PerDimLayoutAttr getDimLayout(int64_t dim) const;
+ }];
+}
+
+#endif // IREE_DIALECT_VECTOREXT_ATTRS
+
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
index 5509811..3169fed 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtBase.td
@@ -29,10 +29,13 @@
beyond what is currently available in the Vector Dialect.
}];
let useDefaultAttributePrinterParser = 1;
+ let extraClassDeclaration = [{
+ void registerAttributes();
+ }];
}
//===---------------------------------------------------------------------===//
-// Vector layout attributes
+// Vector layout attribute helpers
//===---------------------------------------------------------------------===//
class IREEVectorExt_Attr<string name, list<Trait> traits = []>
@@ -47,76 +50,5 @@
class IREEVectorExt_EnumAttr<EnumAttrInfo enumInfo, string name = "">
: EnumAttr<IREEVectorExt_Dialect, enumInfo, name>;
-// Defines the batch dimensions for the original SIMD tensor.
-// By convention, X is along rows and Y along columns.
-def BATCHX : I32EnumAttrCase<"BATCHX", 0>;
-def BATCHY : I32EnumAttrCase<"BATCHY", 1>;
-// Defines the vector dimension.
-def VECTORX : I32EnumAttrCase<"VECTORX", 2>;
-def VECTORY : I32EnumAttrCase<"VECTORY", 3>;
-def VECTORZ : I32EnumAttrCase<"VECTORZ", 4>;
-// Defines the lane dimensions.
-def LANEX : I32EnumAttrCase<"LANEX", 5>;
-def LANEY : I32EnumAttrCase<"LANEY", 6>;
-def LANEZ : I32EnumAttrCase<"LANEZ", 7>;
-
-def LayoutDimension : IREEVectorExt_I32EnumAttr<"LayoutDimension",
- "Describes the dimension of the high-dimensional layout", [
- BATCHX,
- BATCHY,
- VECTORX,
- VECTORY,
- VECTORZ,
- LANEX,
- LANEY,
- LANEZ,
- ]>;
-
-def LayoutDimensionAttr : IREEVectorExt_EnumAttr<LayoutDimension, "dimension">;
-
-def PerDimLayoutAttr : IREEVectorExt_Attr<"PerDimLayout"> {
- let mnemonic = "per_dim_layout";
- let summary = [{high-dimensional vector register layout for a given vector dimension}];
- let description = [{
- This attribute describes the per dimension register layout for a given vector
- that could be prescribed by an operator such as matrix multiplication.
- This is a way to explicitly represent the layout in the IR
- when it is in the SIMD form prior to converting to the SIMT form so that
- we can reason about layouts, propagating layouts and layout conflicts.
- }];
- let parameters = (ins
- ArrayRefParameter<"LayoutDimensionAttr", "labels for the high dimensional layout dims">:$labels,
- ArrayRefParameter<"int64_t", "shapes for the high dimensional layout dims">:$shapes
- );
- let assemblyFormat = "`<``[` $labels `]``,` `[` $shapes `]``>`";
- let genVerifyDecl = 0;
- let extraClassDeclaration = [{
- std::optional<int64_t> getShape(const LayoutDimension &dim);
- bool contains(const LayoutDimension &dim);
- }];
-}
-
-def LayoutAttr : IREEVectorExt_Attr<"Layout",
- [ DeclareAttrInterfaceMethods<VectorLayoutInterface> ]> {
- let mnemonic = "layout";
- let summary = [{high-dimensional vector register layout for a given vector}];
- let description = [{
- This contains a complete specification of the layout for a given vector,
- whereas the attribute above only specifies the per dimension layout.
- }];
- let parameters = (ins
- ArrayRefParameter<"PerDimLayoutAttr", "layout for each dimension of the vector">:$layouts
- );
- let assemblyFormat = "`<`$layouts`>`";
- let genVerifyDecl = 0;
- let extraClassDeclaration = [{
- std::optional<int64_t> getBatchDim(int64_t dim);
-
- // Returns the grid of lane ids. Assumes a valid layout.
- ::std::tuple<int64_t, int64_t, int64_t> getLaneGrid();
- PerDimLayoutAttr getDimLayout(int64_t dim) const;
- }];
-}
-
#endif // IREE_DIALECT_VECTOREXT_BASE
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
index 8f8f9bb..6e965d1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(IREEVectorExtDialect
+ VectorExtAttrs.cpp
VectorExtDialect.cpp
VectorExtOps.cpp
@@ -6,6 +7,7 @@
${IREE_DIALECTS_SOURCE_DIR}/include
DEPENDS
+ IREEVectorExtAttrsIncGen
IREEVectorExtIncGen
LINK_LIBS PUBLIC
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
new file mode 100644
index 0000000..66f60ba
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -0,0 +1,157 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
+
+using namespace mlir;
+
+namespace mlir::iree_compiler::IREE::VectorExt {
+
+bool PerDimLayoutAttr::contains(const LayoutDimension &dim) {
+ for (LayoutDimensionAttr label : getLabels()) {
+ if (label.getValue() == dim)
+ return true;
+ }
+ return false;
+}
+
+std::optional<int64_t> PerDimLayoutAttr::getShape(const LayoutDimension &dim) {
+ for (auto value : llvm::zip(getLabels(), getShapes())) {
+ if (dim == std::get<0>(value).getValue())
+ return std::get<1>(value);
+ }
+ return std::nullopt;
+}
+
+// Get the SIMT Vector shape in the order specified by dims. If no dims are
+// specified, then return an empty vector.
+bool LayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
+ for (auto perDimLayout : llvm::enumerate(getLayouts())) {
+ ArrayRef<int64_t> layoutShape = perDimLayout.value().getShapes();
+ int64_t computedShape = std::reduce(layoutShape.begin(), layoutShape.end(),
+ 1, std::multiplies<int64_t>());
+ int64_t expectedShape = shape[perDimLayout.index()];
+ if (computedShape != expectedShape) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Project out the layout for the specified dimensions
+// resulting in the layout for a lower dimensional vector.
+VectorLayoutInterface LayoutAttr::project(ArrayRef<bool> projectedDims) const {
+ assert(projectedDims.size() == getLayouts().size() &&
+ "projectedDims size must match layout size");
+
+ ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
+ assert(projectedDims.size() == layouts.size());
+ SmallVector<PerDimLayoutAttr> newLayouts;
+ for (auto pair : llvm::zip(projectedDims, layouts)) {
+ if (!std::get<0>(pair))
+ newLayouts.push_back(std::get<1>(pair));
+ }
+ return LayoutAttr::get(getContext(), newLayouts);
+}
+
+// Permute the layout according to the provided permutation
+// vector. The dimensionality of the layout remains the same.
+VectorLayoutInterface LayoutAttr::permute(ArrayRef<int64_t> permutation) const {
+ assert(permutation.size() == getLayouts().size() &&
+ "permutation size must match layout size");
+
+ ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
+ assert(permutation.size() == layouts.size());
+ SmallVector<PerDimLayoutAttr> newLayouts;
+ for (unsigned index : permutation) {
+ assert(index >= 0 && index < layouts.size());
+ newLayouts.push_back(layouts[index]);
+ }
+ return LayoutAttr::get(getContext(), newLayouts);
+}
+
+// This function returns the distributed shape of the SIMT
+// vector and evaluates it in the following order:
+// BATCHX, BATCHY, VECTORY, VECTORX
+// The vector dimensions are combined into a single SIMT
+// vector dimension.
+SmallVector<int64_t> LayoutAttr::getDistributedShape() const {
+ SmallVector<LayoutDimension> labels{
+ LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+ LayoutDimension::VECTORY, LayoutDimension::VECTORX};
+ SmallVector<int64_t> simtVectorShape;
+ std::optional<int64_t> vectorShape;
+ for (LayoutDimension dim : labels) {
+ ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
+ for (PerDimLayoutAttr layout : layouts) {
+ if (!layout.contains(dim))
+ continue;
+ int64_t shape = layout.getShape(dim).value();
+ if (isVectorDimension(dim)) {
+ vectorShape = shape * vectorShape.value_or(1);
+ continue;
+ }
+ simtVectorShape.push_back(shape);
+ }
+ }
+ if (vectorShape)
+ simtVectorShape.push_back(vectorShape.value());
+ return simtVectorShape;
+}
+
+PerDimLayoutAttr LayoutAttr::getDimLayout(int64_t dim) const {
+ assert(dim >= 0 && dim < getLayouts().size());
+ return getLayouts()[dim];
+}
+
+std::optional<int64_t> LayoutAttr::getBatchDim(int64_t dim) {
+ assert(dim < getLayouts().size());
+ PerDimLayoutAttr layout = getDimLayout(dim);
+ for (auto [name, shape] :
+ llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
+ if (isBatchDimension(name.getValue()))
+ return shape;
+ }
+ return std::nullopt;
+}
+
+std::tuple<int64_t, int64_t, int64_t> LayoutAttr::getLaneGrid() {
+ int64_t laneX = 1;
+ int64_t laneY = 1;
+ int64_t laneZ = 1;
+ for (PerDimLayoutAttr dimLayout : getLayouts()) {
+ // Note that valid layouts only include at most one instance of each
+ // dimension type, so this is simply doing assignment on the first instance
+ // of each lane index, not an accumulative product.
+ auto maybeXShape = dimLayout.getShape(LayoutDimension::LANEX);
+ laneX *= maybeXShape.value_or(1);
+ auto maybeYShape = dimLayout.getShape(LayoutDimension::LANEY);
+ laneY *= maybeYShape.value_or(1);
+ auto maybeZShape = dimLayout.getShape(LayoutDimension::LANEZ);
+ laneZ *= maybeZShape.value_or(1);
+ }
+ return std::make_tuple(laneX, laneY, laneZ);
+}
+
+} // namespace mlir::iree_compiler::IREE::VectorExt
+
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc" // IWYU pragma: keep
+
+#define GET_ATTRDEF_CLASSES
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" // IWYU pragma: keep
+
+void IREEVectorExtDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc"
+ >();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
index 4e53f7b..904d35a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -11,13 +11,15 @@
#include <numeric>
using namespace mlir;
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc"
+
using namespace mlir::iree_compiler::IREE::VectorExt;
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc" // IWYU pragma: keep
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrInterfaces.cpp.inc"
-#define GET_ATTRDEF_CLASSES
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" // IWYU pragma: keep
- //
+namespace mlir::iree_compiler::IREE::VectorExt {
+
struct IREEVectorExtDialectOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
@@ -31,10 +33,7 @@
void IREEVectorExtDialect::initialize() {
addInterfaces<IREEVectorExtDialectOpAsmInterface>();
- addAttributes<
-#define GET_ATTRDEF_LIST
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc"
- >();
+ registerAttributes();
#define GET_OP_LIST
addOperations<
@@ -42,132 +41,4 @@
>();
}
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.cpp.inc"
-
-#include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrInterfaces.cpp.inc"
-
-bool PerDimLayoutAttr::contains(const LayoutDimension &dim) {
- for (LayoutDimensionAttr label : getLabels()) {
- if (label.getValue() == dim)
- return true;
- }
- return false;
-}
-
-std::optional<int64_t> PerDimLayoutAttr::getShape(const LayoutDimension &dim) {
- for (auto value : llvm::zip(getLabels(), getShapes())) {
- if (dim == std::get<0>(value).getValue())
- return std::get<1>(value);
- }
- return std::nullopt;
-}
-
-// Get the SIMT Vector shape in the order specified by dims. If no dims are
-// specified, then return an empty vector.
-bool LayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
- for (auto perDimLayout : llvm::enumerate(getLayouts())) {
- ArrayRef<int64_t> layoutShape = perDimLayout.value().getShapes();
- int64_t computedShape = std::reduce(layoutShape.begin(), layoutShape.end(),
- 1, std::multiplies<int64_t>());
- int64_t expectedShape = shape[perDimLayout.index()];
- if (computedShape != expectedShape) {
- return false;
- }
- }
- return true;
-}
-
-// Project out the layout for the specified dimensions
-// resulting in the layout for a lower dimensional vector.
-VectorLayoutInterface LayoutAttr::project(ArrayRef<bool> projectedDims) const {
- assert(projectedDims.size() == getLayouts().size() &&
- "projectedDims size must match layout size");
-
- ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
- assert(projectedDims.size() == layouts.size());
- SmallVector<PerDimLayoutAttr> newLayouts;
- for (auto pair : llvm::zip(projectedDims, layouts)) {
- if (!std::get<0>(pair))
- newLayouts.push_back(std::get<1>(pair));
- }
- return LayoutAttr::get(getContext(), newLayouts);
-}
-
-// Permute the layout according to the provided permutation
-// vector. The dimensionality of the layout remains the same.
-VectorLayoutInterface LayoutAttr::permute(ArrayRef<int64_t> permutation) const {
- assert(permutation.size() == getLayouts().size() &&
- "permutation size must match layout size");
-
- ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
- assert(permutation.size() == layouts.size());
- SmallVector<PerDimLayoutAttr> newLayouts;
- for (unsigned index : permutation) {
- assert(index >= 0 && index < layouts.size());
- newLayouts.push_back(layouts[index]);
- }
- return LayoutAttr::get(getContext(), newLayouts);
-}
-
-// This function returns the distributed shape of the SIMT
-// vector and evaluates it in the following order:
-// BATCHX, BATCHY, VECTORY, VECTORX
-// The vector dimensions are combined into a single SIMT
-// vector dimension.
-SmallVector<int64_t> LayoutAttr::getDistributedShape() const {
- SmallVector<LayoutDimension> labels{
- LayoutDimension::BATCHX, LayoutDimension::BATCHY,
- LayoutDimension::VECTORY, LayoutDimension::VECTORX};
- SmallVector<int64_t> simtVectorShape;
- std::optional<int64_t> vectorShape;
- for (LayoutDimension dim : labels) {
- ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
- for (PerDimLayoutAttr layout : layouts) {
- if (!layout.contains(dim))
- continue;
- int64_t shape = layout.getShape(dim).value();
- if (isVectorDimension(dim)) {
- vectorShape = shape * vectorShape.value_or(1);
- continue;
- }
- simtVectorShape.push_back(shape);
- }
- }
- if (vectorShape)
- simtVectorShape.push_back(vectorShape.value());
- return simtVectorShape;
-}
-
-PerDimLayoutAttr LayoutAttr::getDimLayout(int64_t dim) const {
- assert(dim >= 0 && dim < getLayouts().size());
- return getLayouts()[dim];
-}
-
-std::optional<int64_t> LayoutAttr::getBatchDim(int64_t dim) {
- assert(dim < getLayouts().size());
- PerDimLayoutAttr layout = getDimLayout(dim);
- for (auto [name, shape] :
- llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
- if (isBatchDimension(name.getValue()))
- return shape;
- }
- return std::nullopt;
-}
-
-std::tuple<int64_t, int64_t, int64_t> LayoutAttr::getLaneGrid() {
- int64_t laneX = 1;
- int64_t laneY = 1;
- int64_t laneZ = 1;
- for (PerDimLayoutAttr dimLayout : getLayouts()) {
- // Note that valid layouts only include at most one instance of each
- // dimension type, so this is simply doing assignment on the first instance
- // of each lane index, not an accumulative product.
- auto maybeXShape = dimLayout.getShape(LayoutDimension::LANEX);
- laneX *= maybeXShape.value_or(1);
- auto maybeYShape = dimLayout.getShape(LayoutDimension::LANEY);
- laneY *= maybeYShape.value_or(1);
- auto maybeZShape = dimLayout.getShape(LayoutDimension::LANEZ);
- laneZ *= maybeZShape.value_or(1);
- }
- return std::make_tuple(laneX, laneY, laneZ);
-}
+} // namespace mlir::iree_compiler::IREE::VectorExt