[compiler] Mesh -> IREE conversion (#16199)
Add conversion of the upstream MLIR Mesh Dialect to IREE Flow. Handle
global collective channel construction through to the runtime.
The supported operations are
* retrieval of rank/process index
* all-gather
* all-reduce
* all-to-all
* reduce-scatter
Dynamic device meshes are not supported.
Some operations do not support dynamic tensor dimensions.
diff --git a/build_tools/bazel/build_test_all.sh b/build_tools/bazel/build_test_all.sh
index 8558b36..045e20b 100755
--- a/build_tools/bazel/build_test_all.sh
+++ b/build_tools/bazel/build_test_all.sh
@@ -30,6 +30,7 @@
IREE_READ_REMOTE_BAZEL_CACHE="${IREE_READ_REMOTE_BAZEL_CACHE:-1}"
IREE_WRITE_REMOTE_BAZEL_CACHE="${IREE_WRITE_REMOTE_BAZEL_CACHE:-0}"
BAZEL_BIN="${BAZEL_BIN:-$(which bazel)}"
+SANDBOX_BASE="${SANDBOX_BASE:-}"
if (( ${IREE_WRITE_REMOTE_BAZEL_CACHE} == 1 && ${IREE_READ_REMOTE_BAZEL_CACHE} != 1 )); then
echo "Can't have 'IREE_WRITE_REMOTE_BAZEL_CACHE' (${IREE_WRITE_REMOTE_BAZEL_CACHE}) set without 'IREE_READ_REMOTE_BAZEL_CACHE' (${IREE_READ_REMOTE_BAZEL_CACHE})"
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
index 2d5726d..65851fa 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
@@ -45,55 +45,6 @@
// mode combinations for cross-replica and cross partition communication. See
// the stablehlo specification for more details about the different modes.
-static std::optional<IREE::Flow::CollectiveElementType>
-convertToFlowCollectiveElementType(Type type) {
- if (type.isF32()) {
- return IREE::Flow::CollectiveElementType::Float32;
- }
-
- if (type.isInteger(32)) {
- if (type.isSignedInteger()) {
- return IREE::Flow::CollectiveElementType::Sint32;
- }
- return IREE::Flow::CollectiveElementType::Uint32;
- }
-
- if (type.isF16()) {
- return IREE::Flow::CollectiveElementType::Float16;
- }
-
- if (type.isInteger(8)) {
- if (type.isSignedInteger()) {
- return IREE::Flow::CollectiveElementType::Sint8;
- }
- return IREE::Flow::CollectiveElementType::Uint8;
- }
-
- if (type.isInteger(16)) {
- if (type.isSignedInteger()) {
- return IREE::Flow::CollectiveElementType::Sint16;
- }
- return IREE::Flow::CollectiveElementType::Uint16;
- }
-
- if (type.isBF16()) {
- return IREE::Flow::CollectiveElementType::BFloat16;
- }
-
- if (type.isF64()) {
- return IREE::Flow::CollectiveElementType::Float64;
- }
-
- if (type.isInteger(64)) {
- if (type.isSignedInteger()) {
- return IREE::Flow::CollectiveElementType::Sint64;
- }
- return IREE::Flow::CollectiveElementType::Uint64;
- }
-
- return std::nullopt;
-}
-
static std::optional<IREE::Flow::CollectiveReductionOp>
convertToFlowCollectiveReductionOp(const Operation &op) {
if (isa<mlir::stablehlo::AddOp>(op)) {
@@ -113,17 +64,6 @@
return std::nullopt;
}
-static IREE::Flow::CollectiveElementTypeAttr
-getCollectiveElementTypeAttr(MLIRContext *context, RankedTensorType type) {
- std::optional<IREE::Flow::CollectiveElementType> collectiveElemType =
- convertToFlowCollectiveElementType(type.getElementType());
- if (!collectiveElemType) {
- return IREE::Flow::CollectiveElementTypeAttr();
- }
- return IREE::Flow::CollectiveElementTypeAttr::get(context,
- *collectiveElemType);
-}
-
template <typename T>
static LogicalResult checkCollectiveAttrs(T op, PatternRewriter &rewriter) {
// Note that the channel handle attribute consists of two 64-bit values,
@@ -553,7 +493,7 @@
// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getResult().getType());
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
- getCollectiveElementTypeAttr(op.getContext(), resultType);
+ IREE::Flow::getCollectiveElementTypeAttr(resultType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(
op, "unsupported element type for collective op");
@@ -649,7 +589,7 @@
// Get the collective element type attribute.
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
- getCollectiveElementTypeAttr(op.getContext(), inputType);
+ IREE::Flow::getCollectiveElementTypeAttr(inputType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(op, "unsupported input type");
}
@@ -738,7 +678,7 @@
// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getType());
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
- getCollectiveElementTypeAttr(op.getContext(), resultType);
+ IREE::Flow::getCollectiveElementTypeAttr(resultType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(
op, "unsupported element type for collective op");
@@ -842,7 +782,7 @@
// Get the collective element type attribute.
auto resultType = cast<RankedTensorType>(op.getResult().getType());
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
- getCollectiveElementTypeAttr(op.getContext(), resultType);
+ IREE::Flow::getCollectiveElementTypeAttr(resultType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(op, "unsupported input type");
}
@@ -932,7 +872,7 @@
// Get the collective element type attribute.
IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
- getCollectiveElementTypeAttr(op.getContext(), inputType);
+ IREE::Flow::getCollectiveElementTypeAttr(inputType);
if (!elementTypeAttr) {
return rewriter.notifyMatchFailure(op, "unsupported input type");
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel
new file mode 100644
index 0000000..32305c5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel
@@ -0,0 +1,57 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "MeshToFlow",
+ srcs = [
+ "MeshToFlow.cpp",
+ ],
+ hdrs = [
+ "MeshToFlow.h",
+ ],
+ deps = [
+ ":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "//compiler/src/iree/compiler/Utils",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:ArithUtils",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:LinalgUtils",
+ "@llvm-project//mlir:MeshDialect",
+ "@llvm-project//mlir:MeshTransforms",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["--gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt
new file mode 100644
index 0000000..f1028f2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt
@@ -0,0 +1,49 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ MeshToFlow
+ HDRS
+ "MeshToFlow.h"
+ SRCS
+ "MeshToFlow.cpp"
+ DEPS
+ ::PassesIncGen
+ LLVMSupport
+ MLIRAffineDialect
+ MLIRArithUtils
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRLinalgUtils
+ MLIRMeshDialect
+ MLIRMeshTransforms
+ MLIRPass
+ MLIRSupport
+ MLIRTensorDialect
+ MLIRTransforms
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Utils
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ --gen-pass-decls Passes.h.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp
new file mode 100644
index 0000000..d9294a4
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp
@@ -0,0 +1,737 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
+
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/Folding.h"
+#include "iree/compiler/Utils/Indexing.h"
+#include "iree/compiler/Utils/OpVisitor.h"
+#include "iree/compiler/Utils/Permutation.h"
+#include "iree/compiler/Utils/SmallVectorDenseMapInfo.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-mesh-to-flow"
+
+namespace mlir::iree_compiler::IREE::Flow {
+
+#define GEN_PASS_CLASSES
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.h.inc" // IWYU pragma: keep
+
+static bool hasMoreThanOneMesh(Operation *op) {
+ int meshCount = 0;
+ op->walk([&meshCount](mesh::MeshOp mesh) {
+ ++meshCount;
+ return meshCount > 1 ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return meshCount > 1;
+}
+
+static bool isDefaultChannel(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> meshAxes) {
+ if (mesh.getRank() != static_cast<int64_t>(meshAxes.size())) {
+ return false;
+ }
+ return isIdentityPermutation(meshAxes);
+}
+
+static Value getDefaultChannel(mesh::MeshOp mesh, bool useNamedDefaultChannels,
+ ImplicitLocOpBuilder &builder) {
+ if (useNamedDefaultChannels)
+ return builder.create<IREE::Flow::ChannelDefaultOp>(mesh.getSymName());
+ else
+ return builder.create<IREE::Flow::ChannelDefaultOp>();
+}
+
+// Remove from `values` elements that have indices present in filter.
+static SmallVector<Value> filterOutByIndex(ArrayRef<Value> values,
+ ArrayRef<mesh::MeshAxis> filter) {
+ SmallVector<Value> res;
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (!llvm::is_contained(filter, i)) {
+ res.push_back(values[i]);
+ }
+ }
+ return res;
+}
+
+static CollectiveReductionOp convertReductionKind(mesh::Partial reduction) {
+ switch (reduction) {
+ case mesh::Partial::Max:
+ return CollectiveReductionOp::ReductionMaximum;
+ case mesh::Partial::Min:
+ return CollectiveReductionOp::ReductionMinimum;
+ case mesh::Partial::Sum:
+ return CollectiveReductionOp::ReductionSum;
+ default:
+ assert(false);
+ return CollectiveReductionOp::None;
+ }
+}
+
+static CollectiveReductionOpAttr
+convertReductionKind(mesh::PartialAttr reduction) {
+ return CollectiveReductionOpAttr::get(
+ reduction.getContext(), convertReductionKind(reduction.getValue()));
+}
+
+static Value buildChannelCreation(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> meshAxes,
+ bool useNamedDefaultChannels,
+ ImplicitLocOpBuilder &builder) {
+ assert(mesh);
+ Value meshChannel = getDefaultChannel(mesh, useNamedDefaultChannels, builder);
+ SmallVector<Value> meshProcessMultiIndex =
+ builder.create<mesh::ProcessMultiIndexOp>(mesh).getResults();
+ SmallVector<Value> meshShape =
+ builder.create<mesh::MeshShapeOp>(mesh).getResults();
+ SmallVector<Value> reorderedMeshIndex =
+ permute(ArrayRef<Value>(meshProcessMultiIndex), meshAxes);
+ SmallVector<Value> reorderedMeshShape =
+ permute(ArrayRef<Value>(meshShape), meshAxes);
+ SmallVector<Value> groupIndex =
+ filterOutByIndex(meshProcessMultiIndex, meshAxes);
+ SmallVector<Value> groupsShape = filterOutByIndex(meshShape, meshAxes);
+ OpFoldResult reorderedProcessLinearIndex =
+ linearIndexFromShape(toOpFoldResults(reorderedMeshIndex),
+ toOpFoldResults(reorderedMeshShape), builder);
+ OpFoldResult color = linearIndexFromShape(
+ toOpFoldResults(groupIndex), toOpFoldResults(groupsShape), builder);
+ return builder.create<ChannelSplitOp>(
+ meshChannel,
+ getValueOrCreateConstantIndexOp(builder, builder.getLoc(), color),
+ getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+ reorderedProcessLinearIndex));
+}
+
+static SmallString<64> getChannelName(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> axes) {
+ SmallString<64> res;
+ llvm::raw_svector_ostream stream(res);
+ stream << "_mesh_" << mesh.getSymName();
+ if (axes.empty()) {
+ return res;
+ }
+
+ stream << "_axes";
+ for (mesh::MeshAxis axis : axes) {
+ stream << "_" << axis;
+ }
+
+ return res;
+}
+
+static void buildChannelInitializer(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> meshAxes,
+ bool useNamedDefaultChannels,
+ ImplicitLocOpBuilder &builder) {
+ Util::InitializerOp initOp = builder.create<Util::InitializerOp>();
+ Block *block = builder.createBlock(&initOp.getBody());
+ ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToStart(block);
+ Value channel =
+ buildChannelCreation(mesh, meshAxes, useNamedDefaultChannels, builder);
+ builder.create<Util::GlobalStoreOp>(channel, getChannelName(mesh, meshAxes));
+ builder.create<Util::ReturnOp>();
+}
+
+// Construct a Flow channel inside `module` using
+// util.global and util.initializer.
+static void buildGlobalChannelCreation(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> meshAxes,
+ bool useNamedDefaultChannels,
+ ModuleOp module, OpBuilder &opBuilder) {
+ if (isDefaultChannel(mesh, meshAxes)) {
+ return;
+ }
+
+ ImplicitLocOpBuilder builder(mesh.getLoc(), opBuilder);
+ ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToStart(&module.getBodyRegion().getBlocks().front());
+
+ auto channelName = getChannelName(mesh, meshAxes);
+ builder.create<Util::GlobalOp>(
+ builder.getStringAttr("private"), channelName,
+ IREE::Flow::ChannelType::get(builder.getContext()), false, TypedAttr(),
+ IREE::Util::InlineNeverAttr::get(builder.getContext()));
+ buildChannelInitializer(mesh, meshAxes, useNamedDefaultChannels, builder);
+}
+
+static Value buildCachedChannelLoading(mesh::MeshOp mesh,
+ ArrayRef<mesh::MeshAxis> meshAxes,
+ bool useNamedDefaultChannels,
+ ImplicitLocOpBuilder &builder) {
+ if (isDefaultChannel(mesh, meshAxes)) {
+ return getDefaultChannel(mesh, useNamedDefaultChannels, builder);
+ }
+ return builder.create<Util::GlobalLoadOp>(
+ ChannelType::get(builder.getContext()), getChannelName(mesh, meshAxes));
+}
+
+// The !flow.channel corresponding to the mesh and mesh axes used in the op.
+template <typename MeshCollectiveOp>
+static Value buildCachedChannelLoading(
+ MeshCollectiveOp op, SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels, ImplicitLocOpBuilder &builder) {
+ ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointAfter(op);
+
+ mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+ return buildCachedChannelLoading(mesh, op.getMeshAxes(),
+ useNamedDefaultChannels, builder);
+}
+
+SmallVector<mesh::MeshAxis> getAllMeshAxes(mesh::MeshOp mesh) {
+ SmallVector<mesh::MeshAxis> res(mesh.getRank());
+ std::iota(res.begin(), res.end(), 0);
+ return res;
+}
+
+static Value buildCachedChannelLoading(
+ mesh::ProcessLinearIndexOp op, SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels, ImplicitLocOpBuilder &builder) {
+ ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointAfter(op);
+
+ mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+ return buildCachedChannelLoading(mesh, getAllMeshAxes(mesh),
+ useNamedDefaultChannels, builder);
+}
+
+static TypedValue<RankedTensorType>
+buildTranspose(Value v, ArrayRef<int64_t> transposeVector,
+ ImplicitLocOpBuilder &builder) {
+ RankedTensorType type = v.getType().cast<RankedTensorType>();
+ SmallVector<int64_t> transposedShape =
+ permute(type.getShape(), transposeVector);
+ Value target =
+ builder.create<tensor::EmptyOp>(transposedShape, type.getElementType());
+ return builder.create<linalg::TransposeOp>(v, target, transposeVector)
+ ->getResult(0)
+ .cast<TypedValue<RankedTensorType>>();
+}
+
+static SmallVector<int64_t> transpose(ArrayRef<int64_t> shape, int64_t axisA,
+ int64_t axisB) {
+ SmallVector<int64_t> res = llvm::to_vector(shape);
+ std::swap(res[axisA], res[axisB]);
+ return res;
+}
+
+static RankedTensorType transpose(RankedTensorType type, int64_t axisA,
+ int64_t axisB) {
+ SmallVector<int64_t> newShape = transpose(type.getShape(), axisA, axisB);
+ return type.clone(newShape);
+}
+
+static TypedValue<RankedTensorType>
+buildTranspose(Value v, int64_t axisA, int64_t axisB,
+ ImplicitLocOpBuilder &builder) {
+ int64_t rank = v.getType().cast<RankedTensorType>().getRank();
+ SmallVector<int64_t> transposeVector(rank);
+ std::iota(transposeVector.begin(), transposeVector.end(), 0);
+ std::swap(transposeVector[axisA], transposeVector[axisB]);
+ return buildTranspose(v, transposeVector, builder);
+}
+
+// (..., splitAxisSize, ...) ->
+// (..., splitCount, splitAxisSize / splitCount, ...)
+static TypedValue<RankedTensorType>
+splitAxis(TypedValue<RankedTensorType> tensor, int64_t splitAxis,
+ int64_t splitCount, ImplicitLocOpBuilder &builder) {
+ ArrayRef<int64_t> shape = tensor.getType().getShape();
+ SmallVector<int64_t> newShape;
+ newShape.reserve(shape.size() + 1);
+ for (int64_t i = 0; i < tensor.getType().getRank(); ++i) {
+ if (i != splitAxis) {
+ newShape.push_back(shape[i]);
+ continue;
+ }
+ newShape.push_back(splitCount);
+ if (ShapedType::isDynamic(shape[i])) {
+ newShape.push_back(ShapedType::kDynamic);
+ } else {
+ assert(shape[i] % splitCount == 0);
+ newShape.push_back(shape[i] / splitCount);
+ }
+ }
+
+ RankedTensorType resultType = tensor.getType().clone(newShape);
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForReshape(tensor.getType(), resultType);
+ return builder.create<tensor::ExpandShapeOp>(resultType, tensor,
+ reassociation.value());
+}
+
+// Transposes the input tensor by moving an axis to a new position by inserting
+// it there.
+static TypedValue<RankedTensorType>
+moveAxis(TypedValue<RankedTensorType> tensor, int64_t axis, int64_t destination,
+ ImplicitLocOpBuilder &builder) {
+ SmallVector<int64_t> permutation =
+ makeMovePermutation(tensor.getType().getRank(), axis, destination);
+ return buildTranspose(tensor, permutation, builder);
+}
+
+static SmallVector<int64_t> collapseAxesN(ArrayRef<int64_t> shape,
+ size_t firstAxis, size_t n) {
+ assert(firstAxis + n <= shape.size());
+ assert(n > 1);
+ SmallVector<int64_t> res;
+ std::copy(shape.begin(), shape.begin() + firstAxis, std::back_inserter(res));
+ size_t collapsedAxisSize = std::accumulate(
+ shape.begin() + firstAxis + 1, shape.begin() + firstAxis + n,
+ shape[firstAxis], [](size_t a, size_t b) { return a * b; });
+ res.push_back(collapsedAxisSize);
+ std::copy(shape.begin() + firstAxis + n, shape.end(),
+ std::back_inserter(res));
+ return res;
+}
+
+// Collapses `n` axes starting with axis `firstAxis`.
+// Example:
+// tensor shape = (1, 2, 3, 4), firstAxis = 1, n = 2
+// The resulting tensor is with shape (1, 6, 4).
+static TypedValue<RankedTensorType>
+collapseAxesN(TypedValue<RankedTensorType> tensor, int64_t firstAxis, int64_t n,
+ ImplicitLocOpBuilder &builder) {
+ ArrayRef<int64_t> shape = tensor.getType().getShape();
+ SmallVector<int64_t> newShape = collapseAxesN(shape, firstAxis, n);
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(shape, newShape);
+ return builder.create<tensor::CollapseShapeOp>(tensor, reassociation.value());
+}
+
+// Splits an axis into 2 new dimensions and then move the new splitCount axis
+// and collapse it into collapseAxis.
+// The shape of the tensor and its transformations:
+// (..., splitAxisSize, ..., collapseAxisSize, ...)
+// -> split ->
+// (..., splitCount, splitAxisSize / splitCount, ..., collapseAxisSize, ...)
+// -> move ->
+// (..., splitAxisSize / splitCount, ..., splitCount, collapseAxisSize, ...)
+// -> concat ->
+// (..., splitAxisSize / splitCount, ..., splitCount * collapseAxisSize, ...)
+static TypedValue<RankedTensorType>
+splitMoveCollapse(TypedValue<RankedTensorType> tensor, int64_t splitAxis,
+ int64_t collapseAxis, int64_t splitCount,
+ ImplicitLocOpBuilder &builder) {
+ TypedValue<RankedTensorType> v =
+ IREE::Flow::splitAxis(tensor, splitAxis, splitCount, builder);
+ v = moveAxis(v, splitAxis, collapseAxis, builder);
+ return collapseAxesN(v, collapseAxis, 2, builder);
+}
+
+namespace {
+
+template <typename Op>
+struct MeshToFlowCollectiveRewritePatternBase : OpRewritePattern<Op> {
+ template <typename... OpRewritePatternArgs>
+ MeshToFlowCollectiveRewritePatternBase(
+ SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern<Op>(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTableCollection(symbolTableCollection),
+ useNamedDefaultChannels(useNamedDefaultChannels) {}
+
+protected:
+ SymbolTableCollection &symbolTableCollection;
+ bool useNamedDefaultChannels;
+};
+
+struct MeshAllReduceToFlow
+ : MeshToFlowCollectiveRewritePatternBase<mesh::AllReduceOp> {
+ using MeshToFlowCollectiveRewritePatternBase::
+ MeshToFlowCollectiveRewritePatternBase;
+
+ LogicalResult matchAndRewrite(mesh::AllReduceOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+ Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+ useNamedDefaultChannels, builder);
+ Value target = builder.create<tensor::EmptyOp>(
+ op.getResult().getType().getShape(),
+ op.getResult().getType().getElementType());
+ auto flowAllReduce = builder.create<CollectiveAllReduceOp>(
+ convertReductionKind(op.getReductionAttr()),
+ getCollectiveElementTypeAttr(op.getResult().getType()), target,
+ op.getOperand(), channel);
+ rewriter.replaceAllUsesWith(op.getResult(), flowAllReduce.getResult());
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+
+struct MeshAllGatherToFlow
+ : MeshToFlowCollectiveRewritePatternBase<mesh::AllGatherOp> {
+ using MeshToFlowCollectiveRewritePatternBase::
+ MeshToFlowCollectiveRewritePatternBase;
+
+ LogicalResult matchAndRewrite(mesh::AllGatherOp op,
+ PatternRewriter &rewriter) const override {
+ if (ShapedType::isDynamicShape(
+ op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+ ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+ // TODO: add dynamic support.
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "Dynamic tensor case is unsupported.");
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+ Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+ useNamedDefaultChannels, builder);
+
+ int64_t gatherAxis = op.getGatherAxis().getSExtValue();
+
+ // When gather axis != 0, we need to transpose between 0 and
+ // gather axis before and after the flow all-gather op.
+ Value flowAllGatherOperand =
+ buildTranspose(op.getOperand(), 0, gatherAxis, builder);
+
+ RankedTensorType flowAllGatherResultType = transpose(
+ op.getResult().getType().cast<RankedTensorType>(), 0, gatherAxis);
+ Value target = builder.create<tensor::EmptyOp>(
+ flowAllGatherResultType.getShape(),
+ op.getResult().getType().getElementType());
+ auto flowAllGather = builder.create<CollectiveAllGatherOp>(
+ getCollectiveElementTypeAttr(flowAllGatherResultType), target,
+ flowAllGatherOperand, channel);
+
+ Value res = buildTranspose(flowAllGather, 0, gatherAxis, builder);
+
+ rewriter.replaceAllUsesWith(op.getResult(), res);
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+
+struct MeshAllToAllToFlow
+ : MeshToFlowCollectiveRewritePatternBase<mesh::AllToAllOp> {
+ using MeshToFlowCollectiveRewritePatternBase::
+ MeshToFlowCollectiveRewritePatternBase;
+
+ LogicalResult matchAndRewrite(mesh::AllToAllOp op,
+ PatternRewriter &rewriter) const override {
+ if (ShapedType::isDynamicShape(
+ op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+ ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+ // TODO: add dynamic support.
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "Dynamic tensor case is unsupported.");
+ }
+
+ mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+ assert(!ShapedType::isDynamicShape(mesh.getShape()));
+ int64_t splitCount =
+ mesh::collectiveProcessGroupSize(op.getMeshAxes(), mesh.getShape());
+ // TODO: handle dynamic case.
+ if (ShapedType::isDynamic(splitCount)) {
+ // TODO: add dynamic support.
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "Dynamic split count induced by a dynamic mesh is unsupported.");
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+
+ Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+ useNamedDefaultChannels, builder);
+
+ int64_t splitAxis = op.getSplitAxis().getSExtValue();
+
+ TypedValue<RankedTensorType> splitAxisAsMostOuter =
+ buildTranspose(op.getOperand(), 0, splitAxis, builder);
+
+ Value target = builder.create<tensor::EmptyOp>(
+ splitAxisAsMostOuter.getType().getShape(),
+ splitAxisAsMostOuter.getType().getElementType());
+ auto flowAllToAll = builder.create<CollectiveAllToAllOp>(
+ getCollectiveElementTypeAttr(splitAxisAsMostOuter.getType()), target,
+ splitAxisAsMostOuter, channel);
+
+ TypedValue<RankedTensorType> splitAxisBackInItsPlace =
+ buildTranspose(flowAllToAll, 0, splitAxis, builder);
+
+ int64_t concatAxis = op.getConcatAxis().getSExtValue();
+ Value res = splitMoveCollapse(splitAxisBackInItsPlace, splitAxis,
+ concatAxis, splitCount, builder);
+
+ rewriter.replaceAllUsesWith(op.getResult(), res);
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+
+struct MeshProcessLinearIndexToFlow
+ : MeshToFlowCollectiveRewritePatternBase<mesh::ProcessLinearIndexOp> {
+ using MeshToFlowCollectiveRewritePatternBase::
+ MeshToFlowCollectiveRewritePatternBase;
+
+ LogicalResult matchAndRewrite(mesh::ProcessLinearIndexOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+ Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+ useNamedDefaultChannels, builder);
+ Value newIndex =
+ builder.create<ChannelRankOp>(builder.getIndexType(), channel);
+ rewriter.replaceAllUsesWith(op.getResult(), newIndex);
+ return success();
+ }
+};
+
+struct MeshReduceScatterToFlow
+ : MeshToFlowCollectiveRewritePatternBase<mesh::ReduceScatterOp> {
+ using MeshToFlowCollectiveRewritePatternBase::
+ MeshToFlowCollectiveRewritePatternBase;
+
+ LogicalResult matchAndRewrite(mesh::ReduceScatterOp op,
+ PatternRewriter &rewriter) const override {
+ if (ShapedType::isDynamicShape(
+ op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+ ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+ // TODO: add dynamic support.
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "Dynamic tensor case is unsupported.");
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+ Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+ useNamedDefaultChannels, builder);
+
+ int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+
+ // When scatter axis != 0, we need to transpose between 0 and
+ // scatter axis before and after the flow reduce-scatter op.
+ Value flowReduceScatterOperand =
+ buildTranspose(op.getOperand(), 0, scatterAxis, builder);
+ RankedTensorType flowReduceScatterResultType = transpose(
+ op.getResult().getType().cast<RankedTensorType>(), 0, scatterAxis);
+
+ Value target = builder.create<tensor::EmptyOp>(
+ flowReduceScatterResultType.getShape(),
+ op.getResult().getType().getElementType());
+ auto flowReduceScatter = builder.create<CollectiveReduceScatterOp>(
+ convertReductionKind(op.getReductionAttr()),
+ getCollectiveElementTypeAttr(flowReduceScatterResultType), target,
+ flowReduceScatterOperand, channel);
+
+ Value res = buildTranspose(flowReduceScatter, 0, scatterAxis, builder);
+
+ rewriter.replaceAllUsesWith(op.getResult(), res);
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+
+using MeshAndAxesSet =
+ DenseSet<std::tuple<mesh::MeshOp, SmallVector<mesh::MeshAxis>>>;
+
+template <typename Op>
+struct CollectiveOpVisitor {
+ CollectiveOpVisitor(MeshAndAxesSet &meshAndAxesSet,
+ SymbolTableCollection &symbolTableCollection)
+ : meshAndAxesSet(meshAndAxesSet),
+ symbolTableCollection(symbolTableCollection) {}
+ void operator()(Op op) {
+ meshAndAxesSet.insert(std::make_tuple(
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+ op, op.getMeshAttr()),
+ llvm::to_vector(op.getMeshAxes())));
+ }
+
+private:
+ MeshAndAxesSet &meshAndAxesSet;
+ SymbolTableCollection &symbolTableCollection;
+};
+
+template <typename Op>
+struct CollectiveOpWithoutMeshAxesVisitor {
+ CollectiveOpWithoutMeshAxesVisitor(
+ MeshAndAxesSet &meshAndAxesSet,
+ SymbolTableCollection &symbolTableCollection)
+ : meshAndAxesSet(meshAndAxesSet),
+ symbolTableCollection(symbolTableCollection) {}
+ void operator()(Op op) {
+ mesh::MeshOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+ op, op.getMeshAttr());
+ meshAndAxesSet.insert(std::make_tuple(mesh, getAllMeshAxes(mesh)));
+ }
+
+private:
+ MeshAndAxesSet &meshAndAxesSet;
+ SymbolTableCollection &symbolTableCollection;
+};
+
+void populateMeshAndAxes(Operation *op, MeshAndAxesSet &meshAndAxesSet,
+ SymbolTableCollection &symbolTableCollection) {
+ OpVisitorCollection opVisitors;
+ opVisitors.emplaceVisitors<
+ CollectiveOpVisitor<mesh::AllGatherOp>,
+ CollectiveOpVisitor<mesh::AllReduceOp>,
+ CollectiveOpVisitor<mesh::AllToAllOp>,
+ CollectiveOpVisitor<mesh::ReduceScatterOp>,
+ CollectiveOpWithoutMeshAxesVisitor<mesh::ProcessLinearIndexOp>>(
+ meshAndAxesSet, symbolTableCollection);
+
+ op->walk([&opVisitors](Operation *op) {
+ opVisitors(op);
+ return WalkResult::advance();
+ });
+}
+
+static void createChannels(ModuleOp moduleOp,
+ SymbolTableCollection &symbolTableCollection,
+ MeshAndAxesSet &meshAndAxesSet,
+ bool useNamedDefaultChannels) {
+ populateMeshAndAxes(moduleOp, meshAndAxesSet, symbolTableCollection);
+
+ OpBuilder builder(moduleOp->getContext());
+
+ // Sort for deterministic testing with FileCheck.
+ auto meshAndAxesSetSorted = llvm::to_vector(meshAndAxesSet);
+ llvm::sort(meshAndAxesSetSorted, [](auto &a, auto &b) {
+ int nameCompareRes =
+ std::get<0>(a).getSymName().compare(std::get<0>(b).getSymName());
+ if (nameCompareRes == 0)
+ return std::get<1>(a) < std::get<1>(b);
+ return nameCompareRes < 0;
+ });
+ for (auto &[mesh, meshAxes] : llvm::make_range(meshAndAxesSetSorted.rbegin(),
+ meshAndAxesSetSorted.rend())) {
+ buildGlobalChannelCreation(mesh, meshAxes, useNamedDefaultChannels,
+ moduleOp, builder);
+ }
+}
+
+static LogicalResult
+convertCollectives(ModuleOp moduleOp,
+ SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels) {
+ RewritePatternSet patterns(moduleOp->getContext());
+ IREE::Flow::populateMeshToFlowCollectivesPatterns(
+ patterns, symbolTableCollection, useNamedDefaultChannels);
+ return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
+static void removeMeshOps(MeshAndAxesSet &meshAndAxesSet) {
+ auto meshRange =
+ llvm::map_range(meshAndAxesSet, [](auto &v) { return std::get<0>(v); });
+ DenseSet<mesh::MeshOp> meshOpsSet(std::begin(meshRange), std::end(meshRange));
+ for (mesh::MeshOp op : meshOpsSet) {
+ if (op)
+ op.erase();
+ }
+}
+
+struct ConvertMeshToFlowPass
+ : public IREE::Flow::ConvertMeshToFlowBase<ConvertMeshToFlowPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registerMeshToFlowDependencies(registry);
+ }
+
+ void runOnOperation() override {
+ // Run only on the top module.
+ if (getOperation()->getParentOp() != nullptr) {
+ return;
+ }
+
+ MeshAndAxesSet meshAndAxesSet;
+ SymbolTableCollection symbolTableCollection;
+ bool useNamedDefaultChannels = hasMoreThanOneMesh(getOperation());
+
+ createChannels(getOperation(), symbolTableCollection, meshAndAxesSet,
+ useNamedDefaultChannels);
+ if (failed(convertCollectives(getOperation(), symbolTableCollection,
+ useNamedDefaultChannels))) {
+ return signalPassFailure();
+ }
+
+ // Cleanup mesh definition ops that are no longer referenced.
+ removeMeshOps(meshAndAxesSet);
+ }
+};
+
+} // namespace
+
+void populateMeshToFlowCollectivesPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels) {
+ patterns.add<MeshAllGatherToFlow, MeshAllReduceToFlow, MeshAllToAllToFlow,
+ MeshReduceScatterToFlow, MeshProcessLinearIndexToFlow>(
+ symbolTableCollection, useNamedDefaultChannels, patterns.getContext());
+ mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+ mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+ symbolTableCollection);
+}
+
+std::unique_ptr<Pass> createConvertMeshToFlowPass() {
+ return std::make_unique<ConvertMeshToFlowPass>();
+}
+
+void registerMeshToFlowDependencies(DialectRegistry ®istry) {
+ registry.insert<affine::AffineDialect, FlowDialect, linalg::LinalgDialect,
+ mesh::MeshDialect, tensor::TensorDialect>();
+}
+
+void registerMeshToFlowPasses() { registerPasses(); }
+
+} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h
new file mode 100644
index 0000000..acc0572
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h
@@ -0,0 +1,27 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class DialectRegistry;
+class SymbolTableCollection;
+class OpPassManager;
+class Pass;
+namespace iree_compiler::IREE::Flow {
+
+std::unique_ptr<Pass> createConvertMeshToFlowPass();
+
+void populateMeshToFlowCollectivesPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection,
+ bool useNamedDefaultChannels);
+
+void registerMeshToFlowDependencies(DialectRegistry ®istry);
+void registerMeshToFlowPasses();
+
+} // namespace iree_compiler::IREE::Flow
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td
new file mode 100644
index 0000000..c593328
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td
@@ -0,0 +1,53 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
+#define IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ConvertMeshToFlow :
+ Pass<"iree-convert-mesh-to-flow", "mlir::ModuleOp"> {
+ let summary = "Convert Mesh dialect operations to IREE Flow.";
+ let description = [{
+ Each mesh corresponds to a default Flow channel with the same group name.
+ ```
+ mesh.mesh @mesh_1(shape = 2x3)
+ ```
+ ```
+ %channel = flow.channel.default "mesh_1" : !flow.channel
+ ```
+ If there is onl one mesh in the program than the name is omitted and the
+ ```
+ %channel = flow.channel.default : !flow.channel
+ ```
+
+ Each (mesh, mesh_axes) pair partitions and orders the devices into disjoint
+ groups, each corresponding to a Flow channel to perform a collective
+ operation.
+ For example
+ ```
+ mesh.mesh @mesh(shape = 2x3x4x5)
+ ...
+ %1 = mesh.all_reduce on @mesh mesh_axes = [2, 0] : tensor<10x20xf32>
+ ```
+ For more information see
+ [Mesh dialect](https://mlir.llvm.org/docs/Dialects/Mesh/#device-groups).
+
+ The mesh partition and device ordering determines the values for the
+ `color` and `key` in the corresponding `flow.channel.split` operation used
+ to create the channel.
+ For more information on the meaning of `color` and `key` see
+ [MPI_Comm_split](https://www.mpi-forum.org/docs/mpi-4.1/mpi41-report/node188.htm#Node188)
+ in the MPI standard.
+
+ Each Flow channel is wrapped in an IREE `util.global` and its construction
+ is done only once with `util.initializer`.
+ }];
+ let constructor = "mlir::iree_compiler::IREE::Flow::createConvertMeshToFlowPass()";
+}
+
+#endif // IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel
new file mode 100644
index 0000000..72882ed
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel
@@ -0,0 +1,29 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "channel_creation.mlir",
+ "collectives.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt
new file mode 100644
index 0000000..449e06b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel#
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "channel_creation.mlir"
+ "collectives.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir
new file mode 100644
index 0000000..fe9cbae
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir
@@ -0,0 +1,130 @@
+// RUN: iree-opt --split-input-file --iree-convert-mesh-to-flow --cse %s | FileCheck %s
+
+// CHECK-LABEL: module @static_1d_mesh_grouping_along_axis_0
+module @static_1d_mesh_grouping_along_axis_0 {
+
+ // No channel initialization default channel is expected.
+ // CHECK-NOT: util.global private @_mesh_mesh_1d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+ mesh.mesh @mesh_1d(shape = 2)
+
+ func.func @f(
+ %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+ %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] reduction = <sum>
+ : tensor<1xi8> -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @static_2d_mesh_grouping_along_axis_1
+module @static_2d_mesh_grouping_along_axis_1 {
+
+ // CHECK: util.global private @_mesh_mesh_2d_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+ // CHECK: util.initializer {
+ // CHECK-DAG: %[[AXIS_1_SIZE:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[AXIS_0_SIZE:.*]] = arith.constant 3 : index
+ // CHECK-DAG: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default : !flow.channel
+ // CHECK: %[[CHANNEL_RANK:.*]] = flow.channel.rank %[[DEFAULT_CHANNEL]] : index
+ // CHECK: %[[COLOR_AND_KEY:.*]]:2 = affine.delinearize_index %[[CHANNEL_RANK]] into
+ // CHECK-SAME: (%[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]) : index, index
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+ // CHECK-SAME: %[[DEFAULT_CHANNEL]], %[[COLOR_AND_KEY]]#0, %[[COLOR_AND_KEY]]#1 : !flow.channel -> !flow.channel
+ // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh_2d_axes_1 : !flow.channel
+ mesh.mesh @mesh_2d(shape = 3x4)
+
+ func.func @f(%input : tensor<1xi8>) -> tensor<1xi8> {
+ %out = mesh.all_reduce %input on @mesh_2d mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+ return %out : tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+// CHECK-LABEL: module @static_4d_mesh_grouping_along_axes_2_1
+module @static_4d_mesh_grouping_along_axes_2_1 {
+
+ // CHECK: util.global private @_mesh_mesh_4d_axes_2_1 {inlining_policy = #util.inline.never} : !flow.channel
+ // CHECK: util.initializer {
+ // CHECK-DAG: %[[AXIS_3_SIZE:.*]] = arith.constant 6 : index
+ // CHECK-DAG: %[[AXIS_2_SIZE:.*]] = arith.constant 5 : index
+ // CHECK-DAG: %[[AXIS_1_SIZE:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[AXIS_0_SIZE:.*]] = arith.constant 3 : index
+ // CHECK-DAG: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default : !flow.channel
+ // CHECK: %[[CHANNEL_RANK:.*]] = flow.channel.rank %[[DEFAULT_CHANNEL]] : index
+ // CHECK: %[[DEVICE_MULTI_IDX:.*]]:4 = affine.delinearize_index %[[CHANNEL_RANK]] into
+ // CHECK-SAME: (%[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]], %[[AXIS_2_SIZE]], %[[AXIS_3_SIZE]]) : index, index, index, index
+ // CHECK: %[[IN_GROUP_IDX:.*]] = affine.apply
+ // CHECK-SAME: #map()[%[[DEVICE_MULTI_IDX]]#2, %[[AXIS_1_SIZE]], %[[DEVICE_MULTI_IDX]]#1]
+ // CHECK: %[[GROUP_IDX:.*]] = affine.apply
+ // CHECK-SAME: #map()[%[[DEVICE_MULTI_IDX]]#0, %[[AXIS_3_SIZE]], %[[DEVICE_MULTI_IDX]]#3]
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+ // CHECK-SAME: %[[DEFAULT_CHANNEL]], %[[GROUP_IDX]], %[[IN_GROUP_IDX]] : !flow.channel -> !flow.channel
+ // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh_4d_axes_2_1 : !flow.channel
+ mesh.mesh @mesh_4d(shape = 3x4x5x6)
+
+ func.func @f(%input : tensor<1xi8>) -> tensor<1xi8> {
+ %out = mesh.all_reduce %input on @mesh_4d mesh_axes = [2, 1] : tensor<1xi8> -> tensor<1xi8>
+ return %out : tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @multiple_different_channels
+module @multiple_different_channels {
+
+ // CHECK-DAG: util.global private @_mesh_mesh_2d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+ // CHECK-DAG: util.global private @_mesh_mesh_2d_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+ mesh.mesh @mesh_2d(shape = 3x4)
+
+ func.func @f(%input : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+ %out0 = mesh.all_reduce %input on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+ %out1 = mesh.all_reduce %input on @mesh_2d mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+ return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @same_channel_used_multiple_times
+module @same_channel_used_multiple_times {
+
+ // CHECK: util.global private @_mesh_mesh_2d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+ mesh.mesh @mesh_2d(shape = 3x4)
+
+ func.func @f(%input0 : tensor<1xi8>, %input1 : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+ %out0 = mesh.all_reduce %input0 on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+ %out1 = mesh.all_reduce %input1 on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+ return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @multiple_meshes
+module @multiple_meshes {
+
+ // CHECK: util.global private @_mesh_mesh1_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+ // CHECK: util.initializer {
+ // CHECK: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default "mesh1" : !flow.channel
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+ // CHECK-SAME: %[[DEFAULT_CHANNEL]], %{{.*}}, %{{.*}} : !flow.channel -> !flow.channel
+ // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh1_axes_0 : !flow.channel
+ mesh.mesh @mesh1(shape = 1x2)
+ // CHECK: util.global private @_mesh_mesh2_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+ // CHECK: util.initializer {
+ // CHECK: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default "mesh2" : !flow.channel
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+ // CHECK-SAME: %[[DEFAULT_CHANNEL]], %{{.*}}, %{{.*}} : !flow.channel -> !flow.channel
+ // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh2_axes_1 : !flow.channel
+ mesh.mesh @mesh2(shape = 3x4)
+
+ func.func @f(%input0 : tensor<1xi8>, %input1 : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+ %out0 = mesh.all_reduce %input0 on @mesh1 mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+ %out1 = mesh.all_reduce %input1 on @mesh2 mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+ return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir
new file mode 100644
index 0000000..60d16ab
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir
@@ -0,0 +1,143 @@
+// RUN: iree-opt --split-input-file --iree-convert-mesh-to-flow --cse %s | FileCheck %s
+
+mesh.mesh @mesh_2d(shape = 3x4)
+
+// CHECK-LABEL: func @all_gather_non_default_channel
+func.func @all_gather_non_default_channel(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xi8>
+ %arg0 : tensor<3x4xi8>) -> tensor<3x16xi8> {
+ // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_1 : !flow.channel
+ // CHECK-DAG: %[[TRANSPOSED_OPERAND_INIT_VAL:.*]] = tensor.empty() : tensor<4x3xi8>
+ // CHECK: %[[TRANSPOSED_OPERAND:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[ARG]] : tensor<3x4xi8>) outs(%[[TRANSPOSED_OPERAND_INIT_VAL]] : tensor<4x3xi8>) permutation = [1, 0]
+ // CHECK: %[[ALL_GATHER_INITIAL_VAL:.*]] = tensor.empty() : tensor<16x3xi8>
+ // CHECK: %[[ALL_GATHER_RES:.*]] = flow.collective.all_gather ui8,
+ // CHECK-SAME: %[[ALL_GATHER_INITIAL_VAL]], %[[TRANSPOSED_OPERAND]], %[[CHANNEL]]
+ // CHECK-SAME: (tensor<16x3xi8>, tensor<4x3xi8>, !flow.channel) -> %[[ALL_GATHER_INITIAL_VAL]] as tensor<16x3xi8>
+ // CHECK: %[[RES_INIT_VAL:.*]] = tensor.empty() : tensor<3x16xi8>
+ // CHECK: %[[RES:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[ALL_GATHER_RES]] : tensor<16x3xi8>) outs(%[[RES_INIT_VAL]] : tensor<3x16xi8>) permutation = [1, 0]
+ %0 = mesh.all_gather %arg0 on @mesh_2d mesh_axes = [1] gather_axis = 1
+ : tensor<3x4xi8> -> tensor<3x16xi8>
+ // CHECK: return %[[RES]] : tensor<3x16xi8>
+ return %0 : tensor<3x16xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @all_reduce_sum_default_channel
+func.func @all_reduce_sum_default_channel(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+ %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+ // CHECK: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xi8>
+ // CHECK: %[[RES:.*]] = flow.collective.all_reduce sum, ui8, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+ // CHECK-SAME: (tensor<1xi8>, tensor<1xi8>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xi8>
+ %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0]
+ : tensor<1xi8> -> tensor<1xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_2d(shape = 2x2)
+
+// CHECK-LABEL: func @all_reduce_min_non_default_channel
+func.func @all_reduce_min_non_default_channel(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+ %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_1_0 : !flow.channel
+ // CHECK-DAG: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xi8>
+ // CHECK: %[[RES:.*]] = flow.collective.all_reduce minimum, ui8, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+ // CHECK-SAME: (tensor<1xi8>, tensor<1xi8>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xi8>
+ %0 = mesh.all_reduce %arg0 on @mesh_2d mesh_axes = [1, 0] reduction = <min>
+ : tensor<1xi8> -> tensor<1xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @all_reduce_f32
+func.func @all_reduce_f32(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+ %arg0 : tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-DAG: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+ // CHECK-DAG: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xf32>
+ // CHECK: %[[RES:.*]] = flow.collective.all_reduce sum, f32, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+ // CHECK-SAME: (tensor<1xf32>, tensor<1xf32>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xf32>
+ %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0]
+ : tensor<1xf32> -> tensor<1xf32>
+ // CHECK: return %[[RES]] : tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+ // CHECK: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+ // CHECK: %[[RES:.*]] = flow.channel.rank %[[CHANNEL]] : index
+ %0 = mesh.process_linear_index on @mesh_1d : index
+ // CHECK: return %[[RES]] : index
+ return %0 : index
+}
+
+// -----
+
+mesh.mesh @mesh_3d(shape = 2x3x4)
+
+// CHECK-LABEL: func @all_to_all_non_default_channel
+func.func @all_to_all_non_default_channel(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1x12x3x4x5xf32>
+ %arg0 : tensor<1x12x3x4x5xf32>) -> tensor<1x2x3x24x5xf32> {
+ // CHECK: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_3d_axes_1_0 : !flow.channel
+ // CHECK: %[[SPLIT_AXIS_AT_FRONT:.*]] = linalg.transpose ins(%[[ARG]] : tensor<1x12x3x4x5xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<12x1x3x4x5xf32>) permutation = [1, 0, 2, 3, 4]
+ // CHECK: %[[FLOW_ALL_TO_ALL:.*]] = flow.collective.all_to_all f32, %{{.*}}, %[[SPLIT_AXIS_AT_FRONT]], %_mesh_mesh_3d_axes_1_0 :
+ // CHECK-SAME: (tensor<12x1x3x4x5xf32>, tensor<12x1x3x4x5xf32>, !flow.channel) -> %0 as tensor<12x1x3x4x5xf32>
+ // CHECK: %[[SPLIT_AXIS_BACK_IN_ITS_PLACE:.*]] = linalg.transpose ins(%1 : tensor<12x1x3x4x5xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<1x12x3x4x5xf32>) permutation = [1, 0, 2, 3, 4]
+ // CHECK: %[[SPLIT_AXIS_IS_SPLIT:.*]] = tensor.expand_shape %[[SPLIT_AXIS_BACK_IN_ITS_PLACE]]
+ // CHECK-SAME-LITERAL: [[0], [1, 2], [3], [4], [5]] : tensor<1x12x3x4x5xf32> into tensor<1x6x2x3x4x5xf32>
+ // CHECK: %[[MOVED_SPLIT_COUNT_AXIS:.*]] = linalg.transpose ins(%[[SPLIT_AXIS_IS_SPLIT]] : tensor<1x6x2x3x4x5xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<1x2x3x6x4x5xf32>) permutation = [0, 2, 3, 1, 4, 5]
+ // CHECK: %[[COLLAPSED_SPLIT_COUNT_INTO_CONCAT_AXIS:.*]] = tensor.collapse_shape %[[MOVED_SPLIT_COUNT_AXIS]]
+ // CHECK-SAME-LITERAL: [[0], [1], [2], [3, 4], [5]] : tensor<1x2x3x6x4x5xf32> into tensor<1x2x3x24x5xf32>
+ %0 = mesh.all_to_all %arg0 on @mesh_3d mesh_axes = [1, 0] split_axis = 1 concat_axis = 3
+ : tensor<1x12x3x4x5xf32> -> tensor<1x2x3x24x5xf32>
+ // CHECK: return %[[COLLAPSED_SPLIT_COUNT_INTO_CONCAT_AXIS]] : tensor<1x2x3x24x5xf32>
+ return %0 : tensor<1x2x3x24x5xf32>
+}
+
+// -----
+
+mesh.mesh @mesh_2d(shape = 2x2)
+
+// CHECK-LABEL: func @reduce_scatter_non_default_channel
+func.func @reduce_scatter_non_default_channel(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x2xi8>
+ %arg0 : tensor<3x2xi8>) -> tensor<3x1xi8> {
+ // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_0 : !flow.channel
+ // CHECK-DAG: %[[TRANSPOSED_OPERAND_INIT_VAL:.*]] = tensor.empty() : tensor<2x3xi8>
+ // CHECK: %[[TRANSPOSED_OPERAND:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[ARG]] : tensor<3x2xi8>) outs(%[[TRANSPOSED_OPERAND_INIT_VAL]] : tensor<2x3xi8>) permutation = [1, 0]
+ // CHECK: %[[REDUCE_SCATTER_INITIAL_VAL:.*]] = tensor.empty() : tensor<1x3xi8>
+ // CHECK: %[[REDUCE_SCATTER_RES:.*]] = flow.collective.reduce_scatter sum, ui8,
+ // CHECK-SAME: %[[REDUCE_SCATTER_INITIAL_VAL]], %[[TRANSPOSED_OPERAND]], %[[CHANNEL]]
+ // CHECK-SAME: (tensor<1x3xi8>, tensor<2x3xi8>, !flow.channel) -> %[[REDUCE_SCATTER_INITIAL_VAL]] as tensor<1x3xi8>
+ // CHECK: %[[RES_INIT_VAL:.*]] = tensor.empty() : tensor<3x1xi8>
+ // CHECK: %[[RES:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[REDUCE_SCATTER_RES]] : tensor<1x3xi8>) outs(%[[RES_INIT_VAL]] : tensor<3x1xi8>) permutation = [1, 0]
+ %0 = mesh.reduce_scatter %arg0 on @mesh_2d mesh_axes = [0] scatter_axis = 1
+ : tensor<3x2xi8> -> tensor<3x1xi8>
+ // CHECK: return %[[RES]] : tensor<3x1xi8>
+ return %0 : tensor<3x1xi8>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 1ce78c8..6266615 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -1947,6 +1948,16 @@
setNameFn(getResult(), "default_channel");
}
+void ChannelDefaultOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef group) {
+ ChannelDefaultOp::build(odsBuilder, odsState,
+ odsBuilder.getStringAttr(group));
+}
+
+void ChannelDefaultOp::build(OpBuilder &odsBuilder, OperationState &odsState) {
+ ChannelDefaultOp::build(odsBuilder, odsState, StringAttr());
+}
+
//===----------------------------------------------------------------------===//
// flow.channel.split
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 716a043..3f335d2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1732,6 +1732,11 @@
`:` type($result)
attr-dict-with-keyword
}];
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$group)>,
+ OpBuilder<(ins)>
+ ];
}
def FLOW_ChannelSplitOp : FLOW_Op<"channel.split", [
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
index 4e898cd..4f93ea7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
@@ -8,6 +8,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
// clang-format off: must be included after all LLVM/MLIR headers.
@@ -219,4 +220,60 @@
}
}
+std::optional<IREE::Flow::CollectiveElementType>
+convertToFlowCollectiveElementType(Type type) {
+ if (type.isa<FloatType>()) {
+ if (type.isF16()) {
+ return IREE::Flow::CollectiveElementType::Float16;
+ }
+ if (type.isBF16()) {
+ return IREE::Flow::CollectiveElementType::BFloat16;
+ }
+ if (type.isF32()) {
+ return IREE::Flow::CollectiveElementType::Float32;
+ }
+ if (type.isF64()) {
+ return IREE::Flow::CollectiveElementType::Float64;
+ }
+ } else if (type.isa<IntegerType>()) {
+ if (type.isInteger(8)) {
+ if (type.isSignedInteger()) {
+ return IREE::Flow::CollectiveElementType::Sint8;
+ }
+ return IREE::Flow::CollectiveElementType::Uint8;
+ }
+ if (type.isInteger(16)) {
+ if (type.isSignedInteger()) {
+ return IREE::Flow::CollectiveElementType::Sint16;
+ }
+ return IREE::Flow::CollectiveElementType::Uint16;
+ }
+ if (type.isInteger(32)) {
+ if (type.isSignedInteger()) {
+ return IREE::Flow::CollectiveElementType::Sint32;
+ }
+ return IREE::Flow::CollectiveElementType::Uint32;
+ }
+ if (type.isInteger(64)) {
+ if (type.isSignedInteger()) {
+ return IREE::Flow::CollectiveElementType::Sint64;
+ }
+ return IREE::Flow::CollectiveElementType::Uint64;
+ }
+ }
+
+ return std::nullopt;
+}
+
+IREE::Flow::CollectiveElementTypeAttr
+getCollectiveElementTypeAttr(RankedTensorType type) {
+ std::optional<IREE::Flow::CollectiveElementType> collectiveElemType =
+ convertToFlowCollectiveElementType(type.getElementType());
+ if (!collectiveElemType) {
+ return IREE::Flow::CollectiveElementTypeAttr();
+ }
+ return IREE::Flow::CollectiveElementTypeAttr::get(type.getContext(),
+ *collectiveElemType);
+}
+
} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
index 80151c5..3cc34ff 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
@@ -171,4 +171,18 @@
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h.inc" // IWYU pragma: keep
// clang-format on
+namespace mlir::iree_compiler::IREE::Flow {
+
+// Create an attribute corresponding to the underlying numeric element type.
+// If there no such correspondence a null attribute is returned.
+IREE::Flow::CollectiveElementTypeAttr
+getCollectiveElementTypeAttr(RankedTensorType type);
+
+// Convert the numeric type `type` to the corresponding enum value.
+// If there is not correspondence nullopt is returned.
+std::optional<IREE::Flow::CollectiveElementType>
+convertToFlowCollectiveElementType(Type type);
+
+} // namespace mlir::iree_compiler::IREE::Flow
+
#endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index a4600af..a7a9e30 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -66,6 +66,7 @@
"//runtime/src/iree/schemas/instruments:dispatch_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
+ "@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ControlFlowDialect",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 4a94132..01381be 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -42,6 +42,7 @@
::PassesIncGen
LLVMSupport
MLIRAffineToStandard
+ MLIRAffineTransforms
MLIRArithDialect
MLIRBufferizationDialect
MLIRControlFlowDialect
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 7146c1d..a732114 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
@@ -405,6 +406,10 @@
FunctionLikeNest(passManager)
.addPass(IREE::HAL::createElideRedundantCommandsPass);
+ // TODO: Maybe this should be a part of Affine lowering pass.
+ // Remove if it is added there.
+ // https://github.com/llvm/llvm-project/issues/78458
+ passManager.addPass(affine::createAffineExpandIndexOpsPass());
// Fixup workgroup count calculations that may have used the affine dialect.
// Kind of random here but can happen if the benchmarking code does things.
passManager.addPass(mlir::createLowerAffinePass());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
index f7c94d6..0484ee7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
@@ -23,6 +23,7 @@
deps = [
"//compiler/src/iree/compiler/Dialect/Stream/Conversion",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Util/Conversion",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
index f8e7852..ff93dbd 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
@@ -24,6 +24,7 @@
MLIRTransforms
iree::compiler::Dialect::Stream::Conversion
iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index c29fa9a..c73549f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -8,7 +8,9 @@
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -226,6 +228,10 @@
patterns
.insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>(
expansionState, typeConverter, context);
+ patterns.add<GenericConvertTypesPattern<IREE::Util::GlobalOp>,
+ GenericConvertTypesPattern<IREE::Util::GlobalLoadOp>,
+ GenericConvertTypesPattern<IREE::Util::GlobalStoreOp>>(
+ typeConverter, context);
}
void populateUtilToStreamConversionPatterns(MLIRContext *context,
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
index 6971f1e..0614d22 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
@@ -7,7 +7,11 @@
#ifndef IREE_COMPILER_DIALECT_UTIL_CONVERSION_CONVERSIONPATTERNS_H_
#define IREE_COMPILER_DIALECT_UTIL_CONVERSION_CONVERSIONPATTERNS_H_
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
@@ -18,7 +22,7 @@
LogicalResult
matchAndRewrite(T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type> resultTypes;
+ SmallVector<Type> newResultTypes;
for (auto oldType : op.getOperation()->getResultTypes()) {
SmallVector<Type> newTypes;
if (failed(this->getTypeConverter()->convertType(oldType, newTypes))) {
@@ -26,13 +30,45 @@
}
// TODO(benvanik): figure out this silly expansion stuff. Seems broken.
// resultTypes.append(newTypes);
- resultTypes.push_back(newTypes.front());
+ newResultTypes.push_back(newTypes.front());
}
- auto newOp = rewriter.create<T>(op.getLoc(), resultTypes,
- adaptor.getOperands(), op->getAttrs());
+
+ SmallVector<NamedAttribute> newAttrs;
+ if (failed(convertTypeAttributes(op->getAttrs(), newAttrs))) {
+ return rewriter.notifyMatchFailure(op,
+ "failed converting type attributes");
+ }
+
+ if (newResultTypes == op->getResultTypes() &&
+ op->getOperands() == adaptor.getOperands() &&
+ newAttrs == op->getAttrs()) {
+ return rewriter.notifyMatchFailure(op, "op does not need transformation");
+ }
+
+ auto newOp = rewriter.create<T>(op.getLoc(), newResultTypes,
+ adaptor.getOperands(), newAttrs);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
+
+protected:
+ LogicalResult convertTypeAttributes(ArrayRef<NamedAttribute> attrs,
+ SmallVector<NamedAttribute> &res) const {
+ for (NamedAttribute attr : attrs) {
+ TypeAttr oldType = attr.getValue().dyn_cast<TypeAttr>();
+ if (!oldType) {
+ res.push_back(attr);
+ continue;
+ }
+
+ Type newType = this->getTypeConverter()->convertType(oldType.getValue());
+ if (!newType) {
+ return failure();
+ }
+ res.push_back(NamedAttribute(attr.getName(), TypeAttr::get(newType)));
+ }
+ return success();
+ }
};
template <typename OpT>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 6113201..823558d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -772,7 +772,6 @@
custom<TypeOrAttr>($type, $initial_value)
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"StringRef":$name,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 806cd54..508ead4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -79,6 +79,10 @@
FunctionLikeNest(passManager)
.addPass(mlir::createLoopInvariantCodeMotionPass)
.addPass(mlir::createConvertSCFToCFPass)
+ // TODO: Maybe this should be a part of Affine lowering pass.
+ // Remove if it is added there.
+ // https://github.com/llvm/llvm-project/issues/78458
+ .addPass(affine::createAffineExpandIndexOpsPass)
.addPass(mlir::createLowerAffinePass)
.addPass(mlir::arith::createArithUnsignedWhenEquivalentPass);
diff --git a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
index 043bfcf..c214d54 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
@@ -59,6 +59,7 @@
deps = [
":PassHeaders",
":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
index 6f1716c..1f1dc4f 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -67,6 +67,7 @@
MLIRTensorDialect
MLIRTensorUtils
MLIRTransforms
+ iree::compiler::Dialect::Flow::Conversion::MeshToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp b/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
index 08728c0..f4750b7 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/InputConversion/Common/Passes.h"
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
@@ -22,6 +23,7 @@
passManager.addPass(createIREEImportPublicPass());
passManager.addPass(createImportMLProgramPass());
passManager.addPass(createSanitizeModuleNamesPass());
+ passManager.addPass(IREE::Flow::createConvertMeshToFlowPass());
}
void registerCommonInputConversionPasses() {
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 637a66b..2324a70 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -37,6 +37,7 @@
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces",
"//compiler/src/iree/compiler/ConstEval",
+ "//compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h
index d9849da..93816cb 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h
@@ -18,6 +18,7 @@
#include "iree/compiler/Bindings/Native/Transforms/Passes.h"
#include "iree/compiler/Bindings/TFLite/Transforms/Passes.h"
#include "iree/compiler/ConstEval/Passes.h"
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
@@ -53,6 +54,7 @@
GlobalOptimization::registerGlobalOptimizationPipeline();
Preprocessing::registerPreprocessingPasses();
IREE::Flow::registerFlowPasses();
+ IREE::Flow::registerMeshToFlowPasses();
IREE::HAL::registerHALPasses();
IREE::HAL::Inline::registerHALInlinePasses();
IREE::HAL::Loader::registerHALLoaderPasses();
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index ddc310d..873a311 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -31,6 +31,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
@@ -67,6 +68,7 @@
LLVM::LLVMDialect,
linalg::LinalgDialect,
math::MathDialect,
+ mesh::MeshDialect,
memref::MemRefDialect,
ml_program::MLProgramDialect,
pdl::PDLDialect,
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index b3e1de8..4ebbdaf 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -54,8 +54,6 @@
// Affine
affine::registerAffinePasses();
- affine::registerAffineLoopFusionPass();
- affine::registerAffinePipelineDataTransferPass();
registerConvertAffineToStandardPass();
// Arm SME
diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel
index 6b8ef5e..c094aa1 100644
--- a/compiler/src/iree/compiler/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Utils/BUILD.bazel
@@ -33,11 +33,16 @@
"ElementPackingUtils.h",
"EquivalenceUtils.h",
"FlatbufferUtils.h",
+ "Folding.h",
"IndexSet.h",
+ "Indexing.h",
"ModuleUtils.h",
+ "OpVisitor.h",
"OptionUtils.h",
"PassUtils.h",
"PatternUtils.h",
+ "Permutation.h",
+ "SmallVectorDenseMapInfo.h",
"StringUtils.h",
"ToolUtils.h",
"TracingUtils.h",
diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt
index 01939cb..042552a 100644
--- a/compiler/src/iree/compiler/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt
@@ -18,11 +18,16 @@
"ElementPackingUtils.h"
"EquivalenceUtils.h"
"FlatbufferUtils.h"
+ "Folding.h"
"IndexSet.h"
+ "Indexing.h"
"ModuleUtils.h"
+ "OpVisitor.h"
"OptionUtils.h"
"PassUtils.h"
"PatternUtils.h"
+ "Permutation.h"
+ "SmallVectorDenseMapInfo.h"
"StringUtils.h"
"ToolUtils.h"
"TracingUtils.h"
diff --git a/compiler/src/iree/compiler/Utils/Folding.h b/compiler/src/iree/compiler/Utils/Folding.h
new file mode 100644
index 0000000..397f983
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Folding.h
@@ -0,0 +1,32 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_FOLDING_H_
+#define IREE_COMPILER_UTILS_FOLDING_H_
+
+#include <iterator>
+#include <utility>
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/OpDefinition.h"
+namespace mlir::iree_compiler {
+
+// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`.
+template <typename Range, typename OutIt>
+void toOpFoldResults(Range &&range, OutIt outIt) {
+ llvm::transform(std::forward<Range>(range), outIt,
+ [](auto v) { return OpFoldResult(v); });
+}
+
+template <typename Range>
+SmallVector<OpFoldResult> toOpFoldResults(Range &&range) {
+ SmallVector<OpFoldResult> res;
+ toOpFoldResults(std::forward<Range>(range), std::back_inserter(res));
+ return res;
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_FOLDING_H_
diff --git a/compiler/src/iree/compiler/Utils/Indexing.h b/compiler/src/iree/compiler/Utils/Indexing.h
new file mode 100644
index 0000000..58484e8
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Indexing.h
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_INDEXING_H_
+#define IREE_COMPILER_UTILS_INDEXING_H_
+
+#include <algorithm>
+#include <iterator>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler {
+
+// Construct IR that extracts the linear index form a multi-index according to
+// a shape.
+inline OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == shape.size());
+ SmallVector<AffineExpr> shapeAffine;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+ }
+
+ SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(stridesAffine.size());
+ llvm::transform(stridesAffine, std::back_inserter(strides),
+ [&builder, &shape](AffineExpr strideExpr) {
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), strideExpr, shape);
+ });
+
+ auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+ OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_INDEXING_H_
diff --git a/compiler/src/iree/compiler/Utils/OpVisitor.h b/compiler/src/iree/compiler/Utils/OpVisitor.h
new file mode 100644
index 0000000..247bc7a
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/OpVisitor.h
@@ -0,0 +1,109 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <functional>
+#include <type_traits>
+#include <utility>
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
+
+namespace mlir::iree_compiler {
+
+// Calls a collection of callbacks for an operation.
+// For concrete ops each callback is called only if its concrete op type
+// matches the given operation.
+//
+// Operation* op = ...;
+// OpVisitorCollection visitors = ...;
+// visitors.emplaceVisitors<MyVisitor1, MyVisitor1>(
+// visitorConstructorArg1,
+// visitorConstructorArg2);
+// visitors.insertVisitors(
+// [](ConcreteOp op) { ... }, // Call only for ConcreteOp
+// [](Operation* op) { ... } // Call for all op types
+// );
+// visitors(op);
+struct OpVisitorCollection {
+ void operator()(Operation *op) {
+ for (auto &fn : everyOpFns) {
+ fn(op);
+ }
+
+ auto it = opFnMap.find(op->getName().getTypeID());
+ if (it == opFnMap.end()) {
+ return;
+ }
+
+ for (auto &fn : it->second) {
+ fn(op);
+ }
+ }
+
+ template <typename Op, typename Fn>
+ void insertVisitor(Fn &&fn) {
+ opFnMap[TypeID::get<Op>()].emplace_back(
+ ConcreteOpFn<Op, Fn>(std::forward<Fn>(fn)));
+ }
+
+ template <typename Fn,
+ typename = std::enable_if_t<!std::is_invocable_v<Fn, Operation *>>>
+ void insertVisitor(Fn &&fn) {
+ insertVisitor<
+ std::decay_t<typename llvm::function_traits<Fn>::template arg_t<0>>>(
+ std::forward<Fn>(fn));
+ }
+
+ template <typename Fn,
+ typename = std::enable_if_t<std::is_invocable_v<Fn, Operation *>>,
+ typename = void>
+ void insertVisitor(Fn &&fn) {
+ everyOpFns.emplace_back(std::forward<Fn>(fn));
+ }
+
+ template <typename Fn, typename... RestFns>
+ void insertVisitors(Fn &&fn, RestFns &&...restFns) {
+ insertVisitor(fn);
+ insertVisitors(std::forward<RestFns>(restFns)...);
+ }
+
+ template <typename Fn>
+ void insertVisitors(Fn &&fn) {
+ insertVisitor(fn);
+ }
+
+ template <typename... Fns, typename... ConstructorArgs>
+ void emplaceVisitors(ConstructorArgs &&...args) {
+ (emplaceVisitor<Fns>(std::forward<ConstructorArgs>(args)...), ...);
+ }
+
+private:
+ template <typename Fn, typename... ConstructorArgs>
+ void emplaceVisitor(ConstructorArgs &&...args) {
+ insertVisitor(Fn(std::forward<ConstructorArgs>(args)...));
+ }
+
+ template <typename Op, typename Fn>
+ struct ConcreteOpFn {
+
+ template <typename FnArg>
+ ConcreteOpFn(FnArg &&fn) : fn(fn) {}
+
+ void operator()(Operation *op) { fn(llvm::cast<Op>(op)); }
+
+ private:
+ Fn fn;
+ };
+
+ DenseMap<TypeID, SmallVector<std::function<void(Operation *)>>> opFnMap;
+ SmallVector<std::function<void(Operation *)>> everyOpFns;
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Utils/Permutation.h b/compiler/src/iree/compiler/Utils/Permutation.h
new file mode 100644
index 0000000..0cd3df1
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Permutation.h
@@ -0,0 +1,107 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_PERMUTATION_H_
+#define IREE_COMPILER_UTILS_PERMUTATION_H_
+
+#include <iterator>
+#include <type_traits>
+#include <utility>
+
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler {
+
+// Example: values = (1, 2, 3), permutation = (2, 1, 0)
+// output = (3, 2, 1).
+// TODO: make applyPermutation at mlir/Dialect/Utils/IndexingUtils.h in MLIR
+// generic and use it instead.
+template <typename ValuesIt, typename PermutationRange, typename OutIt>
+void permute(ValuesIt valuesBegin, ValuesIt valuesEnd,
+ PermutationRange &&permutation, OutIt outBegin) {
+ assert(std::distance(valuesBegin, valuesEnd) >= llvm::adl_size(permutation));
+ llvm::transform(permutation, outBegin,
+ [valuesBegin](auto i) { return valuesBegin[i]; });
+}
+
+template <typename ValuesRange, typename PermutationRange, typename OutIt>
+void permute(ValuesRange &&values, PermutationRange &&permutation,
+ OutIt outBegin) {
+ permute(llvm::adl_begin(std::forward<ValuesRange>(values)),
+ llvm::adl_end(std::forward<ValuesRange>(values)), permutation,
+ outBegin);
+}
+
+template <typename T, typename Index>
+SmallVector<T> permute(ArrayRef<T> values, ArrayRef<Index> permutation) {
+ SmallVector<T> res;
+ permute(values, permutation, std::back_inserter(res));
+ return res;
+}
+
+// Check if the range is a sequence of numbers starting from 0.
+// Example: (0, 1, 2, 3).
+// TODO: Make the isIdentityPermutation in MLIR more generic to not only
+// accept int64_t and delete this.
+template <typename Range>
+bool isIdentityPermutation(Range &&range) {
+ using ValueType = std::decay_t<decltype(*std::begin(range))>;
+ ValueType i = static_cast<ValueType>(0);
+ return llvm::all_of(std::forward<Range>(range), [&i](ValueType v) {
+ bool res = (v == i);
+ ++i;
+ return res;
+ });
+}
+
+// Make a permutation that moves src to dst.
+// Example with size = 5, src = 1, dst = 3.
+// output = (0, 2, 3, 1, 4).
+// Example with size = 2, src = 0, dst = 1.
+// output = (1, 0).
+template <typename T, typename OutIt>
+void makeMovePermutation(T size, T src, T dst, OutIt outBegin) {
+ assert(src < size && dst < size && size > static_cast<T>(0));
+ T outSize = 0;
+ for (T i = 0; i < size; ++i) {
+ if (outSize == dst) {
+ *outBegin = src;
+ ++outBegin;
+ ++outSize;
+ }
+ if (i == src) {
+ ++i;
+ if (i >= size) {
+ break;
+ }
+ }
+
+ *outBegin = i;
+ ++outBegin;
+ ++outSize;
+ }
+
+ if (size != outSize) {
+ *outBegin = src;
+ ++outBegin;
+ }
+}
+
+template <typename T>
+SmallVector<T> makeMovePermutation(T size, T src, T dst) {
+ SmallVector<T> res;
+ res.reserve(size);
+ makeMovePermutation(size, src, dst, std::back_inserter(res));
+ return res;
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_PERMUTATION_H_
diff --git a/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h b/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h
new file mode 100644
index 0000000..6bf3f79
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h
@@ -0,0 +1,52 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
+#define IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
+
+#include <numeric>
+#include <utility>
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace llvm {
+template <typename T, unsigned N>
+struct DenseMapInfo<SmallVector<T, N>> {
+ static SmallVector<T, N> getEmptyKey() {
+ return SmallVector<T, N>(1, llvm::DenseMapInfo<T>::getEmptyKey());
+ }
+
+ static SmallVector<T, N> getTombstoneKey() {
+ return SmallVector<T, N>(1, llvm::DenseMapInfo<T>::getTombstoneKey());
+ }
+
+ static unsigned getHashValue(const SmallVector<T, N> &v) {
+ hash_code hash = llvm::DenseMapInfo<T>::getHashValue(
+ llvm::DenseMapInfo<T>::getEmptyKey());
+ std::accumulate(v.begin(), v.end(), hash,
+ [](hash_code hash, const T &element) {
+ return hash_combine(hash, element);
+ });
+ return hash;
+ }
+
+ static bool isEqual(const SmallVector<T, N> &lhs,
+ const SmallVector<T, N> &rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+
+ return llvm::all_of_zip(lhs, rhs, [](const T &lhs, const T &rhs) {
+ return DenseMapInfo<T>::isEqual(lhs, rhs);
+ });
+ }
+};
+} // namespace llvm
+
+#endif // IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
diff --git a/compiler/src/iree/compiler/Utils/test/BUILD.bazel b/compiler/src/iree/compiler/Utils/test/BUILD.bazel
new file mode 100644
index 0000000..f48694f
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/BUILD.bazel
@@ -0,0 +1,22 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_test")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_test(
+ name = "utils",
+ testonly = True,
+ srcs = ["UtilsTest.cpp"],
+ deps = [
+ "//compiler/src/iree/compiler/Utils",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Utils/test/CMakeLists.txt b/compiler/src/iree/compiler/Utils/test/CMakeLists.txt
new file mode 100644
index 0000000..7c26dca
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Utils/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_test(
+ NAME
+ utils
+ SRCS
+ "UtilsTest.cpp"
+ DEPS
+ gmock
+ gtest
+ iree::compiler::Utils
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp b/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp
new file mode 100644
index 0000000..44810ca
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp
@@ -0,0 +1,26 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "iree/compiler/Utils/Permutation.h"
+
+using namespace mlir::iree_compiler;
+using namespace testing;
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+TEST(Permutation, MakeMovePermutation) {
+ EXPECT_THAT(makeMovePermutation(1, 0, 0), ElementsAre(0));
+ EXPECT_THAT(makeMovePermutation(2, 0, 1), ElementsAre(1, 0));
+ EXPECT_THAT(makeMovePermutation(5, 1, 3), ElementsAre(0, 2, 3, 1, 4));
+ EXPECT_THAT(makeMovePermutation(3, 1, 2), ElementsAre(0, 2, 1));
+ EXPECT_THAT(makeMovePermutation(3, 2, 0), ElementsAre(2, 0, 1));
+}
diff --git a/tests/e2e/collectives/CMakeLists.txt b/tests/e2e/collectives/CMakeLists.txt
index c98f3d9..b0d1ab1 100644
--- a/tests/e2e/collectives/CMakeLists.txt
+++ b/tests/e2e/collectives/CMakeLists.txt
@@ -13,7 +13,7 @@
if(IREE_TARGET_BACKEND_CUDA AND IREE_HAL_DRIVER_CUDA)
iree_py_test(
NAME
- collectives_test_single_gpu
+ collectives_test_1_gpu
SRCS
"collectives_test.py"
ARGS
@@ -27,7 +27,7 @@
iree_py_test(
NAME
- collectives_test_multi_gpu
+ collectives_test_2_gpus
SRCS
"collectives_test.py"
ARGS
@@ -41,4 +41,19 @@
# To properly test for 2 ranks we need 2 GPUs.
"requires-multiple-devices"
)
+
+ iree_py_test(
+ NAME
+ collectives_test_4_gpus
+ SRCS
+ "collectives_test.py"
+ ARGS
+ "-k" "FourRanks"
+ "--target_backend=cuda"
+ "--driver=cuda"
+ LABELS
+ "requires-gpu-nvidia"
+ "driver=cuda"
+ "requires-multiple-devices"
+ )
endif()
diff --git a/tests/e2e/collectives/collectives_test.py b/tests/e2e/collectives/collectives_test.py
index a28e082..fa995f6 100644
--- a/tests/e2e/collectives/collectives_test.py
+++ b/tests/e2e/collectives/collectives_test.py
@@ -11,13 +11,13 @@
import iree.runtime
from iree.runtime.array_interop import DeviceArray
import os
-from typing import List, Tuple, TypeVar
+from typing import List, Tuple
import numpy as np
import tempfile
import subprocess
import test_utils
-ArrayLike = TypeVar("ArrayLike")
+ArrayLike = object
def parse_args():
@@ -92,7 +92,10 @@
def run_test(
- mlir: str, inputs: List[List[ArrayLike]], expected_outputs: List[List[ArrayLike]]
+ mlir: str,
+ inputs: List[List[ArrayLike]],
+ expected_outputs: List[List[ArrayLike]],
+ mlir_input_type: iree.compiler.InputType | str = iree.compiler.InputType.AUTO,
):
with tempfile.TemporaryDirectory() as tmp_dir:
module_filepath = os.path.join(tmp_dir, "module.vmfb")
@@ -100,14 +103,15 @@
input_str=mlir,
output_file=module_filepath,
target_backends=[args.target_backend],
- input_type="stablehlo",
+ input_type=mlir_input_type,
+ extra_args=["--iree-hal-cuda-llvm-target-arch", "sm_53"],
)
num_ranks = len(inputs)
# Ranks on the 0th axis.
outputs = run_ranks(
num_ranks=num_ranks,
- function="all_reduce_sum",
+ function="main",
driver=args.driver,
module_filepath=module_filepath,
inputs=inputs,
@@ -125,7 +129,7 @@
all_reduce([1, 2, 3, 4]) == [1, 2, 3, 4].
"""
stablehlo_mlir = """
- func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> {
+ func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
%out = "stablehlo.all_reduce"(%input) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%sum = stablehlo.add %arg0, %arg1 : tensor<f32>
@@ -138,7 +142,60 @@
"""
inputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
expected_outputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
- run_test(mlir=stablehlo_mlir, inputs=inputs, expected_outputs=expected_outputs)
+ run_test(
+ mlir=stablehlo_mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ mlir_input_type=iree.compiler.InputType.STABLEHLO,
+ )
+
+ def test_mesh_all_reduce(self):
+ """
+ Test trivial case of all_reduce with one rank.
+ all_reduce([1, 2, 3, 4]) == [1, 2, 3, 4].
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 1)
+
+ func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<4xf32> -> tensor<4xf32>
+ return %out : tensor<4xf32>
+ }
+ """
+ inputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
+ expected_outputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
+ run_test(mlir=mlir, inputs=inputs, expected_outputs=expected_outputs)
+
+ def test_mesh_all_to_all(self):
+ """
+ Test on a 1D device mesh, grouping along mesh dimension 0.
+
+ Device contents before operation:
+ [[1, 2], [3, 4]]
+
+ Device contents after operation:
+ [[1, 2], [3, 4]]
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 1)
+
+ func.func @main(%input : tensor<2x2xf32>) -> tensor<2x2xf32> {
+ %out = mesh.all_to_all %input on @mesh mesh_axes = [0]
+ split_axis = 0 concat_axis = 1 : tensor<2x2xf32> -> tensor<2x2xf32>
+ return %out : tensor<2x2xf32>
+ }
+ """
+ inputs = [
+ [np.array([[1, 2], [3, 4]], dtype=np.float32)],
+ ]
+ expected_outputs = [
+ [np.array([[1, 2], [3, 4]], dtype=np.float32)],
+ ]
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
class TwoRanks(unittest.TestCase):
@@ -147,7 +204,7 @@
Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
"""
stablehlo_mlir = """
- func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> {
+ func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
%out = "stablehlo.all_reduce"(%input) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%sum = stablehlo.add %arg0, %arg1 : tensor<f32>
@@ -162,8 +219,234 @@
[np.array([1, 2, 3, 4], dtype=np.float32)],
[np.array([5, 6, 7, 8], dtype=np.float32)],
]
- expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]]
- run_test(mlir=stablehlo_mlir, inputs=inputs, expected_outputs=expected_outputs)
+ expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+ run_test(
+ mlir=stablehlo_mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ mlir_input_type=iree.compiler.InputType.STABLEHLO,
+ )
+
+ def test_mesh_all_reduce_1d_mesh(self):
+ """
+ Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 2)
+
+ func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<4xf32> -> tensor<4xf32>
+ return %out : tensor<4xf32>
+ }
+ """
+ inputs = [
+ [np.array([1, 2, 3, 4], dtype=np.float32)],
+ [np.array([5, 6, 7, 8], dtype=np.float32)],
+ ]
+ expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
+
+ def test_mesh_all_reduce_3d_mesh(self):
+ """
+ Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 1x2x1)
+
+ func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<4xf32> -> tensor<4xf32>
+ return %out : tensor<4xf32>
+ }
+ """
+ inputs = [
+ [np.array([1, 2, 3, 4], dtype=np.float32)],
+ [np.array([5, 6, 7, 8], dtype=np.float32)],
+ ]
+ expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
+
+
+class FourRanks(unittest.TestCase):
+ def test_mesh_all_reduce_on_2d_mesh_along_axis_1(self):
+ """
+ Test on a 2x2 device mesh reduction along dimension 1.
+ Mesh devices:
+ axis 1
+ ------>
+ 0 1
+ 2 3
+
+ Device contents before operation:
+ [1, 2] [3, 4]
+ [5, 6] [7, 8]
+
+ Device contents after operation:
+ [ 4, 6] [ 4, 6]
+ [12, 14] [12, 14]
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 2x2)
+
+ func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<2xf32> -> tensor<2xf32>
+ return %out : tensor<2xf32>
+ }
+ """
+ inputs = [
+ [np.array([1, 2], dtype=np.float32)],
+ [np.array([3, 4], dtype=np.float32)],
+ [np.array([5, 6], dtype=np.float32)],
+ [np.array([7, 8], dtype=np.float32)],
+ ]
+ expected_outputs = [
+ [np.array([4, 6], dtype=np.float32)],
+ [np.array([4, 6], dtype=np.float32)],
+ [np.array([12, 14], dtype=np.float32)],
+ [np.array([12, 14], dtype=np.float32)],
+ ]
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
+
+ def test_mesh_all_reduce_on_2d_mesh_along_axis_0(self):
+ """
+ Test on a 2x2 device mesh reduction along dimension 0.
+ Mesh devices:
+ axis 1
+ ------>
+ 0 1
+ 2 3
+
+ Device contents before operation:
+ [1, 2] [3, 4]
+ [5, 6] [7, 8]
+
+ Device contents after operation:
+ [6, 8] [10, 12]
+ [6, 8] [10, 12]
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 2x2)
+
+ func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<2xf32> -> tensor<2xf32>
+ return %out : tensor<2xf32>
+ }
+ """
+ inputs = [
+ [np.array([1, 2], dtype=np.float32)],
+ [np.array([3, 4], dtype=np.float32)],
+ [np.array([5, 6], dtype=np.float32)],
+ [np.array([7, 8], dtype=np.float32)],
+ ]
+ expected_outputs = [
+ [np.array([6, 8], dtype=np.float32)],
+ [np.array([10, 12], dtype=np.float32)],
+ [np.array([6, 8], dtype=np.float32)],
+ [np.array([10, 12], dtype=np.float32)],
+ ]
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
+
+ def test_mesh_all_reduce_on_4d_mesh_along_1_axis(self):
+ """
+ Test on a 1x2x1x2 device mesh reduction along mesh dimension 1.
+ Mesh devices:
+ axis 3
+ ------>
+ 0 1 | axis 1
+ 2 3 ↓
+
+ Device contents before operation:
+ [1, 2] [3, 4]
+ [5, 6] [7, 8]
+
+ Device contents after operation:
+ [6, 8] [10, 12]
+ [6, 8] [10, 12]
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 1x2x1x2)
+
+ func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+ %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<2xf32> -> tensor<2xf32>
+ return %out : tensor<2xf32>
+ }
+ """
+ inputs = [
+ [np.array([1, 2], dtype=np.float32)],
+ [np.array([3, 4], dtype=np.float32)],
+ [np.array([5, 6], dtype=np.float32)],
+ [np.array([7, 8], dtype=np.float32)],
+ ]
+ expected_outputs = [
+ [np.array([6, 8], dtype=np.float32)],
+ [np.array([10, 12], dtype=np.float32)],
+ [np.array([6, 8], dtype=np.float32)],
+ [np.array([10, 12], dtype=np.float32)],
+ ]
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
+
+ def test_mesh_all_to_all_on_4d_mesh_along_1_axis(self):
+ """
+ Test on a 1x2x1x2 device mesh, grouping along mesh dimension 1.
+ Mesh devices:
+ axis 3
+ ------>
+ 0 1 | axis 1
+ 2 3 ↓
+
+ Device contents before operation:
+ [[1], [2]] [[3], [4]]
+ [[5], [6]] [[7], [8]]
+
+ Device contents after operation:
+ [[1, 5]] [[3, 7]]
+ [[2, 6]] [[4, 8]]
+ """
+ mlir = """
+ mesh.mesh @mesh(shape = 1x2x1x2)
+
+ func.func @main(%input : tensor<2x1xf32>) -> tensor<1x2xf32> {
+ %out = mesh.all_to_all %input on @mesh mesh_axes = [1]
+ split_axis = 0 concat_axis = 1 : tensor<2x1xf32> -> tensor<1x2xf32>
+ return %out : tensor<1x2xf32>
+ }
+ """
+ inputs = [
+ [np.array([[1], [2]], dtype=np.float32)],
+ [np.array([[3], [4]], dtype=np.float32)],
+ [np.array([[5], [6]], dtype=np.float32)],
+ [np.array([[7], [8]], dtype=np.float32)],
+ ]
+ expected_outputs = [
+ [np.array([[1, 5]], dtype=np.float32)],
+ [np.array([[3, 7]], dtype=np.float32)],
+ [np.array([[2, 6]], dtype=np.float32)],
+ [np.array([[4, 8]], dtype=np.float32)],
+ ]
+ run_test(
+ mlir=mlir,
+ inputs=inputs,
+ expected_outputs=expected_outputs,
+ )
if __name__ == "__main__":