Making flow.dispatch/stream.async.dispatch take multiple symbols. (#15295)

stream.cmd.dispatch already supported this for making external
hal.executable.variant ops work and by making this consistent up the
stack it allows for the use of hal.executable.variant all the way up in
flow. This will allow hal.dispatch.extern to expand to HAL ops instead
of flow.executable and avoid the need for plumbing all of the
HAL-specific behavior through those layers.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 0f13fc2..877b318 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -703,6 +703,34 @@
 }
 
 //===----------------------------------------------------------------------===//
+// flow.dispatch
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct DeduplicateDispatchEntryRefs final
+    : public OpRewritePattern<DispatchOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(DispatchOp dispatchOp,
+                                PatternRewriter &rewriter) const override {
+    auto originalAttr = dispatchOp.getEntryPointsAttr();
+    auto newAttr = deduplicateArrayElements(originalAttr);
+    if (newAttr == originalAttr)
+      return failure();
+    rewriter.updateRootInPlace(
+        dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); });
+    return success();
+  }
+};
+
+} // namespace
+
+void DispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                             MLIRContext *context) {
+  results.insert<DeduplicateDispatchEntryRefs>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // Tensor ops
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 26dd4e9..59a02c4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -333,6 +333,44 @@
 }
 
 //===----------------------------------------------------------------------===//
+// custom<DispatchEntryPoints>($entry_points)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDispatchEntryPoints(OpAsmParser &parser,
+                                            ArrayAttr &entryPointAttrsArray) {
+  SmallVector<Attribute> entryPointAttrs;
+  if (succeeded(parser.parseOptionalLBrace())) {
+    do {
+      SymbolRefAttr entryPointAttr;
+      if (failed(parser.parseAttribute(entryPointAttr)))
+        return failure();
+      entryPointAttrs.push_back(entryPointAttr);
+    } while (succeeded(parser.parseOptionalComma()));
+    if (failed(parser.parseRBrace()))
+      return failure();
+  } else {
+    SymbolRefAttr entryPointAttr;
+    if (failed(parser.parseAttribute(entryPointAttr)))
+      return failure();
+    entryPointAttrs.push_back(entryPointAttr);
+  }
+  entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs);
+  return success();
+}
+
+static void printDispatchEntryPoints(OpAsmPrinter &p, Operation *op,
+                                     ArrayAttr entryPointAttrs) {
+  if (entryPointAttrs.size() == 1) {
+    p.printAttribute(entryPointAttrs.getValue().front());
+  } else {
+    p << '{';
+    llvm::interleaveComma(entryPointAttrs, p.getStream(),
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << '}';
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // flow.dispatch.region
 //===----------------------------------------------------------------------===//
 
@@ -1329,7 +1367,7 @@
                        ValueRange operands, ValueRange operandDims,
                        ArrayAttr tiedOperands,
                        ArrayRef<NamedAttribute> attributes) {
-  state.addAttribute("entry_point", entryPoint);
+  state.addAttribute("entry_points", builder.getArrayAttr(entryPoint));
   state.addOperands(workload);
   state.addTypes(resultTypes);
   state.addOperands(operands);
@@ -1349,51 +1387,72 @@
                      }));
 }
 
-StringAttr DispatchOp::executable() {
-  return getEntryPoint().getRootReference();
-}
-
 FunctionType DispatchOp::getEntryPointType() {
   SmallVector<Type, 8> argTypes(operand_type_range{getArguments()});
   return FunctionType::get(getContext(), argTypes, getResultTypes());
 }
 
+std::string DispatchOp::getEntryPointName() {
+  // Pick the first entry point we have. The common case is we only have one
+  // but frontends may provide multiple variants - they're all likely the
+  // same name but with slight differences and enough for a user to know what's
+  // happening.
+  auto anyEntryPoint = *getEntryPointRefs().begin();
+  std::string entryPointName =
+      anyEntryPoint.getRootReference().getValue().str();
+  for (FlatSymbolRefAttr nestedRef : anyEntryPoint.getNestedReferences()) {
+    entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
+  }
+  return entryPointName;
+}
+
 std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
   return getODSOperandIndexAndLength(1); // $operands
 }
 
 LogicalResult DispatchOp::verify() {
   Operation *op = getOperation();
+
+  if (getEntryPoints().empty()) {
+    return op->emitOpError("at least one entry point reference is required");
+  }
+
   if (failed(verifyOpDynamicDims(op, getArguments(), getArgumentDims())) ||
       failed(verifyOpDynamicDims(op, getResults(), getResultDims()))) {
     return failure();
   }
+
   return success();
 }
 
 LogicalResult DispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   Operation *op = getOperation();
-  auto exportOp =
-      symbolTable.lookupNearestSymbolFrom<IREE::Flow::ExecutableExportOp>(
-          op, getEntryPoint());
-  if (!exportOp) {
-    // TODO(benvanik): there are a lot of tests that are assuming this is not
-    // verified. We'll need to go add dummy executables for all of them. Today
-    // we just bail on the verifier if the symbol isn't found.
-    //
-    // Should be:
-    //   return op->emitOpError() << "undefined entry point: " <<
-    //   getEntryPoint();
-    return success();
+  auto entryPointRefs = getEntryPointRefs();
+  if (entryPointRefs.empty()) {
+    return emitOpError() << "at least one entry point must be defined";
   }
+  for (auto entryPointAttr : entryPointRefs) {
+    auto exportOp =
+        symbolTable.lookupNearestSymbolFrom<IREE::Flow::ExecutableExportOp>(
+            op, entryPointAttr);
+    if (!exportOp) {
+      // TODO(benvanik): there are a lot of tests that are assuming this is not
+      // verified. We'll need to go add dummy executables for all of them. Today
+      // we just bail on the verifier if the symbol isn't found.
+      //
+      // Should be:
+      //   return op->emitOpError() << "undefined entry point: " <<
+      //   getEntryPoint();
+      return success();
+    }
 
-  // Verify that the workload parameters captured match the target export.
-  if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
-    return failure();
+    // Verify that the workload parameters captured match the target export.
+    if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
+      return failure();
+    }
+
+    // TODO(benvanik): verify that the target function has matching operands.
   }
-
-  // TODO(benvanik): verify that the target function has matching operands.
-
   return success();
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index e3d976c..2234870 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -766,7 +766,7 @@
 
   let arguments = (ins
     Variadic<FLOW_Dim>:$workload,
-    SymbolRefAttr:$entry_point,
+    SymbolRefArrayAttr:$entry_points,
     Variadic<AnyType>:$arguments,
     FLOW_ShapeDynamicDims:$argument_dims,
     FLOW_ShapeDynamicDims:$result_dims,
@@ -777,7 +777,7 @@
   );
 
   let assemblyFormat = [{
-    $entry_point
+    custom<DispatchEntryPoints>($entry_points)
     (`[` $workload^ `]`)? ``
     `(` $arguments `)` attr-dict `:`
     custom<ShapedFunctionType>(ref($arguments),
@@ -827,9 +827,19 @@
   ];
 
   let extraClassDeclaration = [{
-    StringAttr executable();
     FunctionType getEntryPointType();
 
+    auto getEntryPointRefs() {
+      return getEntryPoints().getAsRange<SymbolRefAttr>();
+    }
+    void forEachEntryPointAttr(std::function<void(SymbolRefAttr)> fn) {
+      for (auto entryPointAttr : getEntryPointRefs()) fn(entryPointAttr);
+    }
+
+    // Returns a human-friendly string name for what is being dispatched.
+    // May not be unique or a valid reference to an executable.
+    std::string getEntryPointName();
+
     // StreamableOpInterface:
     bool isTransfer() { return false; }
 
@@ -841,6 +851,7 @@
     }
   }];
 
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
index eebae24..7d0138e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -1,12 +1,12 @@
 // RUN: iree-opt --split-input-file %s --verify-diagnostics | FileCheck %s
 
 flow.executable @ex0 {
+  flow.executable.export @dispatch_fn
   builtin.module {
     func.func @dispatch_fn(%cst : index, %arg0 : tensor<4xf32>) -> tensor<4xf32> {
       return %arg0 : tensor<4xf32>
     }
   }
-  flow.executable.export @dispatch_fn
 }
 
 // CHECK-LABEL: @dispatch
@@ -21,18 +21,28 @@
 // -----
 
 flow.executable private @ex0 {
+  flow.executable.export public @dispatch_a
+  flow.executable.export public @dispatch_b
+}
+
+// CHECK-LABEL: @dispatchWithMultipleRefs
+func.func @dispatchWithMultipleRefs(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  // CHECK: = flow.dispatch {@ex0::@dispatch_a, @ex0::@dispatch_b}(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch {@ex0::@dispatch_a, @ex0::@dispatch_b}(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+
+// -----
+
+flow.executable private @ex0 {
   flow.executable.export public @dispatch workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
     flow.return %arg0, %arg1, %arg0 : index, index, index
   }
-  builtin.module {
-    func.func @dispatch() {
-      return
-    }
-  }
 }
 
-// CHECK-LABEL: @asyncDispatchWithWorkgroupCount
-func.func @asyncDispatchWithWorkgroupCount(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
+// CHECK-LABEL: @dispatchWithWorkgroupCount
+func.func @dispatchWithWorkgroupCount(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   // CHECK: = flow.dispatch @ex0::@dispatch[%c1, %c2](%arg0, %arg1) : (tensor<4xf32>, index) -> tensor<4xf32>
@@ -46,14 +56,9 @@
   flow.executable.export public @dispatch workgroups(%arg0: index) -> (index, index, index) {
     flow.return %arg0, %arg0, %arg0 : index, index, index
   }
-  builtin.module {
-    func.func @dispatch() {
-      return
-    }
-  }
 }
 
-func.func @asyncDispatchWithInvalidWorkload(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
+func.func @dispatchWithInvalidWorkload(%arg0: tensor<4xf32>, %arg1: index) -> tensor<4xf32> {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   // expected-error @+1 {{op workload mismatch; entry point expects 1 arguments but dispatch provides 2}}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
index 30e32aa..851acf5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
@@ -360,10 +360,17 @@
     // new symbol name.
     for (auto funcLikeOp : getOperation().getOps<FunctionOpInterface>()) {
       funcLikeOp->walk([&](IREE::Flow::DispatchOp dispatchOp) {
-        auto it = entryPointRefReplacements.find(dispatchOp.getEntryPoint());
-        if (it != entryPointRefReplacements.end()) {
-          dispatchOp.setEntryPointAttr(llvm::cast<SymbolRefAttr>(it->second));
+        SmallVector<Attribute> replacementRefs;
+        for (auto originalRef : dispatchOp.getEntryPointRefs()) {
+          auto it = entryPointRefReplacements.find(originalRef);
+          if (it != entryPointRefReplacements.end()) {
+            replacementRefs.push_back(it->second);
+          } else {
+            replacementRefs.push_back(originalRef);
+          }
         }
+        dispatchOp.setEntryPointsAttr(
+            ArrayAttr::get(dispatchOp.getContext(), replacementRefs));
       });
     }
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
index 24b8f4a..80a26a3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
@@ -202,9 +202,20 @@
     const DenseMap<Attribute, SymbolRefAttr> &replacements) {
   for (auto funcLikeOp : moduleOp.getOps<FunctionOpInterface>()) {
     funcLikeOp->walk([&](DispatchOp dispatchOp) {
-      auto it = replacements.find(dispatchOp.getEntryPoint());
-      if (it != replacements.end()) {
-        dispatchOp.setEntryPointAttr(llvm::cast<SymbolRefAttr>(it->second));
+      bool didChange = false;
+      SmallVector<Attribute> newAttrs;
+      for (auto oldAttr : dispatchOp.getEntryPoints()) {
+        auto it = replacements.find(oldAttr);
+        if (it != replacements.end()) {
+          didChange = true;
+          newAttrs.push_back(it->second);
+        } else {
+          newAttrs.push_back(oldAttr);
+        }
+      }
+      if (didChange) {
+        dispatchOp.setEntryPointsAttr(
+            ArrayAttr::get(moduleOp.getContext(), newAttrs));
       }
     });
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
index 1603d15..fc99f10 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
@@ -377,7 +377,7 @@
   void printDispatchBody(raw_ostream &os, DispatchOp &dispatchOp) {
     // Find the entry point function from the dispatch entry point symbol
     // attribute.
-    auto entryPoint = dispatchOp.getEntryPoint();
+    auto entryPoint = *dispatchOp.getEntryPointRefs().begin();
     auto executableOp = cast<ExecutableOp>(SymbolTable::lookupNearestSymbolFrom(
         dispatchOp, entryPoint.getRootReference()));
     if (!executableOp)
@@ -452,7 +452,7 @@
           // Print entry function name, if there is only one entry function,
           // then the name space and the entry function names are the same,
           // and we can just print the function name to save space.
-          auto entryPoint = dispatch.getEntryPoint();
+          auto entryPoint = *dispatch.getEntryPointRefs().begin();
           auto rootName = entryPoint.getRootReference();
           auto leafName = entryPoint.getLeafReference();
           if (rootName == leafName) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
index c70620b..77ae538 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
@@ -37,12 +37,7 @@
   void runOnOperation() override {
     auto funcOp = getOperation();
     for (auto dispatchOp : funcOp.getFunctionBody().getOps<DispatchOp>()) {
-      std::string entryPointName =
-          dispatchOp.getEntryPoint().getRootReference().getValue().str();
-      for (FlatSymbolRefAttr nestedRef :
-           dispatchOp.getEntryPoint().getNestedReferences()) {
-        entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
-      }
+      std::string entryPointName = dispatchOp.getEntryPointName();
 
       // Input tensors:
       OpBuilder builder(dispatchOp);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
index 6186fd6..f4d845c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp
@@ -169,13 +169,8 @@
       // Trace on a valid ordinal.
       if (localTraceOrdinal >= 0 && localTraceOrdinal < dispatchOps.size()) {
         auto traceTarget = dispatchOps[localTraceOrdinal];
-        std::string entryPointName =
-            traceTarget.getEntryPoint().getRootReference().getValue().str();
-        for (FlatSymbolRefAttr nestedRef :
-             traceTarget.getEntryPoint().getNestedReferences()) {
-          entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
-        }
         // Append the ordinal to the trace name.
+        std::string entryPointName = traceTarget.getEntryPointName();
         traceOpWithName(traceTarget, entryPointName + std::string("::") +
                                          std::to_string(localTraceOrdinal));
       }
@@ -226,15 +221,9 @@
       // dispatches.
       IREE::Flow::DispatchOp breakTarget;
       funcOp.walk([&](IREE::Flow::DispatchOp dispatchOp) {
-        std::string entryPointName =
-            dispatchOp.getEntryPoint().getRootReference().getValue().str();
-        for (FlatSymbolRefAttr nestedRef :
-             dispatchOp.getEntryPoint().getNestedReferences()) {
-          entryPointName = (entryPointName + "::" + nestedRef.getValue()).str();
-        }
+        std::string entryPointName = dispatchOp.getEntryPointName();
         if (traceMatcher.match(entryPointName))
           traceOpWithName(dispatchOp, entryPointName);
-
         if (!breakTarget && breakMatcher.match(entryPointName))
           breakTarget = dispatchOp;
       });
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
index ed27b7f..254fca8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
@@ -14,7 +14,7 @@
 func.func @single_executable(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = arith.constant 4 : index
   // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -51,15 +51,28 @@
   }
 }
 // CHECK-LABEL: func.func @duplicate_executables
-func.func @duplicate_executables(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+func.func @duplicate_executables(%arg0: tensor<4xf32>) {
   %c4 = arith.constant 4 : index
-  // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  return %0 : tensor<4xf32>
+  // CHECK: = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: = flow.dispatch {@duplicate_executables_ex_0::@duplicate_executables_entry_0, @duplicate_executables_ex_0::@duplicate_executables_entry_0}
+  %3 = flow.dispatch {@duplicate_executables_ex_0::@duplicate_executables_entry_0, @duplicate_executables_ex_1::@duplicate_executables_entry_1}[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  return
+}
+
+// Ensure that symbol renaming is done within initializers.
+// CHECK: util.initializer
+util.initializer {
+  // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00>
+  %cst = arith.constant dense<1.000000e+00> : tensor<4xf32>
+  // CHECK: {{.*}} = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0(%[[CST]]) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1(%cst) : (tensor<4xf32>) -> tensor<4xf32>
+  util.optimization_barrier %0 : tensor<4xf32>
+  util.initializer.return
 }
 
 // -----
@@ -88,9 +101,9 @@
 func.func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> {
   %c4 = arith.constant 4 : index
   // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0 : tensor<2xi32>
 }
 
@@ -131,30 +144,16 @@
   // CHECK: %[[C4:.*]] = arith.constant 4
   %c4 = arith.constant 4 : index
   // CHECK:      {{.*}} = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%[[C4]]](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: {{.*}} = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%[[C4]]](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: {{.*}} = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%[[C4]]](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: {{.*}} = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%[[C4]]](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
-// Ensure that symbol renaming is done within initializers.
-util.global private @result : tensor<4xf32>
-// CHECK: util.initializer
-util.initializer {
-  // CHECK: %[[C4:.*]] = arith.constant 4
-  %c4 = arith.constant 4 : index
-  // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00>
-  %cst = arith.constant dense<1.000000e+00> : tensor<4xf32>
-  // CHECK: {{.*}} = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%[[C4]]](%[[CST]]) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%cst) : (tensor<4xf32>) -> tensor<4xf32>
-  util.global.store %0, @result : tensor<4xf32>
-  util.initializer.return
-}
-
 // -----
 
 // CHECK-LABEL: flow.executable public @different_types_float_ex
@@ -181,9 +180,9 @@
 func.func @different_types(%arg0: tensor<4xf32>) -> tensor<4xi1> {
   %c4 = arith.constant 4 : index
   // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
-  %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
   // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
-  %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
   return %0 : tensor<4xi1>
 }
 
@@ -239,11 +238,11 @@
 func.func @nested_ops(%arg0: tensor<5x6xf32>, %arg1: tensor<5x6xf32>) -> tensor<5x6xf32> {
   %c4 = arith.constant 4 : index
   // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
-  %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
+  %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
   // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
-  %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
+  %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
   // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
-  %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
+  %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0, %arg1) : (tensor<5x6xf32>, tensor<5x6xf32>) -> tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 4ae1571..a3c9f44 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -571,7 +571,7 @@
     }
 
     auto newOp = rewriter.replaceOpWithNewOp<IREE::Stream::AsyncDispatchOp>(
-        op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPoint(),
+        op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(),
         dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets,
         dispatchOperandEnds, dispatchOperandLengths, resultSizes,
         adaptor.getTiedOperandsAttr(), getAffinityFor(op));
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 3951a03..31783e6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1888,10 +1888,30 @@
 // stream.async.dispatch
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+struct DeduplicateAsyncDispatchEntryRefs final
+    : public OpRewritePattern<AsyncDispatchOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AsyncDispatchOp dispatchOp,
+                                PatternRewriter &rewriter) const override {
+    auto originalAttr = dispatchOp.getEntryPointsAttr();
+    auto newAttr = deduplicateArrayElements(originalAttr);
+    if (newAttr == originalAttr)
+      return failure();
+    rewriter.updateRootInPlace(
+        dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); });
+    return success();
+  }
+};
+
+} // namespace
+
 void AsyncDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  // TODO(benvanik): nothing? maybe tied type/lifetime updates?
+  // TODO(benvanik):maybe tied type/lifetime updates?
   results.insert<ElideUnusedOp<AsyncDispatchOp>>(context);
+  results.insert<DeduplicateAsyncDispatchEntryRefs>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2316,8 +2336,28 @@
 // stream.cmd.dispatch
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+struct DeduplicateCmdDispatchEntryRefs final
+    : public OpRewritePattern<CmdDispatchOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(CmdDispatchOp dispatchOp,
+                                PatternRewriter &rewriter) const override {
+    auto originalAttr = dispatchOp.getEntryPointsAttr();
+    auto newAttr = deduplicateArrayElements(originalAttr);
+    if (newAttr == originalAttr)
+      return failure();
+    rewriter.updateRootInPlace(
+        dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); });
+    return success();
+  }
+};
+
+} // namespace
+
 void CmdDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
+  results.insert<DeduplicateCmdDispatchEntryRefs>(context);
   results.insert<FoldSubviewsIntoDispatchOp<CmdDispatchOp>>(context);
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index a294923..b3dfc62 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -275,6 +275,44 @@
 }
 
 //===----------------------------------------------------------------------===//
+// custom<DispatchEntryPoints>($entry_points)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDispatchEntryPoints(OpAsmParser &parser,
+                                            ArrayAttr &entryPointAttrsArray) {
+  SmallVector<Attribute> entryPointAttrs;
+  if (succeeded(parser.parseOptionalLBrace())) {
+    do {
+      SymbolRefAttr entryPointAttr;
+      if (failed(parser.parseAttribute(entryPointAttr)))
+        return failure();
+      entryPointAttrs.push_back(entryPointAttr);
+    } while (succeeded(parser.parseOptionalComma()));
+    if (failed(parser.parseRBrace()))
+      return failure();
+  } else {
+    SymbolRefAttr entryPointAttr;
+    if (failed(parser.parseAttribute(entryPointAttr)))
+      return failure();
+    entryPointAttrs.push_back(entryPointAttr);
+  }
+  entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs);
+  return success();
+}
+
+static void printDispatchEntryPoints(OpAsmPrinter &p, Operation *op,
+                                     ArrayAttr entryPointAttrs) {
+  if (entryPointAttrs.size() == 1) {
+    p.printAttribute(entryPointAttrs.getValue().front());
+  } else {
+    p << '{';
+    llvm::interleaveComma(entryPointAttrs, p.getStream(),
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << '}';
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // custom<EncodedResourceOperands>(
 //     $resources, type($resources), $resource_sizes,
 //     $resource_encodings, $resource_encoding_dims)
@@ -1790,26 +1828,32 @@
 LogicalResult
 AsyncDispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   Operation *op = getOperation();
-  auto exportOp =
-      symbolTable.lookupNearestSymbolFrom<IREE::Stream::ExecutableExportOp>(
-          op, getEntryPoint());
-  if (!exportOp) {
-    // TODO(benvanik): there are a lot of tests that are assuming this is not
-    // verified. We'll need to go add dummy executables for all of them. Today
-    // we just bail on the verifier if the symbol isn't found.
-    //
-    // Should be:
-    //   return op->emitOpError() << "undefined entry point: " << entry_point();
-    return success();
+  auto entryPointRefs = getEntryPointRefs();
+  if (entryPointRefs.empty()) {
+    return emitOpError() << "at least one entry point must be defined";
   }
+  for (auto entryPointAttr : entryPointRefs) {
+    auto exportOp =
+        symbolTable.lookupNearestSymbolFrom<IREE::Stream::ExecutableExportOp>(
+            op, entryPointAttr);
+    if (!exportOp) {
+      // TODO(benvanik): there are a lot of tests that are assuming this is not
+      // verified. We'll need to go add dummy executables for all of them. Today
+      // we just bail on the verifier if the symbol isn't found.
+      //
+      // Should be:
+      //   return op->emitOpError() << "undefined entry point: " <<
+      //   entry_point();
+      return success();
+    }
 
-  // Verify that the workload parameters captured match the target export.
-  if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
-    return failure();
+    // Verify that the workload parameters captured match the target export.
+    if (failed(verifyDispatchWorkload(op, exportOp, getWorkload()))) {
+      return failure();
+    }
+
+    // TODO(benvanik): verify that the target function has matching operands.
   }
-
-  // TODO(benvanik): verify that the target function has matching operands.
-
   return success();
 }
 
@@ -2489,40 +2533,6 @@
   return success();
 }
 
-static ParseResult parseDispatchEntryPoints(OpAsmParser &parser,
-                                            ArrayAttr &entryPointAttrsArray) {
-  SmallVector<Attribute> entryPointAttrs;
-  if (succeeded(parser.parseOptionalLBrace())) {
-    do {
-      SymbolRefAttr entryPointAttr;
-      if (failed(parser.parseAttribute(entryPointAttr)))
-        return failure();
-      entryPointAttrs.push_back(entryPointAttr);
-    } while (succeeded(parser.parseOptionalComma()));
-    if (failed(parser.parseRBrace()))
-      return failure();
-  } else {
-    SymbolRefAttr entryPointAttr;
-    if (failed(parser.parseAttribute(entryPointAttr)))
-      return failure();
-    entryPointAttrs.push_back(entryPointAttr);
-  }
-  entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs);
-  return success();
-}
-
-static void printDispatchEntryPoints(OpAsmPrinter &p, Operation *op,
-                                     ArrayAttr entryPointAttrs) {
-  if (entryPointAttrs.size() == 1) {
-    p.printAttribute(entryPointAttrs.getValue().front());
-  } else {
-    p << '{';
-    llvm::interleaveComma(entryPointAttrs, p.getStream(),
-                          [&](Attribute attr) { p.printAttribute(attr); });
-    p << '}';
-  }
-}
-
 static ParseResult parseDispatchResources(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resources,
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 0d52718..988cf86 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -2057,7 +2057,7 @@
 
   let arguments = (ins
     Variadic<Index>:$workload,
-    SymbolRefAttr:$entry_point,
+    SymbolRefArrayAttr:$entry_points,
     Variadic<AnyTypeOf<[
       Stream_AnyStreamResource,
       Stream_PrimitiveType,
@@ -2076,7 +2076,7 @@
 
   let assemblyFormat = [{
     (`on` `(` $affinity^ `)`)?
-    $entry_point
+    custom<DispatchEntryPoints>($entry_points)
     (`[` $workload^ `]`)? ``
     custom<DispatchOperands>($resource_operands,
                              $resource_operand_offsets,
@@ -2089,6 +2089,13 @@
   }];
 
   let extraClassDeclaration = [{
+    auto getEntryPointRefs() {
+      return getEntryPoints().getAsRange<SymbolRefAttr>();
+    }
+    void forEachEntryPointAttr(std::function<void(SymbolRefAttr)> fn) {
+      for (auto entryPointAttr : getEntryPointRefs()) fn(entryPointAttr);
+    }
+
     Value getOperandSize(unsigned idx) {
       return findValueSizeInList(idx, getOperands(), getResourceOperandSizes());
     }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp
index 8d5d694..caf8658 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp
@@ -215,9 +215,9 @@
   OpBuilder builder(splatOp);
   auto dispatchOp = builder.create<IREE::Stream::AsyncDispatchOp>(
       loc, resultTypes, workload,
-      SymbolRefAttr::get(
+      builder.getArrayAttr({SymbolRefAttr::get(
           builder.getStringAttr(builtinName),
-          FlatSymbolRefAttr::get(builder.getContext(), builtinName)),
+          FlatSymbolRefAttr::get(builder.getContext(), builtinName))}),
       operands, operandSizes, operandOffsets, operandEnds, operandLengths,
       resultSizes, builder.getIndexArrayAttr(tiedOperands),
       splatOp.getAffinityAttr());
@@ -311,9 +311,9 @@
   OpBuilder builder(fillOp);
   auto dispatchOp = builder.create<IREE::Stream::AsyncDispatchOp>(
       loc, resultTypes, workload,
-      SymbolRefAttr::get(
+      builder.getArrayAttr({SymbolRefAttr::get(
           builder.getStringAttr(builtinName),
-          FlatSymbolRefAttr::get(builder.getContext(), builtinName)),
+          FlatSymbolRefAttr::get(builder.getContext(), builtinName))}),
       operands, operandSizes, operandOffsets, operandEnds, operandLengths,
       resultSizes, builder.getIndexArrayAttr(tiedOperands),
       fillOp.getAffinityAttr());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 14af630..367f4aa 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -765,10 +765,9 @@
   }
 
   auto newOp = builder.create<IREE::Stream::CmdDispatchOp>(
-      asyncOp.getLoc(), asyncOp.getWorkload(),
-      builder.getArrayAttr({asyncOp.getEntryPoint()}), newOperands,
-      newResources, newResourceSizes, newResourceOffsets, newResourceLengths,
-      builder.getArrayAttr(newResourceAccesses));
+      asyncOp.getLoc(), asyncOp.getWorkload(), asyncOp.getEntryPointsAttr(),
+      newOperands, newResources, newResourceSizes, newResourceOffsets,
+      newResourceLengths, builder.getArrayAttr(newResourceAccesses));
   newOp->setDialectAttrs(asyncOp->getDialectAttrs());
   asyncOp.erase();
   return success();
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 5e4f0ea..f5cefd2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -60,6 +60,13 @@
 // Utils
 //===----------------------------------------------------------------------===//
 
+ArrayAttr deduplicateArrayElements(ArrayAttr arrayAttr) {
+  SetVector<Attribute> attrsSet(arrayAttr.begin(), arrayAttr.end());
+  if (attrsSet.size() == arrayAttr.size())
+    return arrayAttr;
+  return ArrayAttr::get(arrayAttr.getContext(), attrsSet.takeVector());
+}
+
 Value findValueSizeInList(unsigned index, ValueRange values, ValueRange sizes) {
   assert(values[index].getType().isa<IREE::Util::SizeAwareTypeInterface>() &&
          "must be a size-aware type to get dims");
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
index b89675f..170f7eb 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -48,6 +48,9 @@
 // Utils
 //===----------------------------------------------------------------------===//
 
+// Removes duplicate attributes in the array (if any).
+ArrayAttr deduplicateArrayElements(ArrayAttr arrayAttr);
+
 // Returns the dynamic size of the value at |index|.
 Value findValueSizeInList(unsigned index, ValueRange values, ValueRange sizes);