[VectorExt] Add layout iterator classes (#16004)

This PR adds iterator classes to iterate over the layout and to
concatenate iterators with other frozen iterators. This is required when
distributing reductions. Also adds a test to check for correctness.
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index a4553b3..37f6d4d 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -756,6 +756,7 @@
         ":IREELinalgExtDialect",
         ":IREELinalgTransformDialect",
         ":IREELinalgTransformDialectPasses",
+        ":IREEVectorExtDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Rewrite",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index f8794d2..c1beca8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -29,4 +29,70 @@
 
 // clang-format on
 
+namespace mlir::iree_compiler::IREE::VectorExt {
+
+/// Dimensional Strided Iterator class used to represent
+/// an iterator through a single dimension of the layout.
+class DimensionalIterator {
+public:
+  DimensionalIterator(int64_t position = 0, int64_t stride = 1)
+      : position(position), stride(stride) {}
+  int64_t operator*() const { return position; }
+  DimensionalIterator &operator++() {
+    position += stride;
+    return *this;
+  }
+  bool operator!=(const DimensionalIterator &other) const {
+    return position != other.position;
+  }
+
+private:
+  int64_t position, stride;
+};
+
+/// Dimensional Range class used to represent the range of
+/// a particular dimension of the layout. Can be iterated on
+/// using a DimensionalIterator.
+class DimensionalRange {
+public:
+  DimensionalRange() {}
+  DimensionalRange(int64_t start, int64_t stop, int64_t step = 1)
+      : start(start), stop(stop), step(step) {}
+  DimensionalIterator begin() const { return DimensionalIterator(start, step); }
+  DimensionalIterator end() const { return DimensionalIterator(stop, step); }
+
+private:
+  int64_t start, stop, step;
+};
+
+// Iterator class for LayoutAttrs and PerDimLayoutAttrs.
+// Provides O(1) access to state for any given dimension.
+// Also preserves insertion order.
+// Layout iterators skip lane dimensions as these are not
+// required during distribution.
+class LayoutIterator {
+public:
+  using State = llvm::MapVector<LayoutDimension, DimensionalIterator>;
+  using DimensionMapping =
+      llvm::DenseMap<int64_t, SmallVector<LayoutDimension>>;
+  void maybeFreezeAndConcatenate(const LayoutIterator &frozenIterator);
+  LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides);
+  LayoutIterator(PerDimLayoutAttr &attr,
+                 DenseMap<LayoutDimension, int64_t> strides);
+  void apply(std::function<void(const LayoutIterator::State &)>);
+  LayoutIterator &operator++();
+  State getState() const { return state; }
+
+private:
+  void initialize(PerDimLayoutAttr &attr,
+                  DenseMap<LayoutDimension, int64_t> strides);
+  bool iterationComplete();
+  State state;
+  llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
+  DimensionMapping simdDimensionToLayoutDimension;
+  DenseSet<LayoutDimension> frozenDimensions;
+};
+
+} // namespace mlir::iree_compiler::IREE::VectorExt
+
 #endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
index 49c77b0..1c57d06 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp
@@ -52,7 +52,84 @@
   return {};
 }
 
+LayoutIterator &LayoutIterator::operator++() {
+  for (auto &[dim, it] : state) {
+    if (frozenDimensions.contains(dim))
+      continue;
+    if (it != ranges[dim].end()) {
+      ++it;
+      break;
+    }
+    it = ranges[dim].begin();
+  }
+  return *this;
+}
+
+void LayoutIterator::maybeFreezeAndConcatenate(
+    const LayoutIterator &frozenIterator) {
+  for (auto &[frozenDim, frozenIt] : frozenIterator.getState()) {
+    if (!state.contains(frozenDim)) {
+      frozenDimensions.insert(frozenDim);
+      state[frozenDim] = frozenIt;
+    }
+  }
+}
+
+static bool isLaneDimension(LayoutDimension dim) {
+  return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
+         (dim == LayoutDimension::LANEZ);
+}
+
+void LayoutIterator::initialize(PerDimLayoutAttr &attr,
+                                DenseMap<LayoutDimension, int64_t> strides) {
+  auto reversedLabels = llvm::reverse(attr.getLabels());
+  auto reversedShapes = llvm::reverse(attr.getShapes());
+  for (auto [nameAttr, shape] : llvm::zip(reversedLabels, reversedShapes)) {
+    LayoutDimension dim = nameAttr.getValue();
+    if (isLaneDimension(dim))
+      continue;
+    int64_t stride = strides.contains(dim) ? strides[dim] : 1;
+    ranges[dim] = DimensionalRange(0, shape - 1, stride);
+    state[dim] = ranges[dim].begin();
+  }
+}
+
+LayoutIterator::LayoutIterator(LayoutAttr &attr,
+                               DenseMap<LayoutDimension, int64_t> strides) {
+  for (PerDimLayoutAttr perDimAttr : attr.getLayouts()) {
+    initialize(perDimAttr, strides);
+  }
+}
+
+LayoutIterator::LayoutIterator(PerDimLayoutAttr &attr,
+                               DenseMap<LayoutDimension, int64_t> strides) {
+  initialize(attr, strides);
+}
+
+/// The iterator is done when it returns back to
+/// its begin state.
+bool LayoutIterator::iterationComplete() {
+  bool complete{true};
+  for (auto &[dim, it] : state) {
+    if (frozenDimensions.contains(dim))
+      continue;
+    if (it != ranges[dim].begin()) {
+      complete = false;
+      break;
+    }
+  }
+  return complete;
+}
+
+void LayoutIterator::apply(
+    std::function<void(const LayoutIterator::State &)> callback) {
+  do {
+    callback(state);
+    ++(*this);
+  } while (!iterationComplete());
+}
+
 // clang-format off
 #define GET_OP_CLASSES
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep
-// clang-format: on
+// clang-format on
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir
new file mode 100644
index 0000000..08166a0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-dialects-opt --split-input-file --test-vector-ext-iterators %s | FileCheck %s
+
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1,
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+func.func @iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
+  %cst_0 = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
+  return %result : vector<4x4xf16>
+}
+
+// -----
+
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
+// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
+#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
+#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
+#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
+func.func @frozen_iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
+  %cst_0 = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_frozen_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
+  return %result : vector<4x4xf16>
+}
diff --git a/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
index e31af32..b130282 100644
--- a/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(Transforms)
+add_subdirectory(VectorExt)
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt
new file mode 100644
index 0000000..21f2fc5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(IREEVectorExtTestPasses
+  TestIterators.cpp
+
+  DEPENDS
+  mlir-headers
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  IREEVectorExtDialect
+  MLIRPass
+  )
+
diff --git a/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
new file mode 100644
index 0000000..fce0953
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
@@ -0,0 +1,90 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::VectorExt;
+
+namespace {
+
+static const StringRef kIteratorMarker = "__test_iterator_layout__";
+static const StringRef kFrozenIteratorMarker =
+    "__test_frozen_iterator_layout__";
+
+struct TestVectorExtIteratorPass
+    : public PassWrapper<TestVectorExtIteratorPass, Pass> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorExtIteratorPass)
+  TestVectorExtIteratorPass() = default;
+  TestVectorExtIteratorPass(const TestVectorExtIteratorPass &other)
+      : PassWrapper(other) {}
+  StringRef getArgument() const final { return "test-vector-ext-iterators"; }
+  StringRef getDescription() const final {
+    return "Test VectorExt Iterator pass.";
+  }
+  bool canScheduleOn(RegisteredOperationName opName) const override {
+    return true;
+  }
+  // Prints the layout so that LIT can test it for correctness.
+  static void printFn(const LayoutIterator::State &state) {
+    for (const auto &[dim, it] : state) {
+      llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
+                          std::to_string(*it) + ", ";
+    }
+    llvm::outs() << "\n";
+  }
+  void testIterator(Operation *op) {
+    auto layout = dyn_cast_or_null<LayoutAttr>(op->getAttr(kIteratorMarker));
+    DenseMap<LayoutDimension, int64_t> strides;
+    LayoutIterator iterator(layout, strides);
+    iterator.apply(printFn);
+  }
+  LayoutDimensionAttr createLayoutDimensionAttr(MLIRContext *ctx,
+                                                LayoutDimension dim) {
+    return LayoutDimensionAttr::get(ctx, dim);
+  }
+  LayoutIterator
+  createFrozenIterator(MLIRContext *ctx,
+                       DenseMap<LayoutDimension, int64_t> &strides) {
+    SmallVector<LayoutDimensionAttr> labels{
+        createLayoutDimensionAttr(ctx, LayoutDimension::VECTORZ),
+        createLayoutDimensionAttr(ctx, LayoutDimension::VECTORX)};
+    auto newLayout =
+        LayoutAttr::get(ctx, {PerDimLayoutAttr::get(ctx, labels[0], {1}),
+                              PerDimLayoutAttr::get(ctx, labels[1], {1})});
+    return LayoutIterator(newLayout, strides);
+  }
+  void testFrozenIterator(Operation *op) {
+    auto layout =
+        dyn_cast_or_null<LayoutAttr>(op->getAttr(kFrozenIteratorMarker));
+    DenseMap<LayoutDimension, int64_t> strides;
+    LayoutIterator iterator(layout, strides);
+    auto frozenIterator = createFrozenIterator(op->getContext(), strides);
+    iterator.maybeFreezeAndConcatenate(frozenIterator);
+    iterator.apply(printFn);
+  }
+  void runOnOperation() override {
+    getOperation()->walk([&](Operation *op) {
+      if (op->hasAttr(kIteratorMarker)) {
+        return testIterator(op);
+      }
+      if (op->hasAttr(kFrozenIteratorMarker)) {
+        return testFrozenIterator(op);
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir::test_ext {
+void registerVectorExtTestPasses() {
+  PassRegistration<TestVectorExtIteratorPass>();
+}
+} // namespace mlir::test_ext
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 5789bc5..4b3f125 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -9,6 +9,7 @@
   IREELinalgTransformDialectPasses
   IREETransformsTestPasses
   IREEVectorExtDialect
+  IREEVectorExtTestPasses
   # Core dialects.
   MLIRAffineDialect
   MLIRArithDialect
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index 085c31e..c6c1b1d 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -47,6 +47,7 @@
 namespace test_ext {
 /// Test passes, do not deserve an include.
 void registerTestListenerPasses();
+void registerVectorExtTestPasses();
 } // namespace test_ext
 } // namespace mlir
 
@@ -88,6 +89,7 @@
   mlir::linalg::transform::registerDropSchedulePass();
   // Local test passes.
   mlir::test_ext::registerTestListenerPasses();
+  mlir::test_ext::registerVectorExtTestPasses();
 
   // External models.
   mlir::func::registerInlinerExtension(registry);