[VectorExt] Add layout iterator classes (#16004)
This PR adds iterator classes to iterate over the layout and to
concatenate iterators with other frozen iterators. This is required when
distributing reductions. Also adds a test to check for correctness.
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index a4553b3..37f6d4d 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -756,6 +756,7 @@
":IREELinalgExtDialect",
":IREELinalgTransformDialect",
":IREELinalgTransformDialectPasses",
+ ":IREEVectorExtDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index f8794d2..c1beca8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -29,4 +29,70 @@
// clang-format on
+namespace mlir::iree_compiler::IREE::VectorExt {
+
+/// Dimensional Strided Iterator class used to represent
+/// an iterator through a single dimension of the layout.
+class DimensionalIterator {
+public:
+ DimensionalIterator(int64_t position = 0, int64_t stride = 1)
+ : position(position), stride(stride) {}
+ int64_t operator*() const { return position; }
+ DimensionalIterator &operator++() {
+ position += stride;
+ return *this;
+ }
+ bool operator!=(const DimensionalIterator &other) const {
+ return position != other.position;
+ }
+
+private:
+ int64_t position, stride;
+};
+
+/// Dimensional Range class used to represent the range of
+/// a particular dimension of the layout. Can be iterated on
+/// using a DimensionalIterator.
+class DimensionalRange {
+public:
+ DimensionalRange() {}
+ DimensionalRange(int64_t start, int64_t stop, int64_t step = 1)
+ : start(start), stop(stop), step(step) {}
+ DimensionalIterator begin() const { return DimensionalIterator(start, step); }
+ DimensionalIterator end() const { return DimensionalIterator(stop, step); }
+
+private:
+ int64_t start, stop, step;
+};
+
+// Iterator class for LayoutAttrs and PerDimLayoutAttrs.
+// Provides O(1) access to state for any given dimension.
+// Also preserves insertion order.
+// Layout iterators skip lane dimensions as these are not
+// required during distribution.
+class LayoutIterator {
+public:
+ using State = llvm::MapVector<LayoutDimension, DimensionalIterator>;
+ using DimensionMapping =
+ llvm::DenseMap<int64_t, SmallVector<LayoutDimension>>;
+ void maybeFreezeAndConcatenate(const LayoutIterator &frozenIterator);
+ LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides);
+ LayoutIterator(PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides);
+ void apply(std::function<void(const LayoutIterator::State &)>);
+ LayoutIterator &operator++();
+ State getState() const { return state; }
+
+private:
+ void initialize(PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides);
+ bool iterationComplete();
+ State state;
+ llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
+ DimensionMapping simdDimensionToLayoutDimension;
+ DenseSet<LayoutDimension> frozenDimensions;
+};
+
+} // namespace mlir::iree_compiler::IREE::VectorExt
+
#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 49c77b0..1c57d06 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -52,7 +52,84 @@
return {};
}
+LayoutIterator &LayoutIterator::operator++() {
+ for (auto &[dim, it] : state) {
+ if (frozenDimensions.contains(dim))
+ continue;
+ if (it != ranges[dim].end()) {
+ ++it;
+ break;
+ }
+ it = ranges[dim].begin();
+ }
+ return *this;
+}
+
+void LayoutIterator::maybeFreezeAndConcatenate(
+ const LayoutIterator &frozenIterator) {
+ for (auto &[frozenDim, frozenIt] : frozenIterator.getState()) {
+ if (!state.contains(frozenDim)) {
+ frozenDimensions.insert(frozenDim);
+ state[frozenDim] = frozenIt;
+ }
+ }
+}
+
+static bool isLaneDimension(LayoutDimension dim) {
+ return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
+ (dim == LayoutDimension::LANEZ);
+}
+
+void LayoutIterator::initialize(PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides) {
+ auto reversedLabels = llvm::reverse(attr.getLabels());
+ auto reversedShapes = llvm::reverse(attr.getShapes());
+ for (auto [nameAttr, shape] : llvm::zip(reversedLabels, reversedShapes)) {
+ LayoutDimension dim = nameAttr.getValue();
+ if (isLaneDimension(dim))
+ continue;
+ int64_t stride = strides.contains(dim) ? strides[dim] : 1;
+ ranges[dim] = DimensionalRange(0, shape - 1, stride);
+ state[dim] = ranges[dim].begin();
+ }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides) {
+ for (PerDimLayoutAttr perDimAttr : attr.getLayouts()) {
+ initialize(perDimAttr, strides);
+ }
+}
+
+LayoutIterator::LayoutIterator(PerDimLayoutAttr &attr,
+ DenseMap<LayoutDimension, int64_t> strides) {
+ initialize(attr, strides);
+}
+
+/// The iterator is done when it returns back to
+/// its begin state.
+bool LayoutIterator::iterationComplete() {
+ bool complete{true};
+ for (auto &[dim, it] : state) {
+ if (frozenDimensions.contains(dim))
+ continue;
+ if (it != ranges[dim].begin()) {
+ complete = false;
+ break;
+ }
+ }
+ return complete;
+}
+
+void LayoutIterator::apply(
+ std::function<void(const LayoutIterator::State &)> callback) {
+ do {
+ callback(state);
+ ++(*this);
+ } while (!iterationComplete());
+}
+
// clang-format off
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep
-// clang-format: on
+// clang-format on
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir
new file mode 100644
index 0000000..08166a0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-dialects-opt --split-input-file --test-vector-ext-iterators %s | FileCheck %s
+
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1,
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+func.func @iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
+ return %result : vector<4x4xf16>
+}
+
+// -----
+
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+func.func @frozen_iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_frozen_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
+ return %result : vector<4x4xf16>
+}
diff --git a/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
index e31af32..b130282 100644
--- a/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(Transforms)
+add_subdirectory(VectorExt)
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt
new file mode 100644
index 0000000..21f2fc5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(IREEVectorExtTestPasses
+ TestIterators.cpp
+
+ DEPENDS
+ mlir-headers
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ IREEVectorExtDialect
+ MLIRPass
+ )
+
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
new file mode 100644
index 0000000..fce0953
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
@@ -0,0 +1,90 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+namespace {
+
+static const StringRef kIteratorMarker = "__test_iterator_layout__";
+static const StringRef kFrozenIteratorMarker =
+ "__test_frozen_iterator_layout__";
+
+struct TestVectorExtIteratorPass
+ : public PassWrapper<TestVectorExtIteratorPass, Pass> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorExtIteratorPass)
+ TestVectorExtIteratorPass() = default;
+ TestVectorExtIteratorPass(const TestVectorExtIteratorPass &other)
+ : PassWrapper(other) {}
+ StringRef getArgument() const final { return "test-vector-ext-iterators"; }
+ StringRef getDescription() const final {
+ return "Test VectorExt Iterator pass.";
+ }
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+ // Prints the layout so that LIT can test it for correctness.
+ static void printFn(const LayoutIterator::State &state) {
+ for (const auto &[dim, it] : state) {
+ llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
+ std::to_string(*it) + ", ";
+ }
+ llvm::outs() << "\n";
+ }
+ void testIterator(Operation *op) {
+ auto layout = dyn_cast_or_null<LayoutAttr>(op->getAttr(kIteratorMarker));
+ DenseMap<LayoutDimension, int64_t> strides;
+ LayoutIterator iterator(layout, strides);
+ iterator.apply(printFn);
+ }
+ LayoutDimensionAttr createLayoutDimensionAttr(MLIRContext *ctx,
+ LayoutDimension dim) {
+ return LayoutDimensionAttr::get(ctx, dim);
+ }
+ LayoutIterator
+ createFrozenIterator(MLIRContext *ctx,
+ DenseMap<LayoutDimension, int64_t> &strides) {
+ SmallVector<LayoutDimensionAttr> labels{
+ createLayoutDimensionAttr(ctx, LayoutDimension::VECTORZ),
+ createLayoutDimensionAttr(ctx, LayoutDimension::VECTORX)};
+ auto newLayout =
+ LayoutAttr::get(ctx, {PerDimLayoutAttr::get(ctx, labels[0], {1}),
+ PerDimLayoutAttr::get(ctx, labels[1], {1})});
+ return LayoutIterator(newLayout, strides);
+ }
+ void testFrozenIterator(Operation *op) {
+ auto layout =
+ dyn_cast_or_null<LayoutAttr>(op->getAttr(kFrozenIteratorMarker));
+ DenseMap<LayoutDimension, int64_t> strides;
+ LayoutIterator iterator(layout, strides);
+ auto frozenIterator = createFrozenIterator(op->getContext(), strides);
+ iterator.maybeFreezeAndConcatenate(frozenIterator);
+ iterator.apply(printFn);
+ }
+ void runOnOperation() override {
+ getOperation()->walk([&](Operation *op) {
+ if (op->hasAttr(kIteratorMarker)) {
+ return testIterator(op);
+ }
+ if (op->hasAttr(kFrozenIteratorMarker)) {
+ return testFrozenIterator(op);
+ }
+ });
+ }
+};
+
+} // namespace
+
+namespace mlir::test_ext {
+void registerVectorExtTestPasses() {
+ PassRegistration<TestVectorExtIteratorPass>();
+}
+} // namespace mlir::test_ext
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 5789bc5..4b3f125 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -9,6 +9,7 @@
IREELinalgTransformDialectPasses
IREETransformsTestPasses
IREEVectorExtDialect
+ IREEVectorExtTestPasses
# Core dialects.
MLIRAffineDialect
MLIRArithDialect
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index 085c31e..c6c1b1d 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -47,6 +47,7 @@
namespace test_ext {
/// Test passes, do not deserve an include.
void registerTestListenerPasses();
+void registerVectorExtTestPasses();
} // namespace test_ext
} // namespace mlir
@@ -88,6 +89,7 @@
mlir::linalg::transform::registerDropSchedulePass();
// Local test passes.
mlir::test_ext::registerTestListenerPasses();
+ mlir::test_ext::registerVectorExtTestPasses();
// External models.
mlir::func::registerInlinerExtension(registry);