[compiler] Mesh -> IREE conversion (#16199)

Add conversion of the upstream MLIR Mesh Dialect to IREE Flow. Handle
global collective channel construction through to the runtime.

The supported operations are
* retrieval of rank/process index
* all-gather
* all-reduce
* all-to-all
* reduce-scatter

Dynamic device meshes are not supported.
Some operations do not support dynamic tensor dimensions.
diff --git a/build_tools/bazel/build_test_all.sh b/build_tools/bazel/build_test_all.sh
index 8558b36..045e20b 100755
--- a/build_tools/bazel/build_test_all.sh
+++ b/build_tools/bazel/build_test_all.sh
@@ -30,6 +30,7 @@
 IREE_READ_REMOTE_BAZEL_CACHE="${IREE_READ_REMOTE_BAZEL_CACHE:-1}"
 IREE_WRITE_REMOTE_BAZEL_CACHE="${IREE_WRITE_REMOTE_BAZEL_CACHE:-0}"
 BAZEL_BIN="${BAZEL_BIN:-$(which bazel)}"
+SANDBOX_BASE="${SANDBOX_BASE:-}"
 
 if (( ${IREE_WRITE_REMOTE_BAZEL_CACHE} == 1 && ${IREE_READ_REMOTE_BAZEL_CACHE} != 1 )); then
   echo "Can't have 'IREE_WRITE_REMOTE_BAZEL_CACHE' (${IREE_WRITE_REMOTE_BAZEL_CACHE}) set without 'IREE_READ_REMOTE_BAZEL_CACHE' (${IREE_READ_REMOTE_BAZEL_CACHE})"
diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
index 2d5726d..65851fa 100644
--- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
+++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp
@@ -45,55 +45,6 @@
 // mode combinations for cross-replica and cross partition communication. See
 // the stablehlo specification for more details about the different modes.
 
-static std::optional<IREE::Flow::CollectiveElementType>
-convertToFlowCollectiveElementType(Type type) {
-  if (type.isF32()) {
-    return IREE::Flow::CollectiveElementType::Float32;
-  }
-
-  if (type.isInteger(32)) {
-    if (type.isSignedInteger()) {
-      return IREE::Flow::CollectiveElementType::Sint32;
-    }
-    return IREE::Flow::CollectiveElementType::Uint32;
-  }
-
-  if (type.isF16()) {
-    return IREE::Flow::CollectiveElementType::Float16;
-  }
-
-  if (type.isInteger(8)) {
-    if (type.isSignedInteger()) {
-      return IREE::Flow::CollectiveElementType::Sint8;
-    }
-    return IREE::Flow::CollectiveElementType::Uint8;
-  }
-
-  if (type.isInteger(16)) {
-    if (type.isSignedInteger()) {
-      return IREE::Flow::CollectiveElementType::Sint16;
-    }
-    return IREE::Flow::CollectiveElementType::Uint16;
-  }
-
-  if (type.isBF16()) {
-    return IREE::Flow::CollectiveElementType::BFloat16;
-  }
-
-  if (type.isF64()) {
-    return IREE::Flow::CollectiveElementType::Float64;
-  }
-
-  if (type.isInteger(64)) {
-    if (type.isSignedInteger()) {
-      return IREE::Flow::CollectiveElementType::Sint64;
-    }
-    return IREE::Flow::CollectiveElementType::Uint64;
-  }
-
-  return std::nullopt;
-}
-
 static std::optional<IREE::Flow::CollectiveReductionOp>
 convertToFlowCollectiveReductionOp(const Operation &op) {
   if (isa<mlir::stablehlo::AddOp>(op)) {
@@ -113,17 +64,6 @@
   return std::nullopt;
 }
 
-static IREE::Flow::CollectiveElementTypeAttr
-getCollectiveElementTypeAttr(MLIRContext *context, RankedTensorType type) {
-  std::optional<IREE::Flow::CollectiveElementType> collectiveElemType =
-      convertToFlowCollectiveElementType(type.getElementType());
-  if (!collectiveElemType) {
-    return IREE::Flow::CollectiveElementTypeAttr();
-  }
-  return IREE::Flow::CollectiveElementTypeAttr::get(context,
-                                                    *collectiveElemType);
-}
-
 template <typename T>
 static LogicalResult checkCollectiveAttrs(T op, PatternRewriter &rewriter) {
   // Note that the channel handle attribute consists of two 64-bit values,
@@ -553,7 +493,7 @@
     // Get the collective element type attribute.
     auto resultType = cast<RankedTensorType>(op.getResult().getType());
     IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
-        getCollectiveElementTypeAttr(op.getContext(), resultType);
+        IREE::Flow::getCollectiveElementTypeAttr(resultType);
     if (!elementTypeAttr) {
       return rewriter.notifyMatchFailure(
           op, "unsupported element type for collective op");
@@ -649,7 +589,7 @@
 
     // Get the collective element type attribute.
     IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
-        getCollectiveElementTypeAttr(op.getContext(), inputType);
+        IREE::Flow::getCollectiveElementTypeAttr(inputType);
     if (!elementTypeAttr) {
       return rewriter.notifyMatchFailure(op, "unsupported input type");
     }
@@ -738,7 +678,7 @@
     // Get the collective element type attribute.
     auto resultType = cast<RankedTensorType>(op.getType());
     IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
-        getCollectiveElementTypeAttr(op.getContext(), resultType);
+        IREE::Flow::getCollectiveElementTypeAttr(resultType);
     if (!elementTypeAttr) {
       return rewriter.notifyMatchFailure(
           op, "unsupported element type for collective op");
@@ -842,7 +782,7 @@
     // Get the collective element type attribute.
     auto resultType = cast<RankedTensorType>(op.getResult().getType());
     IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
-        getCollectiveElementTypeAttr(op.getContext(), resultType);
+        IREE::Flow::getCollectiveElementTypeAttr(resultType);
     if (!elementTypeAttr) {
       return rewriter.notifyMatchFailure(op, "unsupported input type");
     }
@@ -932,7 +872,7 @@
 
     // Get the collective element type attribute.
     IREE::Flow::CollectiveElementTypeAttr elementTypeAttr =
-        getCollectiveElementTypeAttr(op.getContext(), inputType);
+        IREE::Flow::getCollectiveElementTypeAttr(inputType);
     if (!elementTypeAttr) {
       return rewriter.notifyMatchFailure(op, "unsupported input type");
     }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel
new file mode 100644
index 0000000..32305c5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel
@@ -0,0 +1,57 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_compiler_cc_library(
+    name = "MeshToFlow",
+    srcs = [
+        "MeshToFlow.cpp",
+    ],
+    hdrs = [
+        "MeshToFlow.h",
+    ],
+    deps = [
+        ":PassesIncGen",
+        "//compiler/src/iree/compiler/Dialect/Flow/IR",
+        "//compiler/src/iree/compiler/Dialect/Util/IR",
+        "//compiler/src/iree/compiler/Utils",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:ArithUtils",
+        "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
+        "@llvm-project//mlir:LinalgUtils",
+        "@llvm-project//mlir:MeshDialect",
+        "@llvm-project//mlir:MeshTransforms",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+iree_gentbl_cc_library(
+    name = "PassesIncGen",
+    tbl_outs = [
+        (
+            ["--gen-pass-decls"],
+            "Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "Passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt
new file mode 100644
index 0000000..f1028f2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/CMakeLists.txt
@@ -0,0 +1,49 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/BUILD.bazel    #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    MeshToFlow
+  HDRS
+    "MeshToFlow.h"
+  SRCS
+    "MeshToFlow.cpp"
+  DEPS
+    ::PassesIncGen
+    LLVMSupport
+    MLIRAffineDialect
+    MLIRArithUtils
+    MLIRIR
+    MLIRLinalgDialect
+    MLIRLinalgUtils
+    MLIRMeshDialect
+    MLIRMeshTransforms
+    MLIRPass
+    MLIRSupport
+    MLIRTensorDialect
+    MLIRTransforms
+    iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::Util::IR
+    iree::compiler::Utils
+  PUBLIC
+)
+
+iree_tablegen_library(
+  NAME
+    PassesIncGen
+  TD_FILE
+    "Passes.td"
+  OUTS
+    --gen-pass-decls Passes.h.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp
new file mode 100644
index 0000000..d9294a4
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.cpp
@@ -0,0 +1,737 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
+
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/Folding.h"
+#include "iree/compiler/Utils/Indexing.h"
+#include "iree/compiler/Utils/OpVisitor.h"
+#include "iree/compiler/Utils/Permutation.h"
+#include "iree/compiler/Utils/SmallVectorDenseMapInfo.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-mesh-to-flow"
+
+namespace mlir::iree_compiler::IREE::Flow {
+
+#define GEN_PASS_CLASSES
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.h.inc" // IWYU pragma: keep
+
+static bool hasMoreThanOneMesh(Operation *op) {
+  int meshCount = 0;
+  op->walk([&meshCount](mesh::MeshOp mesh) {
+    ++meshCount;
+    return meshCount > 1 ? WalkResult::interrupt() : WalkResult::advance();
+  });
+  return meshCount > 1;
+}
+
+static bool isDefaultChannel(mesh::MeshOp mesh,
+                             ArrayRef<mesh::MeshAxis> meshAxes) {
+  if (mesh.getRank() != static_cast<int64_t>(meshAxes.size())) {
+    return false;
+  }
+  return isIdentityPermutation(meshAxes);
+}
+
+static Value getDefaultChannel(mesh::MeshOp mesh, bool useNamedDefaultChannels,
+                               ImplicitLocOpBuilder &builder) {
+  if (useNamedDefaultChannels)
+    return builder.create<IREE::Flow::ChannelDefaultOp>(mesh.getSymName());
+  else
+    return builder.create<IREE::Flow::ChannelDefaultOp>();
+}
+
+// Remove from `values` elements that have indices present in filter.
+static SmallVector<Value> filterOutByIndex(ArrayRef<Value> values,
+                                           ArrayRef<mesh::MeshAxis> filter) {
+  SmallVector<Value> res;
+  for (size_t i = 0; i < values.size(); ++i) {
+    if (!llvm::is_contained(filter, i)) {
+      res.push_back(values[i]);
+    }
+  }
+  return res;
+}
+
+static CollectiveReductionOp convertReductionKind(mesh::Partial reduction) {
+  switch (reduction) {
+  case mesh::Partial::Max:
+    return CollectiveReductionOp::ReductionMaximum;
+  case mesh::Partial::Min:
+    return CollectiveReductionOp::ReductionMinimum;
+  case mesh::Partial::Sum:
+    return CollectiveReductionOp::ReductionSum;
+  default:
+    assert(false);
+    return CollectiveReductionOp::None;
+  }
+}
+
+static CollectiveReductionOpAttr
+convertReductionKind(mesh::PartialAttr reduction) {
+  return CollectiveReductionOpAttr::get(
+      reduction.getContext(), convertReductionKind(reduction.getValue()));
+}
+
+static Value buildChannelCreation(mesh::MeshOp mesh,
+                                  ArrayRef<mesh::MeshAxis> meshAxes,
+                                  bool useNamedDefaultChannels,
+                                  ImplicitLocOpBuilder &builder) {
+  assert(mesh);
+  Value meshChannel = getDefaultChannel(mesh, useNamedDefaultChannels, builder);
+  SmallVector<Value> meshProcessMultiIndex =
+      builder.create<mesh::ProcessMultiIndexOp>(mesh).getResults();
+  SmallVector<Value> meshShape =
+      builder.create<mesh::MeshShapeOp>(mesh).getResults();
+  SmallVector<Value> reorderedMeshIndex =
+      permute(ArrayRef<Value>(meshProcessMultiIndex), meshAxes);
+  SmallVector<Value> reorderedMeshShape =
+      permute(ArrayRef<Value>(meshShape), meshAxes);
+  SmallVector<Value> groupIndex =
+      filterOutByIndex(meshProcessMultiIndex, meshAxes);
+  SmallVector<Value> groupsShape = filterOutByIndex(meshShape, meshAxes);
+  OpFoldResult reorderedProcessLinearIndex =
+      linearIndexFromShape(toOpFoldResults(reorderedMeshIndex),
+                           toOpFoldResults(reorderedMeshShape), builder);
+  OpFoldResult color = linearIndexFromShape(
+      toOpFoldResults(groupIndex), toOpFoldResults(groupsShape), builder);
+  return builder.create<ChannelSplitOp>(
+      meshChannel,
+      getValueOrCreateConstantIndexOp(builder, builder.getLoc(), color),
+      getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+                                      reorderedProcessLinearIndex));
+}
+
+static SmallString<64> getChannelName(mesh::MeshOp mesh,
+                                      ArrayRef<mesh::MeshAxis> axes) {
+  SmallString<64> res;
+  llvm::raw_svector_ostream stream(res);
+  stream << "_mesh_" << mesh.getSymName();
+  if (axes.empty()) {
+    return res;
+  }
+
+  stream << "_axes";
+  for (mesh::MeshAxis axis : axes) {
+    stream << "_" << axis;
+  }
+
+  return res;
+}
+
+static void buildChannelInitializer(mesh::MeshOp mesh,
+                                    ArrayRef<mesh::MeshAxis> meshAxes,
+                                    bool useNamedDefaultChannels,
+                                    ImplicitLocOpBuilder &builder) {
+  Util::InitializerOp initOp = builder.create<Util::InitializerOp>();
+  Block *block = builder.createBlock(&initOp.getBody());
+  ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointToStart(block);
+  Value channel =
+      buildChannelCreation(mesh, meshAxes, useNamedDefaultChannels, builder);
+  builder.create<Util::GlobalStoreOp>(channel, getChannelName(mesh, meshAxes));
+  builder.create<Util::ReturnOp>();
+}
+
+// Construct a Flow channel inside `module` using
+// util.global and util.initializer.
+static void buildGlobalChannelCreation(mesh::MeshOp mesh,
+                                       ArrayRef<mesh::MeshAxis> meshAxes,
+                                       bool useNamedDefaultChannels,
+                                       ModuleOp module, OpBuilder &opBuilder) {
+  if (isDefaultChannel(mesh, meshAxes)) {
+    return;
+  }
+
+  ImplicitLocOpBuilder builder(mesh.getLoc(), opBuilder);
+  ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointToStart(&module.getBodyRegion().getBlocks().front());
+
+  auto channelName = getChannelName(mesh, meshAxes);
+  builder.create<Util::GlobalOp>(
+      builder.getStringAttr("private"), channelName,
+      IREE::Flow::ChannelType::get(builder.getContext()), false, TypedAttr(),
+      IREE::Util::InlineNeverAttr::get(builder.getContext()));
+  buildChannelInitializer(mesh, meshAxes, useNamedDefaultChannels, builder);
+}
+
+static Value buildCachedChannelLoading(mesh::MeshOp mesh,
+                                       ArrayRef<mesh::MeshAxis> meshAxes,
+                                       bool useNamedDefaultChannels,
+                                       ImplicitLocOpBuilder &builder) {
+  if (isDefaultChannel(mesh, meshAxes)) {
+    return getDefaultChannel(mesh, useNamedDefaultChannels, builder);
+  }
+  return builder.create<Util::GlobalLoadOp>(
+      ChannelType::get(builder.getContext()), getChannelName(mesh, meshAxes));
+}
+
+// The !flow.channel corresponding to the mesh and mesh axes used in the op.
+template <typename MeshCollectiveOp>
+static Value buildCachedChannelLoading(
+    MeshCollectiveOp op, SymbolTableCollection &symbolTableCollection,
+    bool useNamedDefaultChannels, ImplicitLocOpBuilder &builder) {
+  ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointAfter(op);
+
+  mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+  return buildCachedChannelLoading(mesh, op.getMeshAxes(),
+                                   useNamedDefaultChannels, builder);
+}
+
+SmallVector<mesh::MeshAxis> getAllMeshAxes(mesh::MeshOp mesh) {
+  SmallVector<mesh::MeshAxis> res(mesh.getRank());
+  std::iota(res.begin(), res.end(), 0);
+  return res;
+}
+
+static Value buildCachedChannelLoading(
+    mesh::ProcessLinearIndexOp op, SymbolTableCollection &symbolTableCollection,
+    bool useNamedDefaultChannels, ImplicitLocOpBuilder &builder) {
+  ImplicitLocOpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointAfter(op);
+
+  mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+  return buildCachedChannelLoading(mesh, getAllMeshAxes(mesh),
+                                   useNamedDefaultChannels, builder);
+}
+
+static TypedValue<RankedTensorType>
+buildTranspose(Value v, ArrayRef<int64_t> transposeVector,
+               ImplicitLocOpBuilder &builder) {
+  RankedTensorType type = v.getType().cast<RankedTensorType>();
+  SmallVector<int64_t> transposedShape =
+      permute(type.getShape(), transposeVector);
+  Value target =
+      builder.create<tensor::EmptyOp>(transposedShape, type.getElementType());
+  return builder.create<linalg::TransposeOp>(v, target, transposeVector)
+      ->getResult(0)
+      .cast<TypedValue<RankedTensorType>>();
+}
+
+static SmallVector<int64_t> transpose(ArrayRef<int64_t> shape, int64_t axisA,
+                                      int64_t axisB) {
+  SmallVector<int64_t> res = llvm::to_vector(shape);
+  std::swap(res[axisA], res[axisB]);
+  return res;
+}
+
+static RankedTensorType transpose(RankedTensorType type, int64_t axisA,
+                                  int64_t axisB) {
+  SmallVector<int64_t> newShape = transpose(type.getShape(), axisA, axisB);
+  return type.clone(newShape);
+}
+
+static TypedValue<RankedTensorType>
+buildTranspose(Value v, int64_t axisA, int64_t axisB,
+               ImplicitLocOpBuilder &builder) {
+  int64_t rank = v.getType().cast<RankedTensorType>().getRank();
+  SmallVector<int64_t> transposeVector(rank);
+  std::iota(transposeVector.begin(), transposeVector.end(), 0);
+  std::swap(transposeVector[axisA], transposeVector[axisB]);
+  return buildTranspose(v, transposeVector, builder);
+}
+
+// (..., splitAxisSize, ...) ->
+// (..., splitCount, splitAxisSize / splitCount, ...)
+static TypedValue<RankedTensorType>
+splitAxis(TypedValue<RankedTensorType> tensor, int64_t splitAxis,
+          int64_t splitCount, ImplicitLocOpBuilder &builder) {
+  ArrayRef<int64_t> shape = tensor.getType().getShape();
+  SmallVector<int64_t> newShape;
+  newShape.reserve(shape.size() + 1);
+  for (int64_t i = 0; i < tensor.getType().getRank(); ++i) {
+    if (i != splitAxis) {
+      newShape.push_back(shape[i]);
+      continue;
+    }
+    newShape.push_back(splitCount);
+    if (ShapedType::isDynamic(shape[i])) {
+      newShape.push_back(ShapedType::kDynamic);
+    } else {
+      assert(shape[i] % splitCount == 0);
+      newShape.push_back(shape[i] / splitCount);
+    }
+  }
+
+  RankedTensorType resultType = tensor.getType().clone(newShape);
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForReshape(tensor.getType(), resultType);
+  return builder.create<tensor::ExpandShapeOp>(resultType, tensor,
+                                               reassociation.value());
+}
+
+// Transposes the input tensor by moving an axis to a new position by inserting
+// it there.
+static TypedValue<RankedTensorType>
+moveAxis(TypedValue<RankedTensorType> tensor, int64_t axis, int64_t destination,
+         ImplicitLocOpBuilder &builder) {
+  SmallVector<int64_t> permutation =
+      makeMovePermutation(tensor.getType().getRank(), axis, destination);
+  return buildTranspose(tensor, permutation, builder);
+}
+
+static SmallVector<int64_t> collapseAxesN(ArrayRef<int64_t> shape,
+                                          size_t firstAxis, size_t n) {
+  assert(firstAxis + n <= shape.size());
+  assert(n > 1);
+  SmallVector<int64_t> res;
+  std::copy(shape.begin(), shape.begin() + firstAxis, std::back_inserter(res));
+  size_t collapsedAxisSize = std::accumulate(
+      shape.begin() + firstAxis + 1, shape.begin() + firstAxis + n,
+      shape[firstAxis], [](size_t a, size_t b) { return a * b; });
+  res.push_back(collapsedAxisSize);
+  std::copy(shape.begin() + firstAxis + n, shape.end(),
+            std::back_inserter(res));
+  return res;
+}
+
+// Collapses `n` axes starting with axis `firstAxis`.
+// Example:
+// tensor shape = (1, 2, 3, 4), firstAxis = 1, n = 2
+// The resulting tensor is with shape (1, 6, 4).
+static TypedValue<RankedTensorType>
+collapseAxesN(TypedValue<RankedTensorType> tensor, int64_t firstAxis, int64_t n,
+              ImplicitLocOpBuilder &builder) {
+  ArrayRef<int64_t> shape = tensor.getType().getShape();
+  SmallVector<int64_t> newShape = collapseAxesN(shape, firstAxis, n);
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(shape, newShape);
+  return builder.create<tensor::CollapseShapeOp>(tensor, reassociation.value());
+}
+
+// Splits an axis into 2 new dimensions and then move the new splitCount axis
+// and collapse it into collapseAxis.
+// The shape of the tensor and its transformations:
+// (..., splitAxisSize, ..., collapseAxisSize, ...)
+// -> split ->
+// (..., splitCount, splitAxisSize / splitCount, ..., collapseAxisSize, ...)
+// -> move ->
+// (..., splitAxisSize / splitCount, ..., splitCount, collapseAxisSize, ...)
+// -> concat ->
+// (..., splitAxisSize / splitCount, ..., splitCount * collapseAxisSize, ...)
+static TypedValue<RankedTensorType>
+splitMoveCollapse(TypedValue<RankedTensorType> tensor, int64_t splitAxis,
+                  int64_t collapseAxis, int64_t splitCount,
+                  ImplicitLocOpBuilder &builder) {
+  TypedValue<RankedTensorType> v =
+      IREE::Flow::splitAxis(tensor, splitAxis, splitCount, builder);
+  v = moveAxis(v, splitAxis, collapseAxis, builder);
+  return collapseAxesN(v, collapseAxis, 2, builder);
+}
+
+namespace {
+
+template <typename Op>
+struct MeshToFlowCollectiveRewritePatternBase : OpRewritePattern<Op> {
+  template <typename... OpRewritePatternArgs>
+  MeshToFlowCollectiveRewritePatternBase(
+      SymbolTableCollection &symbolTableCollection,
+      bool useNamedDefaultChannels,
+      OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern<Op>(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTableCollection(symbolTableCollection),
+        useNamedDefaultChannels(useNamedDefaultChannels) {}
+
+protected:
+  SymbolTableCollection &symbolTableCollection;
+  bool useNamedDefaultChannels;
+};
+
+struct MeshAllReduceToFlow
+    : MeshToFlowCollectiveRewritePatternBase<mesh::AllReduceOp> {
+  using MeshToFlowCollectiveRewritePatternBase::
+      MeshToFlowCollectiveRewritePatternBase;
+
+  LogicalResult matchAndRewrite(mesh::AllReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+    Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+                                              useNamedDefaultChannels, builder);
+    Value target = builder.create<tensor::EmptyOp>(
+        op.getResult().getType().getShape(),
+        op.getResult().getType().getElementType());
+    auto flowAllReduce = builder.create<CollectiveAllReduceOp>(
+        convertReductionKind(op.getReductionAttr()),
+        getCollectiveElementTypeAttr(op.getResult().getType()), target,
+        op.getOperand(), channel);
+    rewriter.replaceAllUsesWith(op.getResult(), flowAllReduce.getResult());
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+struct MeshAllGatherToFlow
+    : MeshToFlowCollectiveRewritePatternBase<mesh::AllGatherOp> {
+  using MeshToFlowCollectiveRewritePatternBase::
+      MeshToFlowCollectiveRewritePatternBase;
+
+  LogicalResult matchAndRewrite(mesh::AllGatherOp op,
+                                PatternRewriter &rewriter) const override {
+    if (ShapedType::isDynamicShape(
+            op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+        ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+      // TODO: add dynamic support.
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "Dynamic tensor case is unsupported.");
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+    Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+                                              useNamedDefaultChannels, builder);
+
+    int64_t gatherAxis = op.getGatherAxis().getSExtValue();
+
+    // When gather axis != 0, we need to transpose between 0 and
+    // gather axis before and after the flow all-gather op.
+    Value flowAllGatherOperand =
+        buildTranspose(op.getOperand(), 0, gatherAxis, builder);
+
+    RankedTensorType flowAllGatherResultType = transpose(
+        op.getResult().getType().cast<RankedTensorType>(), 0, gatherAxis);
+    Value target = builder.create<tensor::EmptyOp>(
+        flowAllGatherResultType.getShape(),
+        op.getResult().getType().getElementType());
+    auto flowAllGather = builder.create<CollectiveAllGatherOp>(
+        getCollectiveElementTypeAttr(flowAllGatherResultType), target,
+        flowAllGatherOperand, channel);
+
+    Value res = buildTranspose(flowAllGather, 0, gatherAxis, builder);
+
+    rewriter.replaceAllUsesWith(op.getResult(), res);
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+struct MeshAllToAllToFlow
+    : MeshToFlowCollectiveRewritePatternBase<mesh::AllToAllOp> {
+  using MeshToFlowCollectiveRewritePatternBase::
+      MeshToFlowCollectiveRewritePatternBase;
+
+  LogicalResult matchAndRewrite(mesh::AllToAllOp op,
+                                PatternRewriter &rewriter) const override {
+    if (ShapedType::isDynamicShape(
+            op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+        ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+      // TODO: add dynamic support.
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "Dynamic tensor case is unsupported.");
+    }
+
+    mesh::MeshOp mesh = mesh::getMesh(op, symbolTableCollection);
+    assert(!ShapedType::isDynamicShape(mesh.getShape()));
+    int64_t splitCount =
+        mesh::collectiveProcessGroupSize(op.getMeshAxes(), mesh.getShape());
+    // TODO: handle dynamic case.
+    if (ShapedType::isDynamic(splitCount)) {
+      // TODO: add dynamic support.
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "Dynamic split count induced by a dynamic mesh is unsupported.");
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+
+    Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+                                              useNamedDefaultChannels, builder);
+
+    int64_t splitAxis = op.getSplitAxis().getSExtValue();
+
+    TypedValue<RankedTensorType> splitAxisAsMostOuter =
+        buildTranspose(op.getOperand(), 0, splitAxis, builder);
+
+    Value target = builder.create<tensor::EmptyOp>(
+        splitAxisAsMostOuter.getType().getShape(),
+        splitAxisAsMostOuter.getType().getElementType());
+    auto flowAllToAll = builder.create<CollectiveAllToAllOp>(
+        getCollectiveElementTypeAttr(splitAxisAsMostOuter.getType()), target,
+        splitAxisAsMostOuter, channel);
+
+    TypedValue<RankedTensorType> splitAxisBackInItsPlace =
+        buildTranspose(flowAllToAll, 0, splitAxis, builder);
+
+    int64_t concatAxis = op.getConcatAxis().getSExtValue();
+    Value res = splitMoveCollapse(splitAxisBackInItsPlace, splitAxis,
+                                  concatAxis, splitCount, builder);
+
+    rewriter.replaceAllUsesWith(op.getResult(), res);
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+struct MeshProcessLinearIndexToFlow
+    : MeshToFlowCollectiveRewritePatternBase<mesh::ProcessLinearIndexOp> {
+  using MeshToFlowCollectiveRewritePatternBase::
+      MeshToFlowCollectiveRewritePatternBase;
+
+  LogicalResult matchAndRewrite(mesh::ProcessLinearIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+    Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+                                              useNamedDefaultChannels, builder);
+    Value newIndex =
+        builder.create<ChannelRankOp>(builder.getIndexType(), channel);
+    rewriter.replaceAllUsesWith(op.getResult(), newIndex);
+    return success();
+  }
+};
+
+struct MeshReduceScatterToFlow
+    : MeshToFlowCollectiveRewritePatternBase<mesh::ReduceScatterOp> {
+  using MeshToFlowCollectiveRewritePatternBase::
+      MeshToFlowCollectiveRewritePatternBase;
+
+  LogicalResult matchAndRewrite(mesh::ReduceScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    if (ShapedType::isDynamicShape(
+            op.getOperand().getType().cast<RankedTensorType>().getShape()) ||
+        ShapedType::isDynamicShape(op.getResult().getType().getShape())) {
+      // TODO: add dynamic support.
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "Dynamic tensor case is unsupported.");
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+    Value channel = buildCachedChannelLoading(op, symbolTableCollection,
+                                              useNamedDefaultChannels, builder);
+
+    int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+
+    // When scatter axis != 0, we need to transpose between 0 and
+    // scatter axis before and after the flow reduce-scatter op.
+    Value flowReduceScatterOperand =
+        buildTranspose(op.getOperand(), 0, scatterAxis, builder);
+    RankedTensorType flowReduceScatterResultType = transpose(
+        op.getResult().getType().cast<RankedTensorType>(), 0, scatterAxis);
+
+    Value target = builder.create<tensor::EmptyOp>(
+        flowReduceScatterResultType.getShape(),
+        op.getResult().getType().getElementType());
+    auto flowReduceScatter = builder.create<CollectiveReduceScatterOp>(
+        convertReductionKind(op.getReductionAttr()),
+        getCollectiveElementTypeAttr(flowReduceScatterResultType), target,
+        flowReduceScatterOperand, channel);
+
+    Value res = buildTranspose(flowReduceScatter, 0, scatterAxis, builder);
+
+    rewriter.replaceAllUsesWith(op.getResult(), res);
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+using MeshAndAxesSet =
+    DenseSet<std::tuple<mesh::MeshOp, SmallVector<mesh::MeshAxis>>>;
+
+template <typename Op>
+struct CollectiveOpVisitor {
+  CollectiveOpVisitor(MeshAndAxesSet &meshAndAxesSet,
+                      SymbolTableCollection &symbolTableCollection)
+      : meshAndAxesSet(meshAndAxesSet),
+        symbolTableCollection(symbolTableCollection) {}
+  void operator()(Op op) {
+    meshAndAxesSet.insert(std::make_tuple(
+        symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+            op, op.getMeshAttr()),
+        llvm::to_vector(op.getMeshAxes())));
+  }
+
+private:
+  MeshAndAxesSet &meshAndAxesSet;
+  SymbolTableCollection &symbolTableCollection;
+};
+
+template <typename Op>
+struct CollectiveOpWithoutMeshAxesVisitor {
+  CollectiveOpWithoutMeshAxesVisitor(
+      MeshAndAxesSet &meshAndAxesSet,
+      SymbolTableCollection &symbolTableCollection)
+      : meshAndAxesSet(meshAndAxesSet),
+        symbolTableCollection(symbolTableCollection) {}
+  void operator()(Op op) {
+    mesh::MeshOp mesh =
+        symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+            op, op.getMeshAttr());
+    meshAndAxesSet.insert(std::make_tuple(mesh, getAllMeshAxes(mesh)));
+  }
+
+private:
+  MeshAndAxesSet &meshAndAxesSet;
+  SymbolTableCollection &symbolTableCollection;
+};
+
+void populateMeshAndAxes(Operation *op, MeshAndAxesSet &meshAndAxesSet,
+                         SymbolTableCollection &symbolTableCollection) {
+  OpVisitorCollection opVisitors;
+  opVisitors.emplaceVisitors<
+      CollectiveOpVisitor<mesh::AllGatherOp>,
+      CollectiveOpVisitor<mesh::AllReduceOp>,
+      CollectiveOpVisitor<mesh::AllToAllOp>,
+      CollectiveOpVisitor<mesh::ReduceScatterOp>,
+      CollectiveOpWithoutMeshAxesVisitor<mesh::ProcessLinearIndexOp>>(
+      meshAndAxesSet, symbolTableCollection);
+
+  op->walk([&opVisitors](Operation *op) {
+    opVisitors(op);
+    return WalkResult::advance();
+  });
+}
+
+static void createChannels(ModuleOp moduleOp,
+                           SymbolTableCollection &symbolTableCollection,
+                           MeshAndAxesSet &meshAndAxesSet,
+                           bool useNamedDefaultChannels) {
+  populateMeshAndAxes(moduleOp, meshAndAxesSet, symbolTableCollection);
+
+  OpBuilder builder(moduleOp->getContext());
+
+  // Sort for deterministic testing with FileCheck.
+  auto meshAndAxesSetSorted = llvm::to_vector(meshAndAxesSet);
+  llvm::sort(meshAndAxesSetSorted, [](auto &a, auto &b) {
+    int nameCompareRes =
+        std::get<0>(a).getSymName().compare(std::get<0>(b).getSymName());
+    if (nameCompareRes == 0)
+      return std::get<1>(a) < std::get<1>(b);
+    return nameCompareRes < 0;
+  });
+  for (auto &[mesh, meshAxes] : llvm::make_range(meshAndAxesSetSorted.rbegin(),
+                                                 meshAndAxesSetSorted.rend())) {
+    buildGlobalChannelCreation(mesh, meshAxes, useNamedDefaultChannels,
+                               moduleOp, builder);
+  }
+}
+
+static LogicalResult
+convertCollectives(ModuleOp moduleOp,
+                   SymbolTableCollection &symbolTableCollection,
+                   bool useNamedDefaultChannels) {
+  RewritePatternSet patterns(moduleOp->getContext());
+  IREE::Flow::populateMeshToFlowCollectivesPatterns(
+      patterns, symbolTableCollection, useNamedDefaultChannels);
+  return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
+static void removeMeshOps(MeshAndAxesSet &meshAndAxesSet) {
+  auto meshRange =
+      llvm::map_range(meshAndAxesSet, [](auto &v) { return std::get<0>(v); });
+  DenseSet<mesh::MeshOp> meshOpsSet(std::begin(meshRange), std::end(meshRange));
+  for (mesh::MeshOp op : meshOpsSet) {
+    if (op)
+      op.erase();
+  }
+}
+
+struct ConvertMeshToFlowPass
+    : public IREE::Flow::ConvertMeshToFlowBase<ConvertMeshToFlowPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registerMeshToFlowDependencies(registry);
+  }
+
+  void runOnOperation() override {
+    // Run only on the top module.
+    if (getOperation()->getParentOp() != nullptr) {
+      return;
+    }
+
+    MeshAndAxesSet meshAndAxesSet;
+    SymbolTableCollection symbolTableCollection;
+    bool useNamedDefaultChannels = hasMoreThanOneMesh(getOperation());
+
+    createChannels(getOperation(), symbolTableCollection, meshAndAxesSet,
+                   useNamedDefaultChannels);
+    if (failed(convertCollectives(getOperation(), symbolTableCollection,
+                                  useNamedDefaultChannels))) {
+      return signalPassFailure();
+    }
+
+    // Cleanup mesh definition ops that are no longer referenced.
+    removeMeshOps(meshAndAxesSet);
+  }
+};
+
+} // namespace
+
+void populateMeshToFlowCollectivesPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection,
+    bool useNamedDefaultChannels) {
+  patterns.add<MeshAllGatherToFlow, MeshAllReduceToFlow, MeshAllToAllToFlow,
+               MeshReduceScatterToFlow, MeshProcessLinearIndexToFlow>(
+      symbolTableCollection, useNamedDefaultChannels, patterns.getContext());
+  mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+  mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+                                                    symbolTableCollection);
+}
+
+std::unique_ptr<Pass> createConvertMeshToFlowPass() {
+  return std::make_unique<ConvertMeshToFlowPass>();
+}
+
+void registerMeshToFlowDependencies(DialectRegistry &registry) {
+  registry.insert<affine::AffineDialect, FlowDialect, linalg::LinalgDialect,
+                  mesh::MeshDialect, tensor::TensorDialect>();
+}
+
+void registerMeshToFlowPasses() { registerPasses(); }
+
+} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h
new file mode 100644
index 0000000..acc0572
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h
@@ -0,0 +1,27 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class DialectRegistry;
+class SymbolTableCollection;
+class OpPassManager;
+class Pass;
+namespace iree_compiler::IREE::Flow {
+
+std::unique_ptr<Pass> createConvertMeshToFlowPass();
+
+void populateMeshToFlowCollectivesPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection,
+    bool useNamedDefaultChannels);
+
+void registerMeshToFlowDependencies(DialectRegistry &registry);
+void registerMeshToFlowPasses();
+
+} // namespace iree_compiler::IREE::Flow
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td
new file mode 100644
index 0000000..c593328
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/Passes.td
@@ -0,0 +1,53 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
+#define IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ConvertMeshToFlow :
+    Pass<"iree-convert-mesh-to-flow", "mlir::ModuleOp"> {
+  let summary = "Convert Mesh dialect operations to IREE Flow.";
+  let description = [{
+    Each mesh corresponds to a default Flow channel with the same group name.
+    ```
+    mesh.mesh @mesh_1(shape = 2x3)
+    ```
+    ```
+    %channel = flow.channel.default "mesh_1" : !flow.channel
+    ```
+    If there is onl one mesh in the program than the name is omitted and the
+    ```
+    %channel = flow.channel.default : !flow.channel
+    ```
+
+    Each (mesh, mesh_axes) pair partitions and orders the devices into disjoint
+    groups, each corresponding to a Flow channel to perform a collective
+    operation.
+    For example
+    ```
+    mesh.mesh @mesh(shape = 2x3x4x5)
+    ...
+    %1 = mesh.all_reduce on @mesh mesh_axes = [2, 0] : tensor<10x20xf32>
+    ```
+    For more information see
+    [Mesh dialect](https://mlir.llvm.org/docs/Dialects/Mesh/#device-groups).
+
+    The mesh partition and device ordering determines the values for the
+    `color` and `key` in the corresponding `flow.channel.split` operation used
+    to create the channel.
+    For more information on the meaning of `color` and `key` see
+    [MPI_Comm_split](https://www.mpi-forum.org/docs/mpi-4.1/mpi41-report/node188.htm#Node188)
+    in the MPI standard.
+
+    Each Flow channel is wrapped in an IREE `util.global` and its construction
+    is done only once with `util.initializer`.
+  }];
+  let constructor = "mlir::iree_compiler::IREE::Flow::createConvertMeshToFlowPass()";
+}
+
+#endif  // IREE_DIALECT_FLOW_MESHTOFLOW_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel
new file mode 100644
index 0000000..72882ed
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel
@@ -0,0 +1,29 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = enforce_glob(
+        [
+            "channel_creation.mlir",
+            "collectives.mlir",
+        ],
+        include = ["*.mlir"],
+    ),
+    cfg = "//compiler:lit.cfg.py",
+    tools = [
+        "//tools:iree-opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt
new file mode 100644
index 0000000..449e06b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/BUILD.bazel#
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "channel_creation.mlir"
+    "collectives.mlir"
+  TOOLS
+    FileCheck
+    iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir
new file mode 100644
index 0000000..fe9cbae
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/channel_creation.mlir
@@ -0,0 +1,130 @@
+// RUN: iree-opt --split-input-file --iree-convert-mesh-to-flow --cse %s | FileCheck %s
+
+// CHECK-LABEL: module @static_1d_mesh_grouping_along_axis_0
+module @static_1d_mesh_grouping_along_axis_0 {
+
+  // No channel initialization default channel is expected.
+  // CHECK-NOT: util.global private @_mesh_mesh_1d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+  mesh.mesh @mesh_1d(shape = 2)
+
+  func.func @f(
+      %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+    %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] reduction = <sum>
+      : tensor<1xi8> -> tensor<1xi8>
+    return %0 : tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @static_2d_mesh_grouping_along_axis_1
+module @static_2d_mesh_grouping_along_axis_1 {
+
+  // CHECK: util.global private @_mesh_mesh_2d_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+  // CHECK: util.initializer {
+    // CHECK-DAG: %[[AXIS_1_SIZE:.*]] = arith.constant 4 : index
+    // CHECK-DAG: %[[AXIS_0_SIZE:.*]] = arith.constant 3 : index
+    // CHECK-DAG: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default : !flow.channel
+    // CHECK: %[[CHANNEL_RANK:.*]] = flow.channel.rank %[[DEFAULT_CHANNEL]] : index
+    // CHECK: %[[COLOR_AND_KEY:.*]]:2 = affine.delinearize_index %[[CHANNEL_RANK]] into
+    // CHECK-SAME: (%[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]) : index, index
+    // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+    // CHECK-SAME: %[[DEFAULT_CHANNEL]], %[[COLOR_AND_KEY]]#0, %[[COLOR_AND_KEY]]#1 : !flow.channel -> !flow.channel
+    // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh_2d_axes_1 : !flow.channel
+  mesh.mesh @mesh_2d(shape = 3x4)
+
+  func.func @f(%input : tensor<1xi8>) -> tensor<1xi8> {
+    %out = mesh.all_reduce %input on @mesh_2d mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+    return %out : tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+// CHECK-LABEL: module @static_4d_mesh_grouping_along_axes_2_1
+module @static_4d_mesh_grouping_along_axes_2_1 {
+
+  // CHECK: util.global private @_mesh_mesh_4d_axes_2_1 {inlining_policy = #util.inline.never} : !flow.channel
+  // CHECK: util.initializer {
+    // CHECK-DAG: %[[AXIS_3_SIZE:.*]] = arith.constant 6 : index
+    // CHECK-DAG: %[[AXIS_2_SIZE:.*]] = arith.constant 5 : index
+    // CHECK-DAG: %[[AXIS_1_SIZE:.*]] = arith.constant 4 : index
+    // CHECK-DAG: %[[AXIS_0_SIZE:.*]] = arith.constant 3 : index
+    // CHECK-DAG: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default : !flow.channel
+    // CHECK: %[[CHANNEL_RANK:.*]] = flow.channel.rank %[[DEFAULT_CHANNEL]] : index
+    // CHECK: %[[DEVICE_MULTI_IDX:.*]]:4 = affine.delinearize_index %[[CHANNEL_RANK]] into
+    // CHECK-SAME: (%[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]], %[[AXIS_2_SIZE]], %[[AXIS_3_SIZE]]) : index, index, index, index
+    // CHECK: %[[IN_GROUP_IDX:.*]] = affine.apply
+    // CHECK-SAME: #map()[%[[DEVICE_MULTI_IDX]]#2, %[[AXIS_1_SIZE]], %[[DEVICE_MULTI_IDX]]#1]
+    // CHECK: %[[GROUP_IDX:.*]] = affine.apply
+    // CHECK-SAME: #map()[%[[DEVICE_MULTI_IDX]]#0, %[[AXIS_3_SIZE]], %[[DEVICE_MULTI_IDX]]#3]
+    // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+    // CHECK-SAME: %[[DEFAULT_CHANNEL]], %[[GROUP_IDX]], %[[IN_GROUP_IDX]] : !flow.channel -> !flow.channel
+    // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh_4d_axes_2_1 : !flow.channel
+  mesh.mesh @mesh_4d(shape = 3x4x5x6)
+
+  func.func @f(%input : tensor<1xi8>) -> tensor<1xi8> {
+    %out = mesh.all_reduce %input on @mesh_4d mesh_axes = [2, 1] : tensor<1xi8> -> tensor<1xi8>
+    return %out : tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @multiple_different_channels
+module @multiple_different_channels {
+
+  // CHECK-DAG: util.global private @_mesh_mesh_2d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+  // CHECK-DAG: util.global private @_mesh_mesh_2d_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+  mesh.mesh @mesh_2d(shape = 3x4)
+
+  func.func @f(%input : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+    %out0 = mesh.all_reduce %input on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+    %out1 = mesh.all_reduce %input on @mesh_2d mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+    return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @same_channel_used_multiple_times
+module @same_channel_used_multiple_times {
+
+  // CHECK: util.global private @_mesh_mesh_2d_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+  mesh.mesh @mesh_2d(shape = 3x4)
+
+  func.func @f(%input0 : tensor<1xi8>, %input1 : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+    %out0 = mesh.all_reduce %input0 on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+    %out1 = mesh.all_reduce %input1 on @mesh_2d mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+    return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @multiple_meshes
+module @multiple_meshes {
+
+  // CHECK: util.global private @_mesh_mesh1_axes_0 {inlining_policy = #util.inline.never} : !flow.channel
+  // CHECK: util.initializer {
+    // CHECK: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default "mesh1" : !flow.channel
+    // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+    // CHECK-SAME: %[[DEFAULT_CHANNEL]], %{{.*}}, %{{.*}} : !flow.channel -> !flow.channel
+    // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh1_axes_0 : !flow.channel
+  mesh.mesh @mesh1(shape = 1x2)
+  // CHECK: util.global private @_mesh_mesh2_axes_1 {inlining_policy = #util.inline.never} : !flow.channel
+  // CHECK: util.initializer {
+    // CHECK: %[[DEFAULT_CHANNEL:.*]] = flow.channel.default "mesh2" : !flow.channel
+    // CHECK: %[[CHANNEL:.*]] = flow.channel.split
+    // CHECK-SAME: %[[DEFAULT_CHANNEL]], %{{.*}}, %{{.*}} : !flow.channel -> !flow.channel
+    // CHECK: util.global.store %[[CHANNEL]], @_mesh_mesh2_axes_1 : !flow.channel
+  mesh.mesh @mesh2(shape = 3x4)
+
+  func.func @f(%input0 : tensor<1xi8>, %input1 : tensor<1xi8>) -> (tensor<1xi8>, tensor<1xi8>) {
+    %out0 = mesh.all_reduce %input0 on @mesh1 mesh_axes = [0] : tensor<1xi8> -> tensor<1xi8>
+    %out1 = mesh.all_reduce %input1 on @mesh2 mesh_axes = [1] : tensor<1xi8> -> tensor<1xi8>
+    return %out0, %out1 : tensor<1xi8>, tensor<1xi8>
+  }
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir
new file mode 100644
index 0000000..60d16ab
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow/test/collectives.mlir
@@ -0,0 +1,143 @@
+// RUN: iree-opt --split-input-file --iree-convert-mesh-to-flow --cse %s | FileCheck %s
+
+mesh.mesh @mesh_2d(shape = 3x4)
+
+// CHECK-LABEL: func @all_gather_non_default_channel
+func.func @all_gather_non_default_channel(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xi8>
+    %arg0 : tensor<3x4xi8>) -> tensor<3x16xi8> {
+  // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_1 : !flow.channel
+  // CHECK-DAG: %[[TRANSPOSED_OPERAND_INIT_VAL:.*]] = tensor.empty() : tensor<4x3xi8>
+  // CHECK: %[[TRANSPOSED_OPERAND:.*]] = linalg.transpose
+  // CHECK-SAME: ins(%[[ARG]] : tensor<3x4xi8>) outs(%[[TRANSPOSED_OPERAND_INIT_VAL]] : tensor<4x3xi8>) permutation = [1, 0]
+  // CHECK: %[[ALL_GATHER_INITIAL_VAL:.*]] = tensor.empty() : tensor<16x3xi8>
+  // CHECK: %[[ALL_GATHER_RES:.*]] = flow.collective.all_gather ui8,
+  // CHECK-SAME: %[[ALL_GATHER_INITIAL_VAL]], %[[TRANSPOSED_OPERAND]], %[[CHANNEL]]
+  // CHECK-SAME: (tensor<16x3xi8>, tensor<4x3xi8>, !flow.channel) -> %[[ALL_GATHER_INITIAL_VAL]] as tensor<16x3xi8>
+  // CHECK: %[[RES_INIT_VAL:.*]] = tensor.empty() : tensor<3x16xi8>
+  // CHECK: %[[RES:.*]] = linalg.transpose
+  // CHECK-SAME: ins(%[[ALL_GATHER_RES]] : tensor<16x3xi8>) outs(%[[RES_INIT_VAL]] : tensor<3x16xi8>) permutation = [1, 0]
+  %0 = mesh.all_gather %arg0 on @mesh_2d mesh_axes = [1] gather_axis = 1
+    : tensor<3x4xi8> -> tensor<3x16xi8>
+  // CHECK: return %[[RES]] : tensor<3x16xi8>
+  return %0 : tensor<3x16xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @all_reduce_sum_default_channel
+func.func @all_reduce_sum_default_channel(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+    %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+  // CHECK: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+  // CHECK: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xi8>
+  // CHECK: %[[RES:.*]] = flow.collective.all_reduce sum, ui8, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+  // CHECK-SAME: (tensor<1xi8>, tensor<1xi8>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xi8>
+  %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0]
+    : tensor<1xi8> -> tensor<1xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_2d(shape = 2x2)
+
+// CHECK-LABEL: func @all_reduce_min_non_default_channel
+func.func @all_reduce_min_non_default_channel(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+    %arg0 : tensor<1xi8>) -> tensor<1xi8> {
+  // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_1_0 : !flow.channel
+  // CHECK-DAG: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xi8>
+  // CHECK: %[[RES:.*]] = flow.collective.all_reduce minimum, ui8, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+  // CHECK-SAME: (tensor<1xi8>, tensor<1xi8>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xi8>
+  %0 = mesh.all_reduce %arg0 on @mesh_2d mesh_axes = [1, 0] reduction = <min>
+    : tensor<1xi8> -> tensor<1xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @all_reduce_f32
+func.func @all_reduce_f32(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+    %arg0 : tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-DAG: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+  // CHECK-DAG: %[[INITIAL_VAL:.*]] = tensor.empty() : tensor<1xf32>
+  // CHECK: %[[RES:.*]] = flow.collective.all_reduce sum, f32, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
+  // CHECK-SAME: (tensor<1xf32>, tensor<1xf32>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xf32>
+  %0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0]
+    : tensor<1xf32> -> tensor<1xf32>
+  // CHECK: return %[[RES]] : tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+  // CHECK: %[[CHANNEL:.*]] = flow.channel.default : !flow.channel
+  // CHECK: %[[RES:.*]] = flow.channel.rank %[[CHANNEL]] : index
+  %0 = mesh.process_linear_index on @mesh_1d : index
+  // CHECK: return %[[RES]] : index
+  return %0 : index
+}
+
+// -----
+
+mesh.mesh @mesh_3d(shape = 2x3x4)
+
+// CHECK-LABEL: func @all_to_all_non_default_channel
+func.func @all_to_all_non_default_channel(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<1x12x3x4x5xf32>
+    %arg0 : tensor<1x12x3x4x5xf32>) -> tensor<1x2x3x24x5xf32> {
+  // CHECK: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_3d_axes_1_0 : !flow.channel
+  // CHECK: %[[SPLIT_AXIS_AT_FRONT:.*]] = linalg.transpose ins(%[[ARG]] : tensor<1x12x3x4x5xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<12x1x3x4x5xf32>) permutation = [1, 0, 2, 3, 4]
+  // CHECK: %[[FLOW_ALL_TO_ALL:.*]] = flow.collective.all_to_all f32, %{{.*}}, %[[SPLIT_AXIS_AT_FRONT]], %_mesh_mesh_3d_axes_1_0 :
+  // CHECK-SAME: (tensor<12x1x3x4x5xf32>, tensor<12x1x3x4x5xf32>, !flow.channel) -> %0 as tensor<12x1x3x4x5xf32>
+  // CHECK: %[[SPLIT_AXIS_BACK_IN_ITS_PLACE:.*]] = linalg.transpose ins(%1 : tensor<12x1x3x4x5xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<1x12x3x4x5xf32>) permutation = [1, 0, 2, 3, 4]
+  // CHECK: %[[SPLIT_AXIS_IS_SPLIT:.*]] = tensor.expand_shape %[[SPLIT_AXIS_BACK_IN_ITS_PLACE]]
+  // CHECK-SAME-LITERAL: [[0], [1, 2], [3], [4], [5]] : tensor<1x12x3x4x5xf32> into tensor<1x6x2x3x4x5xf32>
+  // CHECK: %[[MOVED_SPLIT_COUNT_AXIS:.*]] = linalg.transpose ins(%[[SPLIT_AXIS_IS_SPLIT]] : tensor<1x6x2x3x4x5xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<1x2x3x6x4x5xf32>) permutation = [0, 2, 3, 1, 4, 5]
+  // CHECK: %[[COLLAPSED_SPLIT_COUNT_INTO_CONCAT_AXIS:.*]] = tensor.collapse_shape %[[MOVED_SPLIT_COUNT_AXIS]]
+  // CHECK-SAME-LITERAL: [[0], [1], [2], [3, 4], [5]] : tensor<1x2x3x6x4x5xf32> into tensor<1x2x3x24x5xf32>
+  %0 = mesh.all_to_all %arg0 on @mesh_3d mesh_axes = [1, 0] split_axis = 1 concat_axis = 3
+    : tensor<1x12x3x4x5xf32> -> tensor<1x2x3x24x5xf32>
+  // CHECK: return %[[COLLAPSED_SPLIT_COUNT_INTO_CONCAT_AXIS]] : tensor<1x2x3x24x5xf32>
+  return %0 : tensor<1x2x3x24x5xf32>
+}
+
+// -----
+
+mesh.mesh @mesh_2d(shape = 2x2)
+
+// CHECK-LABEL: func @reduce_scatter_non_default_channel
+func.func @reduce_scatter_non_default_channel(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x2xi8>
+    %arg0 : tensor<3x2xi8>) -> tensor<3x1xi8> {
+  // CHECK-DAG: %[[CHANNEL:.*]] = util.global.load @_mesh_mesh_2d_axes_0 : !flow.channel
+  // CHECK-DAG: %[[TRANSPOSED_OPERAND_INIT_VAL:.*]] = tensor.empty() : tensor<2x3xi8>
+  // CHECK: %[[TRANSPOSED_OPERAND:.*]] = linalg.transpose
+  // CHECK-SAME: ins(%[[ARG]] : tensor<3x2xi8>) outs(%[[TRANSPOSED_OPERAND_INIT_VAL]] : tensor<2x3xi8>) permutation = [1, 0]
+  // CHECK: %[[REDUCE_SCATTER_INITIAL_VAL:.*]] = tensor.empty() : tensor<1x3xi8>
+  // CHECK: %[[REDUCE_SCATTER_RES:.*]] = flow.collective.reduce_scatter sum, ui8,
+  // CHECK-SAME: %[[REDUCE_SCATTER_INITIAL_VAL]], %[[TRANSPOSED_OPERAND]], %[[CHANNEL]]
+  // CHECK-SAME: (tensor<1x3xi8>, tensor<2x3xi8>, !flow.channel) -> %[[REDUCE_SCATTER_INITIAL_VAL]] as tensor<1x3xi8>
+  // CHECK: %[[RES_INIT_VAL:.*]] = tensor.empty() : tensor<3x1xi8>
+  // CHECK: %[[RES:.*]] = linalg.transpose
+  // CHECK-SAME: ins(%[[REDUCE_SCATTER_RES]] : tensor<1x3xi8>) outs(%[[RES_INIT_VAL]] : tensor<3x1xi8>) permutation = [1, 0]
+  %0 = mesh.reduce_scatter %arg0 on @mesh_2d mesh_axes = [0] scatter_axis = 1
+    : tensor<3x2xi8> -> tensor<3x1xi8>
+  // CHECK: return %[[RES]] : tensor<3x1xi8>
+  return %0 : tensor<3x1xi8>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 1ce78c8..6266615 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -1947,6 +1948,16 @@
   setNameFn(getResult(), "default_channel");
 }
 
+void ChannelDefaultOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                             StringRef group) {
+  ChannelDefaultOp::build(odsBuilder, odsState,
+                          odsBuilder.getStringAttr(group));
+}
+
+void ChannelDefaultOp::build(OpBuilder &odsBuilder, OperationState &odsState) {
+  ChannelDefaultOp::build(odsBuilder, odsState, StringAttr());
+}
+
 //===----------------------------------------------------------------------===//
 // flow.channel.split
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 716a043..3f335d2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1732,6 +1732,11 @@
     `:` type($result)
     attr-dict-with-keyword
   }];
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$group)>,
+    OpBuilder<(ins)>
+  ];
 }
 
 def FLOW_ChannelSplitOp : FLOW_Op<"channel.split", [
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
index 4e898cd..4f93ea7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 
 // clang-format off: must be included after all LLVM/MLIR headers.
@@ -219,4 +220,60 @@
   }
 }
 
+std::optional<IREE::Flow::CollectiveElementType>
+convertToFlowCollectiveElementType(Type type) {
+  if (type.isa<FloatType>()) {
+    if (type.isF16()) {
+      return IREE::Flow::CollectiveElementType::Float16;
+    }
+    if (type.isBF16()) {
+      return IREE::Flow::CollectiveElementType::BFloat16;
+    }
+    if (type.isF32()) {
+      return IREE::Flow::CollectiveElementType::Float32;
+    }
+    if (type.isF64()) {
+      return IREE::Flow::CollectiveElementType::Float64;
+    }
+  } else if (type.isa<IntegerType>()) {
+    if (type.isInteger(8)) {
+      if (type.isSignedInteger()) {
+        return IREE::Flow::CollectiveElementType::Sint8;
+      }
+      return IREE::Flow::CollectiveElementType::Uint8;
+    }
+    if (type.isInteger(16)) {
+      if (type.isSignedInteger()) {
+        return IREE::Flow::CollectiveElementType::Sint16;
+      }
+      return IREE::Flow::CollectiveElementType::Uint16;
+    }
+    if (type.isInteger(32)) {
+      if (type.isSignedInteger()) {
+        return IREE::Flow::CollectiveElementType::Sint32;
+      }
+      return IREE::Flow::CollectiveElementType::Uint32;
+    }
+    if (type.isInteger(64)) {
+      if (type.isSignedInteger()) {
+        return IREE::Flow::CollectiveElementType::Sint64;
+      }
+      return IREE::Flow::CollectiveElementType::Uint64;
+    }
+  }
+
+  return std::nullopt;
+}
+
+IREE::Flow::CollectiveElementTypeAttr
+getCollectiveElementTypeAttr(RankedTensorType type) {
+  std::optional<IREE::Flow::CollectiveElementType> collectiveElemType =
+      convertToFlowCollectiveElementType(type.getElementType());
+  if (!collectiveElemType) {
+    return IREE::Flow::CollectiveElementTypeAttr();
+  }
+  return IREE::Flow::CollectiveElementTypeAttr::get(type.getContext(),
+                                                    *collectiveElemType);
+}
+
 } // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
index 80151c5..3cc34ff 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h
@@ -171,4 +171,18 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h.inc" // IWYU pragma: keep
 // clang-format on
 
+namespace mlir::iree_compiler::IREE::Flow {
+
+// Create an attribute corresponding to the underlying numeric element type.
+// If there no such correspondence a null attribute is returned.
+IREE::Flow::CollectiveElementTypeAttr
+getCollectiveElementTypeAttr(RankedTensorType type);
+
+// Convert the numeric type `type` to the corresponding enum value.
+// If there is not correspondence nullopt is returned.
+std::optional<IREE::Flow::CollectiveElementType>
+convertToFlowCollectiveElementType(Type type);
+
+} // namespace mlir::iree_compiler::IREE::Flow
+
 #endif // IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index a4600af..a7a9e30 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -66,6 +66,7 @@
         "//runtime/src/iree/schemas/instruments:dispatch_def_c_fbs",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineToStandard",
+        "@llvm-project//mlir:AffineTransforms",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:ControlFlowDialect",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 4a94132..01381be 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -42,6 +42,7 @@
     ::PassesIncGen
     LLVMSupport
     MLIRAffineToStandard
+    MLIRAffineTransforms
     MLIRArithDialect
     MLIRBufferizationDialect
     MLIRControlFlowDialect
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 7146c1d..a732114 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -16,6 +16,7 @@
 #include "iree/compiler/Utils/PassUtils.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/Passes.h"
@@ -405,6 +406,10 @@
   FunctionLikeNest(passManager)
       .addPass(IREE::HAL::createElideRedundantCommandsPass);
 
+  // TODO: Maybe this should be a part of Affine lowering pass.
+  // Remove if it is added there.
+  // https://github.com/llvm/llvm-project/issues/78458
+  passManager.addPass(affine::createAffineExpandIndexOpsPass());
   // Fixup workgroup count calculations that may have used the affine dialect.
   // Kind of random here but can happen if the benchmarking code does things.
   passManager.addPass(mlir::createLowerAffinePass());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
index f7c94d6..0484ee7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD.bazel
@@ -23,6 +23,7 @@
     deps = [
         "//compiler/src/iree/compiler/Dialect/Stream/Conversion",
         "//compiler/src/iree/compiler/Dialect/Stream/IR",
+        "//compiler/src/iree/compiler/Dialect/Util/Conversion",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
index f8e7852..ff93dbd 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
@@ -24,6 +24,7 @@
     MLIRTransforms
     iree::compiler::Dialect::Stream::Conversion
     iree::compiler::Dialect::Stream::IR
+    iree::compiler::Dialect::Util::Conversion
     iree::compiler::Dialect::Util::IR
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index c29fa9a..c73549f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -8,7 +8,9 @@
 
 #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -226,6 +228,10 @@
   patterns
       .insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>(
           expansionState, typeConverter, context);
+  patterns.add<GenericConvertTypesPattern<IREE::Util::GlobalOp>,
+               GenericConvertTypesPattern<IREE::Util::GlobalLoadOp>,
+               GenericConvertTypesPattern<IREE::Util::GlobalStoreOp>>(
+      typeConverter, context);
 }
 
 void populateUtilToStreamConversionPatterns(MLIRContext *context,
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
index 6971f1e..0614d22 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
@@ -7,7 +7,11 @@
 #ifndef IREE_COMPILER_DIALECT_UTIL_CONVERSION_CONVERSIONPATTERNS_H_
 #define IREE_COMPILER_DIALECT_UTIL_CONVERSION_CONVERSIONPATTERNS_H_
 
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir::iree_compiler {
@@ -18,7 +22,7 @@
   LogicalResult
   matchAndRewrite(T op, typename T::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Type> resultTypes;
+    SmallVector<Type> newResultTypes;
     for (auto oldType : op.getOperation()->getResultTypes()) {
       SmallVector<Type> newTypes;
       if (failed(this->getTypeConverter()->convertType(oldType, newTypes))) {
@@ -26,13 +30,45 @@
       }
       // TODO(benvanik): figure out this silly expansion stuff. Seems broken.
       // resultTypes.append(newTypes);
-      resultTypes.push_back(newTypes.front());
+      newResultTypes.push_back(newTypes.front());
     }
-    auto newOp = rewriter.create<T>(op.getLoc(), resultTypes,
-                                    adaptor.getOperands(), op->getAttrs());
+
+    SmallVector<NamedAttribute> newAttrs;
+    if (failed(convertTypeAttributes(op->getAttrs(), newAttrs))) {
+      return rewriter.notifyMatchFailure(op,
+                                         "failed converting type attributes");
+    }
+
+    if (newResultTypes == op->getResultTypes() &&
+        op->getOperands() == adaptor.getOperands() &&
+        newAttrs == op->getAttrs()) {
+      return rewriter.notifyMatchFailure(op, "op does not need transformation");
+    }
+
+    auto newOp = rewriter.create<T>(op.getLoc(), newResultTypes,
+                                    adaptor.getOperands(), newAttrs);
     rewriter.replaceOp(op, newOp->getResults());
     return success();
   }
+
+protected:
+  LogicalResult convertTypeAttributes(ArrayRef<NamedAttribute> attrs,
+                                      SmallVector<NamedAttribute> &res) const {
+    for (NamedAttribute attr : attrs) {
+      TypeAttr oldType = attr.getValue().dyn_cast<TypeAttr>();
+      if (!oldType) {
+        res.push_back(attr);
+        continue;
+      }
+
+      Type newType = this->getTypeConverter()->convertType(oldType.getValue());
+      if (!newType) {
+        return failure();
+      }
+      res.push_back(NamedAttribute(attr.getName(), TypeAttr::get(newType)));
+    }
+    return success();
+  }
 };
 
 template <typename OpT>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 6113201..823558d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -772,7 +772,6 @@
     custom<TypeOrAttr>($type, $initial_value)
   }];
 
-  let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins
       "StringRef":$name,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 806cd54..508ead4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -79,6 +79,10 @@
   FunctionLikeNest(passManager)
       .addPass(mlir::createLoopInvariantCodeMotionPass)
       .addPass(mlir::createConvertSCFToCFPass)
+      // TODO: Maybe this should be a part of Affine lowering pass.
+      // Remove if it is added there.
+      // https://github.com/llvm/llvm-project/issues/78458
+      .addPass(affine::createAffineExpandIndexOpsPass)
       .addPass(mlir::createLowerAffinePass)
       .addPass(mlir::arith::createArithUnsignedWhenEquivalentPass);
 
diff --git a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
index 043bfcf..c214d54 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
@@ -59,6 +59,7 @@
     deps = [
         ":PassHeaders",
         ":PassesIncGen",
+        "//compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
index 6f1716c..1f1dc4f 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -67,6 +67,7 @@
     MLIRTensorDialect
     MLIRTensorUtils
     MLIRTransforms
+    iree::compiler::Dialect::Flow::Conversion::MeshToFlow
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp b/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
index 08728c0..f4750b7 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/Passes.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/InputConversion/Common/Passes.h"
 
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassOptions.h"
@@ -22,6 +23,7 @@
   passManager.addPass(createIREEImportPublicPass());
   passManager.addPass(createImportMLProgramPass());
   passManager.addPass(createSanitizeModuleNamesPass());
+  passManager.addPass(IREE::Flow::createConvertMeshToFlowPass());
 }
 
 void registerCommonInputConversionPasses() {
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 637a66b..2324a70 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -37,6 +37,7 @@
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/Interfaces",
         "//compiler/src/iree/compiler/ConstEval",
+        "//compiler/src/iree/compiler/Dialect/Flow/Conversion/MeshToFlow",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h
index d9849da..93816cb 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h
@@ -18,6 +18,7 @@
 #include "iree/compiler/Bindings/Native/Transforms/Passes.h"
 #include "iree/compiler/Bindings/TFLite/Transforms/Passes.h"
 #include "iree/compiler/ConstEval/Passes.h"
+#include "iree/compiler/Dialect/Flow/Conversion/MeshToFlow/MeshToFlow.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
@@ -53,6 +54,7 @@
   GlobalOptimization::registerGlobalOptimizationPipeline();
   Preprocessing::registerPreprocessingPasses();
   IREE::Flow::registerFlowPasses();
+  IREE::Flow::registerMeshToFlowPasses();
   IREE::HAL::registerHALPasses();
   IREE::HAL::Inline::registerHALInlinePasses();
   IREE::HAL::Loader::registerHALLoaderPasses();
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index ddc310d..873a311 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
@@ -67,6 +68,7 @@
                   LLVM::LLVMDialect,
                   linalg::LinalgDialect,
                   math::MathDialect,
+                  mesh::MeshDialect,
                   memref::MemRefDialect,
                   ml_program::MLProgramDialect,
                   pdl::PDLDialect,
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index b3e1de8..4ebbdaf 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -54,8 +54,6 @@
 
   // Affine
   affine::registerAffinePasses();
-  affine::registerAffineLoopFusionPass();
-  affine::registerAffinePipelineDataTransferPass();
   registerConvertAffineToStandardPass();
 
   // Arm SME
diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel
index 6b8ef5e..c094aa1 100644
--- a/compiler/src/iree/compiler/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Utils/BUILD.bazel
@@ -33,11 +33,16 @@
         "ElementPackingUtils.h",
         "EquivalenceUtils.h",
         "FlatbufferUtils.h",
+        "Folding.h",
         "IndexSet.h",
+        "Indexing.h",
         "ModuleUtils.h",
+        "OpVisitor.h",
         "OptionUtils.h",
         "PassUtils.h",
         "PatternUtils.h",
+        "Permutation.h",
+        "SmallVectorDenseMapInfo.h",
         "StringUtils.h",
         "ToolUtils.h",
         "TracingUtils.h",
diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt
index 01939cb..042552a 100644
--- a/compiler/src/iree/compiler/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt
@@ -18,11 +18,16 @@
     "ElementPackingUtils.h"
     "EquivalenceUtils.h"
     "FlatbufferUtils.h"
+    "Folding.h"
     "IndexSet.h"
+    "Indexing.h"
     "ModuleUtils.h"
+    "OpVisitor.h"
     "OptionUtils.h"
     "PassUtils.h"
     "PatternUtils.h"
+    "Permutation.h"
+    "SmallVectorDenseMapInfo.h"
     "StringUtils.h"
     "ToolUtils.h"
     "TracingUtils.h"
diff --git a/compiler/src/iree/compiler/Utils/Folding.h b/compiler/src/iree/compiler/Utils/Folding.h
new file mode 100644
index 0000000..397f983
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Folding.h
@@ -0,0 +1,32 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_FOLDING_H_
+#define IREE_COMPILER_UTILS_FOLDING_H_
+
+#include <iterator>
+#include <utility>
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/OpDefinition.h"
+namespace mlir::iree_compiler {
+
+// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`.
+template <typename Range, typename OutIt>
+void toOpFoldResults(Range &&range, OutIt outIt) {
+  llvm::transform(std::forward<Range>(range), outIt,
+                  [](auto v) { return OpFoldResult(v); });
+}
+
+template <typename Range>
+SmallVector<OpFoldResult> toOpFoldResults(Range &&range) {
+  SmallVector<OpFoldResult> res;
+  toOpFoldResults(std::forward<Range>(range), std::back_inserter(res));
+  return res;
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_FOLDING_H_
diff --git a/compiler/src/iree/compiler/Utils/Indexing.h b/compiler/src/iree/compiler/Utils/Indexing.h
new file mode 100644
index 0000000..58484e8
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Indexing.h
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_INDEXING_H_
+#define IREE_COMPILER_UTILS_INDEXING_H_
+
+#include <algorithm>
+#include <iterator>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler {
+
+// Construct IR that extracts the linear index form a multi-index according to
+// a shape.
+inline OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                         ArrayRef<OpFoldResult> shape,
+                                         ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == shape.size());
+  SmallVector<AffineExpr> shapeAffine;
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+  }
+
+  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(stridesAffine.size());
+  llvm::transform(stridesAffine, std::back_inserter(strides),
+                  [&builder, &shape](AffineExpr strideExpr) {
+                    return affine::makeComposedFoldedAffineApply(
+                        builder, builder.getLoc(), strideExpr, shape);
+                  });
+
+  auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+      OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+  return affine::makeComposedFoldedAffineApply(
+      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_INDEXING_H_
diff --git a/compiler/src/iree/compiler/Utils/OpVisitor.h b/compiler/src/iree/compiler/Utils/OpVisitor.h
new file mode 100644
index 0000000..247bc7a
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/OpVisitor.h
@@ -0,0 +1,109 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <functional>
+#include <type_traits>
+#include <utility>
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
+
+namespace mlir::iree_compiler {
+
+// Calls a collection of callbacks for an operation.
+// For concrete ops each callback is called only if its concrete op type
+// matches the given operation.
+//
+// Operation* op = ...;
+// OpVisitorCollection visitors = ...;
+// visitors.emplaceVisitors<MyVisitor1, MyVisitor1>(
+//   visitorConstructorArg1,
+//   visitorConstructorArg2);
+// visitors.insertVisitors(
+//   [](ConcreteOp op) { ... }, // Call only for ConcreteOp
+//   [](Operation* op) { ... } // Call for all op types
+// );
+// visitors(op);
+struct OpVisitorCollection {
+  void operator()(Operation *op) {
+    for (auto &fn : everyOpFns) {
+      fn(op);
+    }
+
+    auto it = opFnMap.find(op->getName().getTypeID());
+    if (it == opFnMap.end()) {
+      return;
+    }
+
+    for (auto &fn : it->second) {
+      fn(op);
+    }
+  }
+
+  template <typename Op, typename Fn>
+  void insertVisitor(Fn &&fn) {
+    opFnMap[TypeID::get<Op>()].emplace_back(
+        ConcreteOpFn<Op, Fn>(std::forward<Fn>(fn)));
+  }
+
+  template <typename Fn,
+            typename = std::enable_if_t<!std::is_invocable_v<Fn, Operation *>>>
+  void insertVisitor(Fn &&fn) {
+    insertVisitor<
+        std::decay_t<typename llvm::function_traits<Fn>::template arg_t<0>>>(
+        std::forward<Fn>(fn));
+  }
+
+  template <typename Fn,
+            typename = std::enable_if_t<std::is_invocable_v<Fn, Operation *>>,
+            typename = void>
+  void insertVisitor(Fn &&fn) {
+    everyOpFns.emplace_back(std::forward<Fn>(fn));
+  }
+
+  template <typename Fn, typename... RestFns>
+  void insertVisitors(Fn &&fn, RestFns &&...restFns) {
+    insertVisitor(fn);
+    insertVisitors(std::forward<RestFns>(restFns)...);
+  }
+
+  template <typename Fn>
+  void insertVisitors(Fn &&fn) {
+    insertVisitor(fn);
+  }
+
+  template <typename... Fns, typename... ConstructorArgs>
+  void emplaceVisitors(ConstructorArgs &&...args) {
+    (emplaceVisitor<Fns>(std::forward<ConstructorArgs>(args)...), ...);
+  }
+
+private:
+  template <typename Fn, typename... ConstructorArgs>
+  void emplaceVisitor(ConstructorArgs &&...args) {
+    insertVisitor(Fn(std::forward<ConstructorArgs>(args)...));
+  }
+
+  template <typename Op, typename Fn>
+  struct ConcreteOpFn {
+
+    template <typename FnArg>
+    ConcreteOpFn(FnArg &&fn) : fn(fn) {}
+
+    void operator()(Operation *op) { fn(llvm::cast<Op>(op)); }
+
+  private:
+    Fn fn;
+  };
+
+  DenseMap<TypeID, SmallVector<std::function<void(Operation *)>>> opFnMap;
+  SmallVector<std::function<void(Operation *)>> everyOpFns;
+};
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Utils/Permutation.h b/compiler/src/iree/compiler/Utils/Permutation.h
new file mode 100644
index 0000000..0cd3df1
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/Permutation.h
@@ -0,0 +1,107 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_PERMUTATION_H_
+#define IREE_COMPILER_UTILS_PERMUTATION_H_
+
+#include <iterator>
+#include <type_traits>
+#include <utility>
+
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir::iree_compiler {
+
+// Example: values = (1, 2, 3), permutation = (2, 1, 0)
+// output = (3, 2, 1).
+// TODO: make applyPermutation at mlir/Dialect/Utils/IndexingUtils.h in MLIR
+// generic and use it instead.
+template <typename ValuesIt, typename PermutationRange, typename OutIt>
+void permute(ValuesIt valuesBegin, ValuesIt valuesEnd,
+             PermutationRange &&permutation, OutIt outBegin) {
+  assert(std::distance(valuesBegin, valuesEnd) >= llvm::adl_size(permutation));
+  llvm::transform(permutation, outBegin,
+                  [valuesBegin](auto i) { return valuesBegin[i]; });
+}
+
+template <typename ValuesRange, typename PermutationRange, typename OutIt>
+void permute(ValuesRange &&values, PermutationRange &&permutation,
+             OutIt outBegin) {
+  permute(llvm::adl_begin(std::forward<ValuesRange>(values)),
+          llvm::adl_end(std::forward<ValuesRange>(values)), permutation,
+          outBegin);
+}
+
+template <typename T, typename Index>
+SmallVector<T> permute(ArrayRef<T> values, ArrayRef<Index> permutation) {
+  SmallVector<T> res;
+  permute(values, permutation, std::back_inserter(res));
+  return res;
+}
+
+// Check if the range is a sequence of numbers starting from 0.
+// Example: (0, 1, 2, 3).
+// TODO: Make the isIdentityPermutation in MLIR more generic to not only
+// accept int64_t and delete this.
+template <typename Range>
+bool isIdentityPermutation(Range &&range) {
+  using ValueType = std::decay_t<decltype(*std::begin(range))>;
+  ValueType i = static_cast<ValueType>(0);
+  return llvm::all_of(std::forward<Range>(range), [&i](ValueType v) {
+    bool res = (v == i);
+    ++i;
+    return res;
+  });
+}
+
+// Make a permutation that moves src to dst.
+// Example with size = 5, src = 1, dst = 3.
+// output = (0, 2, 3, 1, 4).
+// Example with size = 2, src = 0, dst = 1.
+// output = (1, 0).
+template <typename T, typename OutIt>
+void makeMovePermutation(T size, T src, T dst, OutIt outBegin) {
+  assert(src < size && dst < size && size > static_cast<T>(0));
+  T outSize = 0;
+  for (T i = 0; i < size; ++i) {
+    if (outSize == dst) {
+      *outBegin = src;
+      ++outBegin;
+      ++outSize;
+    }
+    if (i == src) {
+      ++i;
+      if (i >= size) {
+        break;
+      }
+    }
+
+    *outBegin = i;
+    ++outBegin;
+    ++outSize;
+  }
+
+  if (size != outSize) {
+    *outBegin = src;
+    ++outBegin;
+  }
+}
+
+template <typename T>
+SmallVector<T> makeMovePermutation(T size, T src, T dst) {
+  SmallVector<T> res;
+  res.reserve(size);
+  makeMovePermutation(size, src, dst, std::back_inserter(res));
+  return res;
+}
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_PERMUTATION_H_
diff --git a/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h b/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h
new file mode 100644
index 0000000..6bf3f79
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/SmallVectorDenseMapInfo.h
@@ -0,0 +1,52 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
+#define IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
+
+#include <numeric>
+#include <utility>
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace llvm {
+template <typename T, unsigned N>
+struct DenseMapInfo<SmallVector<T, N>> {
+  static SmallVector<T, N> getEmptyKey() {
+    return SmallVector<T, N>(1, llvm::DenseMapInfo<T>::getEmptyKey());
+  }
+
+  static SmallVector<T, N> getTombstoneKey() {
+    return SmallVector<T, N>(1, llvm::DenseMapInfo<T>::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const SmallVector<T, N> &v) {
+    hash_code hash = llvm::DenseMapInfo<T>::getHashValue(
+        llvm::DenseMapInfo<T>::getEmptyKey());
+    std::accumulate(v.begin(), v.end(), hash,
+                    [](hash_code hash, const T &element) {
+                      return hash_combine(hash, element);
+                    });
+    return hash;
+  }
+
+  static bool isEqual(const SmallVector<T, N> &lhs,
+                      const SmallVector<T, N> &rhs) {
+    if (lhs.size() != rhs.size()) {
+      return false;
+    }
+
+    return llvm::all_of_zip(lhs, rhs, [](const T &lhs, const T &rhs) {
+      return DenseMapInfo<T>::isEqual(lhs, rhs);
+    });
+  }
+};
+} // namespace llvm
+
+#endif // IREE_COMPILER_UTILS_SMALLVECTORDENSEMAPINFO_H_
diff --git a/compiler/src/iree/compiler/Utils/test/BUILD.bazel b/compiler/src/iree/compiler/Utils/test/BUILD.bazel
new file mode 100644
index 0000000..f48694f
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/BUILD.bazel
@@ -0,0 +1,22 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_test")
+
+package(
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_compiler_cc_test(
+    name = "utils",
+    testonly = True,
+    srcs = ["UtilsTest.cpp"],
+    deps = [
+        "//compiler/src/iree/compiler/Utils",
+        "@com_google_googletest//:gtest",
+    ],
+)
diff --git a/compiler/src/iree/compiler/Utils/test/CMakeLists.txt b/compiler/src/iree/compiler/Utils/test/CMakeLists.txt
new file mode 100644
index 0000000..7c26dca
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# compiler/src/iree/compiler/Utils/test/BUILD.bazel                            #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_test(
+  NAME
+    utils
+  SRCS
+    "UtilsTest.cpp"
+  DEPS
+    gmock
+    gtest
+    iree::compiler::Utils
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp b/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp
new file mode 100644
index 0000000..44810ca
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/test/UtilsTest.cpp
@@ -0,0 +1,26 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "iree/compiler/Utils/Permutation.h"
+
+using namespace mlir::iree_compiler;
+using namespace testing;
+
+int main(int argc, char **argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
+
+TEST(Permutation, MakeMovePermutation) {
+  EXPECT_THAT(makeMovePermutation(1, 0, 0), ElementsAre(0));
+  EXPECT_THAT(makeMovePermutation(2, 0, 1), ElementsAre(1, 0));
+  EXPECT_THAT(makeMovePermutation(5, 1, 3), ElementsAre(0, 2, 3, 1, 4));
+  EXPECT_THAT(makeMovePermutation(3, 1, 2), ElementsAre(0, 2, 1));
+  EXPECT_THAT(makeMovePermutation(3, 2, 0), ElementsAre(2, 0, 1));
+}
diff --git a/tests/e2e/collectives/CMakeLists.txt b/tests/e2e/collectives/CMakeLists.txt
index c98f3d9..b0d1ab1 100644
--- a/tests/e2e/collectives/CMakeLists.txt
+++ b/tests/e2e/collectives/CMakeLists.txt
@@ -13,7 +13,7 @@
 if(IREE_TARGET_BACKEND_CUDA AND IREE_HAL_DRIVER_CUDA)
   iree_py_test(
     NAME
-      collectives_test_single_gpu
+      collectives_test_1_gpu
     SRCS
       "collectives_test.py"
     ARGS
@@ -27,7 +27,7 @@
 
   iree_py_test(
     NAME
-      collectives_test_multi_gpu
+      collectives_test_2_gpus
     SRCS
       "collectives_test.py"
     ARGS
@@ -41,4 +41,19 @@
       # To properly test for 2 ranks we need 2 GPUs.
       "requires-multiple-devices"
   )
+
+  iree_py_test(
+    NAME
+      collectives_test_4_gpus
+    SRCS
+      "collectives_test.py"
+    ARGS
+      "-k" "FourRanks"
+      "--target_backend=cuda"
+      "--driver=cuda"
+    LABELS
+      "requires-gpu-nvidia"
+      "driver=cuda"
+      "requires-multiple-devices"
+  )
 endif()
diff --git a/tests/e2e/collectives/collectives_test.py b/tests/e2e/collectives/collectives_test.py
index a28e082..fa995f6 100644
--- a/tests/e2e/collectives/collectives_test.py
+++ b/tests/e2e/collectives/collectives_test.py
@@ -11,13 +11,13 @@
 import iree.runtime
 from iree.runtime.array_interop import DeviceArray
 import os
-from typing import List, Tuple, TypeVar
+from typing import List, Tuple
 import numpy as np
 import tempfile
 import subprocess
 import test_utils
 
-ArrayLike = TypeVar("ArrayLike")
+ArrayLike = object
 
 
 def parse_args():
@@ -92,7 +92,10 @@
 
 
 def run_test(
-    mlir: str, inputs: List[List[ArrayLike]], expected_outputs: List[List[ArrayLike]]
+    mlir: str,
+    inputs: List[List[ArrayLike]],
+    expected_outputs: List[List[ArrayLike]],
+    mlir_input_type: iree.compiler.InputType | str = iree.compiler.InputType.AUTO,
 ):
     with tempfile.TemporaryDirectory() as tmp_dir:
         module_filepath = os.path.join(tmp_dir, "module.vmfb")
@@ -100,14 +103,15 @@
             input_str=mlir,
             output_file=module_filepath,
             target_backends=[args.target_backend],
-            input_type="stablehlo",
+            input_type=mlir_input_type,
+            extra_args=["--iree-hal-cuda-llvm-target-arch", "sm_53"],
         )
 
         num_ranks = len(inputs)
         # Ranks on the 0th axis.
         outputs = run_ranks(
             num_ranks=num_ranks,
-            function="all_reduce_sum",
+            function="main",
             driver=args.driver,
             module_filepath=module_filepath,
             inputs=inputs,
@@ -125,7 +129,7 @@
         all_reduce([1, 2, 3, 4]) == [1, 2, 3, 4].
         """
         stablehlo_mlir = """
-            func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> {
+            func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
             %out = "stablehlo.all_reduce"(%input) ({
                 ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
                 %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
@@ -138,7 +142,60 @@
         """
         inputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
         expected_outputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
-        run_test(mlir=stablehlo_mlir, inputs=inputs, expected_outputs=expected_outputs)
+        run_test(
+            mlir=stablehlo_mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+            mlir_input_type=iree.compiler.InputType.STABLEHLO,
+        )
+
+    def test_mesh_all_reduce(self):
+        """
+        Test trivial case of all_reduce with one rank.
+        all_reduce([1, 2, 3, 4]) == [1, 2, 3, 4].
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 1)
+
+            func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<4xf32> -> tensor<4xf32>
+            return %out : tensor<4xf32>
+            }
+        """
+        inputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
+        expected_outputs = [[np.array([1, 2, 3, 4], dtype=np.float32)]]
+        run_test(mlir=mlir, inputs=inputs, expected_outputs=expected_outputs)
+
+    def test_mesh_all_to_all(self):
+        """
+        Test on a 1D device mesh, grouping along mesh dimension 0.
+
+        Device contents before operation:
+        [[1, 2], [3, 4]]
+
+        Device contents after operation:
+        [[1, 2], [3, 4]]
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 1)
+
+            func.func @main(%input : tensor<2x2xf32>) -> tensor<2x2xf32> {
+            %out = mesh.all_to_all %input on @mesh mesh_axes = [0]
+              split_axis = 0 concat_axis = 1 : tensor<2x2xf32> -> tensor<2x2xf32>
+            return %out : tensor<2x2xf32>
+            }
+        """
+        inputs = [
+            [np.array([[1, 2], [3, 4]], dtype=np.float32)],
+        ]
+        expected_outputs = [
+            [np.array([[1, 2], [3, 4]], dtype=np.float32)],
+        ]
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
 
 
 class TwoRanks(unittest.TestCase):
@@ -147,7 +204,7 @@
         Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
         """
         stablehlo_mlir = """
-            func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> {
+            func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
             %out = "stablehlo.all_reduce"(%input) ({
                 ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
                 %sum = stablehlo.add %arg0, %arg1 : tensor<f32>
@@ -162,8 +219,234 @@
             [np.array([1, 2, 3, 4], dtype=np.float32)],
             [np.array([5, 6, 7, 8], dtype=np.float32)],
         ]
-        expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]]
-        run_test(mlir=stablehlo_mlir, inputs=inputs, expected_outputs=expected_outputs)
+        expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+        run_test(
+            mlir=stablehlo_mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+            mlir_input_type=iree.compiler.InputType.STABLEHLO,
+        )
+
+    def test_mesh_all_reduce_1d_mesh(self):
+        """
+        Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 2)
+
+            func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<4xf32> -> tensor<4xf32>
+            return %out : tensor<4xf32>
+            }
+        """
+        inputs = [
+            [np.array([1, 2, 3, 4], dtype=np.float32)],
+            [np.array([5, 6, 7, 8], dtype=np.float32)],
+        ]
+        expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
+
+    def test_mesh_all_reduce_3d_mesh(self):
+        """
+        Test all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) == [6, 8, 10, 12].
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 1x2x1)
+
+            func.func @main(%input : tensor<4xf32>) -> tensor<4xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<4xf32> -> tensor<4xf32>
+            return %out : tensor<4xf32>
+            }
+        """
+        inputs = [
+            [np.array([1, 2, 3, 4], dtype=np.float32)],
+            [np.array([5, 6, 7, 8], dtype=np.float32)],
+        ]
+        expected_outputs = [[np.array([6, 8, 10, 12], dtype=np.float32)]] * 2
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
+
+
+class FourRanks(unittest.TestCase):
+    def test_mesh_all_reduce_on_2d_mesh_along_axis_1(self):
+        """
+        Test on a 2x2 device mesh reduction along dimension 1.
+        Mesh devices:
+        axis 1
+        ------>
+        0 1
+        2 3
+
+        Device contents before operation:
+        [1, 2] [3, 4]
+        [5, 6] [7, 8]
+
+        Device contents after operation:
+        [ 4,  6] [ 4,  6]
+        [12, 14] [12, 14]
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 2x2)
+
+            func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<2xf32> -> tensor<2xf32>
+            return %out : tensor<2xf32>
+            }
+        """
+        inputs = [
+            [np.array([1, 2], dtype=np.float32)],
+            [np.array([3, 4], dtype=np.float32)],
+            [np.array([5, 6], dtype=np.float32)],
+            [np.array([7, 8], dtype=np.float32)],
+        ]
+        expected_outputs = [
+            [np.array([4, 6], dtype=np.float32)],
+            [np.array([4, 6], dtype=np.float32)],
+            [np.array([12, 14], dtype=np.float32)],
+            [np.array([12, 14], dtype=np.float32)],
+        ]
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
+
+    def test_mesh_all_reduce_on_2d_mesh_along_axis_0(self):
+        """
+        Test on a 2x2 device mesh reduction along dimension 0.
+        Mesh devices:
+        axis 1
+        ------>
+        0 1
+        2 3
+
+        Device contents before operation:
+        [1, 2] [3, 4]
+        [5, 6] [7, 8]
+
+        Device contents after operation:
+        [6, 8] [10, 12]
+        [6, 8] [10, 12]
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 2x2)
+
+            func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [0] : tensor<2xf32> -> tensor<2xf32>
+            return %out : tensor<2xf32>
+            }
+        """
+        inputs = [
+            [np.array([1, 2], dtype=np.float32)],
+            [np.array([3, 4], dtype=np.float32)],
+            [np.array([5, 6], dtype=np.float32)],
+            [np.array([7, 8], dtype=np.float32)],
+        ]
+        expected_outputs = [
+            [np.array([6, 8], dtype=np.float32)],
+            [np.array([10, 12], dtype=np.float32)],
+            [np.array([6, 8], dtype=np.float32)],
+            [np.array([10, 12], dtype=np.float32)],
+        ]
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
+
+    def test_mesh_all_reduce_on_4d_mesh_along_1_axis(self):
+        """
+        Test on a 1x2x1x2 device mesh reduction along mesh dimension 1.
+        Mesh devices:
+        axis 3
+        ------>
+        0 1     | axis 1
+        2 3     ↓
+
+        Device contents before operation:
+        [1, 2] [3, 4]
+        [5, 6] [7, 8]
+
+        Device contents after operation:
+        [6, 8] [10, 12]
+        [6, 8] [10, 12]
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 1x2x1x2)
+
+            func.func @main(%input : tensor<2xf32>) -> tensor<2xf32> {
+            %out = mesh.all_reduce %input on @mesh mesh_axes = [1] : tensor<2xf32> -> tensor<2xf32>
+            return %out : tensor<2xf32>
+            }
+        """
+        inputs = [
+            [np.array([1, 2], dtype=np.float32)],
+            [np.array([3, 4], dtype=np.float32)],
+            [np.array([5, 6], dtype=np.float32)],
+            [np.array([7, 8], dtype=np.float32)],
+        ]
+        expected_outputs = [
+            [np.array([6, 8], dtype=np.float32)],
+            [np.array([10, 12], dtype=np.float32)],
+            [np.array([6, 8], dtype=np.float32)],
+            [np.array([10, 12], dtype=np.float32)],
+        ]
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
+
+    def test_mesh_all_to_all_on_4d_mesh_along_1_axis(self):
+        """
+        Test on a 1x2x1x2 device mesh, grouping along mesh dimension 1.
+        Mesh devices:
+        axis 3
+        ------>
+        0 1     | axis 1
+        2 3     ↓
+
+        Device contents before operation:
+        [[1], [2]]  [[3], [4]]
+        [[5], [6]]  [[7], [8]]
+
+        Device contents after operation:
+        [[1, 5]]  [[3, 7]]
+        [[2, 6]]  [[4, 8]]
+        """
+        mlir = """
+            mesh.mesh @mesh(shape = 1x2x1x2)
+
+            func.func @main(%input : tensor<2x1xf32>) -> tensor<1x2xf32> {
+            %out = mesh.all_to_all %input on @mesh mesh_axes = [1]
+              split_axis = 0 concat_axis = 1 : tensor<2x1xf32> -> tensor<1x2xf32>
+            return %out : tensor<1x2xf32>
+            }
+        """
+        inputs = [
+            [np.array([[1], [2]], dtype=np.float32)],
+            [np.array([[3], [4]], dtype=np.float32)],
+            [np.array([[5], [6]], dtype=np.float32)],
+            [np.array([[7], [8]], dtype=np.float32)],
+        ]
+        expected_outputs = [
+            [np.array([[1, 5]], dtype=np.float32)],
+            [np.array([[3, 7]], dtype=np.float32)],
+            [np.array([[2, 6]], dtype=np.float32)],
+            [np.array([[4, 8]], dtype=np.float32)],
+        ]
+        run_test(
+            mlir=mlir,
+            inputs=inputs,
+            expected_outputs=expected_outputs,
+        )
 
 
 if __name__ == "__main__":