Replacing flow.dispatch with flow.dispatch2. (#4345)

* Adding a default conversion for flow.dispatch.workgroup.* -> HAL.
This is just a placeholder; no mapping from non-rank-3 dispatches to the
required rank-3 HAL ops is performed and it's assumed that the ops have
already been normalized to rank-3.

* Replacing flow.dispatch with flow.dispatch2.
The new dispatch op takes N-dimensional workgroup counts whereas the old
one just took 1D workloads. This allows us to specify the workgroup
counts when forming the dispatch regions in their natural form and defer
remapping them to the HAL's 3D XYZ until we know the target and are
recording the hal.command_buffer.dispatch ops, though the mechanism still
needs some work.

The legacy dispatch paths for non-linalg-on-tensors thread through with
the assumption that they have 1D workloads still; as we make more progress
on #4140 we can hopefully remove these.

See iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir's
dynamic_tiled_dispatch for IR before/after on the new path.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 3699ee4..f3810ae 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -818,87 +818,10 @@
 // flow.dispatch
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseDispatchOp(OpAsmParser &parser,
-                                   OperationState *result) {
-  SymbolRefAttr entryPointAttr;
-  if (failed(parser.parseAttribute(entryPointAttr, "entry_point",
-                                   result->attributes))) {
-    return failure();
-  }
-
-  OpAsmParser::OperandType workloadArg;
-  Type workloadArgType;
-  if (failed(parser.parseLSquare()) ||
-      failed(parser.parseOperand(workloadArg)) ||
-      failed(parser.parseColonType(workloadArgType)) ||
-      failed(parser.parseRSquare()) ||
-      failed(parser.resolveOperand(workloadArg, workloadArgType,
-                                   result->operands))) {
-    return failure();
-  }
-
-  SmallVector<OpAsmParser::OperandType, 4> operands;
-  FunctionType entryPointType;
-  if (failed(
-          parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
-      failed(parser.parseOptionalAttrDict(result->attributes)) ||
-      failed(parser.parseColonType(entryPointType)) ||
-      failed(
-          parser.addTypesToList(entryPointType.getResults(), result->types)) ||
-      failed(parser.resolveOperands(operands, entryPointType.getInputs(),
-                                    parser.getNameLoc(), result->operands))) {
-    return failure();
-  }
-  return success();
-}
-
-static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
-  p << op.getOperationName() << ' ';
-  p.printAttributeWithoutType(op.entry_point());
-  p << "[";
-  p.printOperand(op.workload());
-  p << " : ";
-  p.printType(op.workload().getType());
-  p << "](";
-  p.printOperands(op.operands());
-  p << ')';
-  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"entry_point"});
-  p << " : ";
-  p.printType(op.getEntryPointType());
-}
-
 void DispatchOp::build(OpBuilder &builder, OperationState &state,
-                       DispatchEntryOp entryPoint, Value workload,
-                       ArrayRef<Type> results, ValueRange operands) {
-  state.addOperands({workload});
-  state.addOperands(operands);
-  // Construct Executable::Entry nested reference.
-  StringRef executableOpSymName =
-      entryPoint->getParentOp()
-          ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
-          .getValue();
-  state.addAttribute(
-      "entry_point",
-      builder.getSymbolRefAttr(executableOpSymName,
-                               {builder.getSymbolRefAttr(entryPoint)}));
-  state.addTypes(results);
-}
-
-StringRef DispatchOp::executable() { return entry_point().getRootReference(); }
-
-FunctionType DispatchOp::getEntryPointType() {
-  SmallVector<Type, 8> argTypes(operand_type_range{operands()});
-  return FunctionType::get(getContext(), argTypes, getResultTypes());
-}
-
-//===----------------------------------------------------------------------===//
-// flow.dispatch2
-//===----------------------------------------------------------------------===//
-
-void Dispatch2Op::build(OpBuilder &builder, OperationState &state,
-                        DispatchEntryOp entryPoint, ValueRange workgroupCount,
-                        TypeRange results, ValueRange operands,
-                        ArrayRef<NamedAttribute> attributes) {
+                       DispatchEntryOp entryPoint, ValueRange workgroupCount,
+                       TypeRange results, ValueRange operands,
+                       ArrayRef<NamedAttribute> attributes) {
   StringRef executableOpSymName =
       entryPoint->getParentOp()
           ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -918,14 +841,14 @@
                                 static_cast<int32_t>(operands.size())}));
 }
 
-StringRef Dispatch2Op::executable() { return entry_point().getRootReference(); }
+StringRef DispatchOp::executable() { return entry_point().getRootReference(); }
 
-FunctionType Dispatch2Op::getEntryPointType() {
+FunctionType DispatchOp::getEntryPointType() {
   SmallVector<Type, 8> argTypes(operand_type_range{operands()});
   return FunctionType::get(getContext(), argTypes, getResultTypes());
 }
 
-static LogicalResult verifyDispatch2Op(Dispatch2Op op) {
+static LogicalResult verifyDispatchOp(DispatchOp op) {
   if (op.workgroup_count().empty()) {
     return op.emitOpError() << "at least one workgroup dimension is required";
   }
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index b8b29af..1711703 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -621,40 +621,6 @@
 //===----------------------------------------------------------------------===//
 
 def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [
-    FLOW_StreamableOp,
-  ]> {
-  let summary = [{a dispatch to an outlined dispatch region}];
-  let description = [{
-    Dispatches a workload to the specified executable function.
-  }];
-
-  let arguments = (ins
-    SymbolRefAttr:$entry_point,
-    FLOW_Workload:$workload,
-    Variadic<AnyType>:$operands
-  );
-  let results = (outs
-    Variadic<AnyType>:$results
-  );
-
-  let skipDefaultBuilders = 1;
-  let builders = [
-    OpBuilderDAG<(ins "DispatchEntryOp":$entryPoint, "Value":$workload,
-      "ArrayRef<Type>":$results, CArg<"ValueRange", "{}">:$operands)>,
-  ];
-
-  let extraClassDeclaration = [{
-    StringRef executable();
-    FunctionType getEntryPointType();
-
-    // StreamableOpInterface:
-    bool isTransfer() { return false; }
-    bool isUsableInStream() { return true; }
-    bool isStreamOnly() { return true; }
-  }];
-}
-
-def FLOW_Dispatch2Op : FLOW_PureOp<"dispatch2", [
     AttrSizedOperandSegments,
     FLOW_StreamableOp,
   ]> {
@@ -698,7 +664,7 @@
     functional-type($operands, $results)
   }];
 
-  let verifier = [{ return verifyDispatch2Op(*this); }];
+  let verifier = [{ return verifyDispatchOp(*this); }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
index d201e53..2325c63 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -15,7 +15,7 @@
 func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CST:.+]] = constant
   %cst = constant 4 : index
-  // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]] : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @ex0::@dispatch_fn[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @ex0::@dispatch_fn[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir
index 706ede0..8ab0c0f 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir
@@ -9,7 +9,7 @@
   // CHECK: %[[DR0:.+]] = addf %[[CA1]], %[[CA1]]
   // CHECK: flow.return %[[DR0]] : tensor<?xf32>
   %ret0, %ret1 = flow.dispatch.region[%workload : index](
-      %i0 = %arg0 : tensor<?xf32>, %i1 = %arg0 : tensor<?xf32>, %i2 = %arg0 : tensor<?xf32>) 
+      %i0 = %arg0 : tensor<?xf32>, %i1 = %arg0 : tensor<?xf32>, %i2 = %arg0 : tensor<?xf32>)
       -> (tensor<?xf32>, tensor<?xf32>) {
     %1 = addf %i0, %i1 : tensor<?xf32>
     flow.return %1, %i2 : tensor<?xf32>, tensor<?xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index b769f7b..028a583 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -21,10 +21,10 @@
   // CHECK: flow.ex.stream.fragment(%arg1 = %[[CST]] : index, %arg2 = %[[A0]] : tensor<4xf32>)
   %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %cst : index, %arg3 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
     // Both referreants of the constant should use the deduped arg.
-    // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : index]
-    // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : index]
-    %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : index](%arg3) : (tensor<4xf32>) -> tensor<4xf32>
-    %2 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg2 : index](%1) : (tensor<4xf32>) -> tensor<4xf32>
+    // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1]
+    // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1]
+    %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg3) : (tensor<4xf32>) -> tensor<4xf32>
+    %2 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg2] (%1) : (tensor<4xf32>) -> tensor<4xf32>
     flow.return %2, %2 : tensor<4xf32>, tensor<4xf32>
   }
   return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
index aeb3f55..457ea93 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
@@ -19,7 +19,7 @@
   // CHECK: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
   %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
     // CHECK-NEXT: flow.dispatch
-    %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+    %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
     // CHECK-NEXT: flow.return
     flow.return %1, %1 : tensor<4xf32>, tensor<4xf32>
     // CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
index 94882ec..15208fb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
@@ -68,7 +68,8 @@
         // calculate the workload from the results.
         auto dummyWorkload = blockBuilder.create<ConstantIndexOp>(loc, 0);
         auto dispatchOp = blockBuilder.create<DispatchOp>(
-            loc, dispatchEntryOp, dummyWorkload, funcType.getResults(), args);
+            loc, dispatchEntryOp, ValueRange{dummyWorkload},
+            funcType.getResults(), args);
         blockBuilder.create<mlir::ReturnOp>(loc, dispatchOp.getResults());
       }
     }
diff --git a/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp b/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
index c54bfe0..cb87627 100644
--- a/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
@@ -41,7 +41,7 @@
   InjectDispatchTracingPass() = default;
 
   void runOnOperation() override {
-    for (auto dispatchOp : getOperation().getOps<Dispatch2Op>()) {
+    for (auto dispatchOp : getOperation().getOps<DispatchOp>()) {
       std::string entryPointName =
           dispatchOp.entry_point().getRootReference().str();
       for (FlatSymbolRefAttr nestedRef :
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 8f2e06f..fbc0ead 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -78,7 +78,7 @@
 
   // Create the dispatch op to the executable function.
   auto dispatchOp = builder.create<DispatchOp>(
-      regionOp.getLoc(), entryPointOp, regionOp.workload(),
+      regionOp.getLoc(), entryPointOp, ValueRange{regionOp.workload()},
       outlinedFuncOp.getType().getResults(), newArgs);
 
   if (traceDispatchTensors) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp
index 34d9b6e..1a6a8a3 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp
@@ -81,7 +81,7 @@
   }
 
   // Create the dispatch op to the executable function.
-  auto dispatchOp = builder.create<Dispatch2Op>(
+  auto dispatchOp = builder.create<DispatchOp>(
       regionOp.getLoc(), entryPointOp, regionOp.workgroup_count(),
       regionOp.getResultTypes(), newOperands);
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir
index f53085c..6d785eb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir
@@ -13,7 +13,7 @@
 //     CHECK: %{{.+}} = flow.variable.load @[[IN0_0]] : tensor<5x3xf32>
 //     CHECK: %{{.+}} = flow.variable.load @[[IN0_1]] : tensor<3x5xf32>
 //     CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> {
-//     CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+//     CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
 //     CHECK:   flow.return %[[DISPATCH_RES]] : tensor<5x5xf32>
 //     CHECK: return %[[RES]] : tensor<5x5xf32>
 //
@@ -23,7 +23,7 @@
 //     CHECK: %{{.+}} = flow.variable.load @[[IN1_0]] : tensor<3x5xf32>
 //     CHECK: %{{.+}} = flow.variable.load @[[IN1_1]] : tensor<5x5xf32>
 //     CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32>
-//     CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
+//     CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
 //     CHECK:   flow.return %[[DISPATCH_RES]] : tensor<3x5xf32>
 //     CHECK: return %[[RES]] : tensor<3x5xf32>
 //
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
index aa0d239..18f22ba 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
@@ -13,8 +13,8 @@
 // CHECK-LABEL: func @single_executable
 func @single_executable(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -53,12 +53,12 @@
 // CHECK-LABEL: func @duplicate_executables
 func @duplicate_executables(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -87,10 +87,10 @@
 // CHECK-LABEL: func @same_ops_diff_operands
 func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   return %0 : tensor<2xi32>
 }
 
@@ -129,14 +129,14 @@
 // CHECK-LABEL: func @multiple_entry_points
 func @multiple_entry_points(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -165,10 +165,10 @@
 // CHECK-LABEL: func @different_types
 func @different_types(%arg0: tensor<4xf32>) -> tensor<4xi1> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
-  %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
-  // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
-  %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4 : index](%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
+  %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1>
   return %0 : tensor<4xi1>
 }
 
@@ -222,12 +222,12 @@
 // CHECK-LABEL: func @nested_ops
 func @nested_ops(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
   %c4 = constant 4 : index
-  // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
-  %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
-  // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
-  %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
-  // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
-  %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
+  %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
   return %0 : tensor<1xi32>
 }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
index 734ddb2..8796e24 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
@@ -12,10 +12,10 @@
   // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor<?xi32>, %arg3 = %1 : index, %arg4 = %[[WORKLOAD0]] : index, %arg5 = %0 : tensor<?xi32>) -> tensor<?xi32> {
   // CHECK-NEXT:   %3 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?]>
   // CHECK-NEXT:   %4 = shapex.tie_shape %arg2, %3 : tensor<?xi32>, !shapex.ranked_shape<[?]>
-  // CHECK-NEXT:   %5 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg4 : index](%arg4, %4) : (index, tensor<?xi32>) -> tensor<?xi32>
+  // CHECK-NEXT:   %5 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg4] (%arg4, %4) : (index, tensor<?xi32>) -> tensor<?xi32>
   // CHECK-NEXT:   flow.return %5 : tensor<?xi32>
   // CHECK-NEXT: }
-  %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0 : index](%c0, %2) : (index, tensor<?xi32>) -> tensor<?xi32>
+  %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0] (%c0, %2) : (index, tensor<?xi32>) -> tensor<?xi32>
   // CHECK-NEXT: return %2 : tensor<?xi32>
   return %15 : tensor<?xi32>
 }
@@ -40,10 +40,10 @@
   // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32>
   %0 = addf %arg0, %arg0 : tensor<4xf32>
   // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %0 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4xf32>
   // CHECK-NEXT: }
-  %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst : index](%0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst] (%0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK: %2 = addf %1, %1 : tensor<4xf32>
   %2 = addf %1, %1 : tensor<4xf32>
   // CHECK-NEXT: return %2 : tensor<4xf32>
@@ -59,10 +59,10 @@
   // CHECK-NEXT: %[[ADD1:.+]] = addf %arg0, %arg0 : tensor<4xf32>
   %add1 = addf %arg0, %arg0 : tensor<4xf32>
   // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %[[ADD1]] : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%arg1 : index](%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst : index](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-  // CHECK-NEXT:   %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%arg1 : index](%[[D1]], %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst : index](%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%arg1] (%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst] (%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%arg1] (%[[D1]], %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst] (%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %[[D2]] : tensor<4xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: %[[ADD2:.+]] = addf %[[S]], %arg0 : tensor<4xf32>
@@ -89,17 +89,17 @@
   // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index
   %cst = constant 4 : index
   // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4xf32>
   // CHECK-NEXT: }
-  %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: %1 = addf %0, %0 : tensor<4xf32>
   %1 = addf %0, %0 : tensor<4xf32>
   // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %1 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4xf32>
   // CHECK-NEXT: }
-  %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst : index](%1) : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%1) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %2 : tensor<4xf32>
   return %2 : tensor<4xf32>
 }
@@ -119,10 +119,10 @@
   // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index
   %cst = constant 4 : index
   // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
-  // CHECK-DAG:    = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1 : index](%arg2)
-  %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK-DAG:    = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1 : index](%arg2)
-  %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-DAG:    = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1] (%arg2)
+  %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-DAG:    = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1] (%arg2)
+  %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %{{.+}}, %{{.+}}
   // CHECK-NEXT: }
   // CHECK-NEXT: return %{{.+}}, %{{.+}}
@@ -169,14 +169,14 @@
   // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index
   %cst = constant 16 : index
   // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
-  // CHECK-NEXT:   %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32>
-  // CHECK-NEXT:   %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1 : index](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
-  // CHECK-NEXT:   %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1 : index](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT:   %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT:   %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4x4xf32>
   // CHECK-NEXT: }
-  %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
-  %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst : index](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
-  %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst : index](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst] (%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst] (%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+  %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst] (%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
   // CHECK-NEXT: return %0 : tensor<4x4xf32>
   return %2 : tensor<4x4xf32>
 }
@@ -210,17 +210,17 @@
   // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index
   %cst = constant 4 : index
   // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4xf32>
   // CHECK-NEXT: }
-  %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32>
   %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %1 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1 : index](%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1] (%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %3 : tensor<4xf32>
   // CHECK-NEXT: }
-  %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst : index](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst] (%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %2 : tensor<4xf32>
   return %2 : tensor<4xf32>
 }
@@ -238,10 +238,10 @@
   // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index
   %cst = constant 4 : index
   // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK-NEXT:   %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT:   %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   flow.return %1 : tensor<4xf32>
   // CHECK-NEXT: }
-  %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst : index](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: return %0 : tensor<4xf32>
   return %0 : tensor<4xf32>
 }
@@ -269,7 +269,7 @@
     // CHECK: %[[STREAM_R2:.+]] = shapex.tie_shape %[[STREAM_R1]], %[[STREAM_RS0]]
     // CHECK: return %[[STREAM_R2]]
   // CHECK: }
-  %6 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2 : index](%3, %4, %5) : (tensor<?x?xf32>, index, index) -> tensor<?x?xf32>
+  %6 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2] (%3, %4, %5) : (tensor<?x?xf32>, index, index) -> tensor<?x?xf32>
   %7 = shapex.tie_shape %6, %arg1 : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]>
   return %7, %arg1 : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]>
 }
@@ -286,11 +286,11 @@
   //  CHECK-DAG:   %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1
   //      CHECK:   flow.return
   // CHECK-NEXT: }
-  %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload : index]() : () -> tensor<i32>
+  %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload] () : () -> tensor<i32>
   //      CHECK: %[[C2:.+]] = constant 2 : i32
   %c2 = constant 2 : i32
   //  CHECK-DAG:   %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2
-  %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload : index](%c2) : (i32) -> tensor<f32>
+  %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload] (%c2) : (i32) -> tensor<f32>
   return %0, %1 : tensor<i32>, tensor<f32>
 }
 
@@ -307,8 +307,8 @@
                     // Could be returned in either order
   // CHECK-NEXT:    flow.return
   // CHECK-NEXT: }
-  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w : index]() : () -> tensor<i32>
-  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w : index]() : () -> tensor<f32>
+  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] () : () -> tensor<i32>
+  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] () : () -> tensor<f32>
   //      CHECK: %[[READBACK:.+]] = flow.tensor.load %[[S1]]
   %readback = flow.tensor.load %d1 : tensor<i32>
   //      CHECK: %[[S2:.+]] = flow.ex.stream.fragment
@@ -316,7 +316,7 @@
   //  CHECK-DAG:    %[[D3:.+]] = flow.dispatch @dispatch_3::@dispatch_3
   //      CHECK:    flow.return %[[D3]]
   // CHECK-NEXT: }
-  %d3 = flow.dispatch @dispatch_3::@dispatch_3[%w : index](%readback) : (i32) -> tensor<2xf32>
+  %d3 = flow.dispatch @dispatch_3::@dispatch_3[%w] (%readback) : (i32) -> tensor<2xf32>
   //      CHECK: return %[[S1]]#
   // CHECK-SAME:   %[[S1]]#
   // CHECK-SAME:   %[[S2]]
@@ -334,7 +334,7 @@
   //  CHECK-DAG:    %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1
   // CHECK-NEXT:    flow.return %[[D1]]
   // CHECK-NEXT: }
-  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w : index](%c1) : (i32) -> (tensor<i32>)
+  %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] (%c1) : (i32) -> (tensor<i32>)
   // CHECK: %[[SE_USER:.+]] = iree.do_not_optimize(%[[S1]])
   %side_effecting_user = iree.do_not_optimize(%d1) : tensor<i32>
   // CHECK: %[[C2:.+]] = constant 2
@@ -344,7 +344,7 @@
   //  CHECK-DAG:    %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2
   // CHECK-NEXT:    flow.return %[[D2]]
   // CHECK-NEXT: }
-  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w : index](%c2) : (i32) -> (tensor<f32>)
+  %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] (%c2) : (i32) -> (tensor<f32>)
   //      CHECK: return %[[S1]], %[[S2]], %[[SE_USER]]
   return %d1, %d2, %side_effecting_user : tensor<i32>, tensor<f32>, tensor<i32>
 }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir
index 6b8dfcd..1872945 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir
@@ -8,24 +8,24 @@
   // CHECK-DAG: constant 4 : index
   // CHECK-DAG: constant 5 : index
   // CHECK-DAG: constant 6 : index
-  // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]] : index]() : () -> tensor<f32>
-  // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]] : index]() : () -> tensor<f32>
-  // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]] : index]() : () -> tensor<f32>
-  // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]] : index]() : () -> tensor<f32>
-  // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]] : index]() : () -> tensor<f32>
-  // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]] : index]() : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]] () : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]] () : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]] () : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]] () : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]] () : () -> tensor<f32>
+  // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]] () : () -> tensor<f32>
   %w = constant 1 : index
-  %d0 = flow.dispatch @dispatch0::@dispatch0[%w : index]() : () -> (tensor<f32>)
+  %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor<f32>)
   %c2 = constant 2 : index
-  %d1 = flow.dispatch @dispatch1::@dispatch1[%w : index]() : () -> (tensor<f32>)
+  %d1 = flow.dispatch @dispatch1::@dispatch1[%w] () : () -> (tensor<f32>)
   %c3 = constant 3 : index
-  %d2 = flow.dispatch @dispatch2::@dispatch2[%w : index]() : () -> (tensor<f32>)
+  %d2 = flow.dispatch @dispatch2::@dispatch2[%w] () : () -> (tensor<f32>)
   %c4 = constant 4 : index
-  %d3 = flow.dispatch @dispatch3::@dispatch3[%w : index]() : () -> (tensor<f32>)
+  %d3 = flow.dispatch @dispatch3::@dispatch3[%w] () : () -> (tensor<f32>)
   %c5 = constant 5 : index
-  %d4 = flow.dispatch @dispatch4::@dispatch4[%w : index]() : () -> (tensor<f32>)
+  %d4 = flow.dispatch @dispatch4::@dispatch4[%w] () : () -> (tensor<f32>)
   %c6 = constant 6 : index
-  %d5 = flow.dispatch @dispatch5::@dispatch5[%w : index]() : () -> (tensor<f32>)
+  %d5 = flow.dispatch @dispatch5::@dispatch5[%w] () : () -> (tensor<f32>)
   return
 }
 
@@ -45,7 +45,7 @@
   %tie0 = shapex.tie_shape %input, %shape : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]>
   %dim0 = shapex.ranked_dim %shape[0] : !shapex.ranked_shape<[?,?]> -> index
   %dim1 = shapex.ranked_dim %shape[1] : !shapex.ranked_shape<[?,?]> -> index
-  %d = flow.dispatch @dispatch::@dispatch[%w : index](%tie0, %dim0, %dim1) : (tensor<?x?xf32>, index, index) -> tensor<?x?xf32>
+  %d = flow.dispatch @dispatch::@dispatch[%w] (%tie0, %dim0, %dim1) : (tensor<?x?xf32>, index, index) -> tensor<?x?xf32>
   %tie1 = shapex.tie_shape %d, %shape : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]>
   return %tie1, %shape : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]>
 }
@@ -87,7 +87,7 @@
   %c2 = constant 2 : index
   %ct3 = constant dense<3> : tensor<i32>
   // CHECK: flow.dispatch
-  %d0 = flow.dispatch @dispatch0::@dispatch0[%w : index]() : () -> (tensor<i32>)
+  %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor<i32>)
   // CHECK: addi
   %add0 = addi %d0, %ct3 : tensor<i32>
   return
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
index 59d9fb5..6800464 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir
@@ -5,8 +5,8 @@
 func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   %c4 = constant 4 : index
   //      CHECK: flow.tensor.trace {trace_info = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32>
-  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch2 @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch2 @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: flow.tensor.trace {trace_info = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32>
   // CHECK-NEXT: return %[[RET0]]
   return %0 : tensor<4xf32>
@@ -20,13 +20,13 @@
   %c4 = constant 4 : index
 
   //     CHECK: flow.tensor.trace {trace_info = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32>
-  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch2 @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
-  %0 = flow.dispatch2 @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: flow.tensor.trace {trace_info = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32>
 
   //     CHECK: flow.tensor.trace {trace_info = "ex::entry1 inputs"} %[[RET0]] : tensor<4xf32>
-  // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch2 @ex::@entry1[%c4] (%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32>
-  %1 = flow.dispatch2 @ex::@entry1[%c4](%0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4] (%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = flow.dispatch @ex::@entry1[%c4] (%0) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT: flow.tensor.trace {trace_info = "ex::entry1 outputs"} %[[RET1]] : tensor<4xf32>
 
   // CHECK: return %[[RET1]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir
index 5a99c50..a58a4f7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir
@@ -22,7 +22,7 @@
   %x = constant 100 : index
   // CHECK-DAG: %[[Y:.+]] = constant 50
   %y = constant 50 : index
-  // CHECK: %[[RET:.+]] = flow.dispatch2 @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[
+  // CHECK: %[[RET:.+]] = flow.dispatch @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[
   // CHECK-SAME: %[[X]], %[[Y]]
   // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32>
   %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
@@ -60,7 +60,7 @@
   %x = constant 100 : index
   // CHECK-DAG: %[[Y:.+]] = constant 50
   %y = constant 50 : index
-  // CHECK: %[[RET0:.+]] = flow.dispatch2 @dispatchFnMuli_dispatch_0::@dispatchFnMuli_dispatch_0[
+  // CHECK: %[[RET0:.+]] = flow.dispatch @dispatchFnMuli_dispatch_0::@dispatchFnMuli_dispatch_0[
   // CHECK-SAME: %[[X]], %[[Y]]
   // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32>
   %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
@@ -73,7 +73,7 @@
     flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32>
     flow.return
   }
-  // CHECK: %[[RET1:.+]] = flow.dispatch2 @dispatchFnMuli_dispatch_1::@dispatchFnMuli_dispatch_1[
+  // CHECK: %[[RET1:.+]] = flow.dispatch @dispatchFnMuli_dispatch_1::@dispatchFnMuli_dispatch_1[
   // CHECK-SAME: %[[Y]], %[[X]]
   // CHECK-SAME: ] (%[[RET0]]) : (tensor<4x8xf32>) -> tensor<8x4xf32>
   %1 = flow.dispatch.workgroups[%y, %x](%0) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) = (
@@ -98,7 +98,7 @@
 func @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> {
   %x = constant 100 : index
   %y = constant 50 : index
-  // CHECK: flow.dispatch2 @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0
+  // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0
   %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
     %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32>
   ) {
@@ -113,7 +113,7 @@
 func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> {
   %x = constant 100 : index
   %y = constant 50 : index
-  // CHECK: flow.dispatch2 @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0
+  // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0
   %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
     %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32>
   ) {
@@ -178,7 +178,7 @@
   //  CHECK-DAG: %[[IN_ARG0_DIM3:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][3]
   //  CHECK-DAG: %[[IN_RET0_DIM0:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][0]
   //  CHECK-DAG: %[[IN_RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1]
-  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch2 @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[
+  // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[
   // CHECK-SAME:   %[[X]], %[[Y]]
   // CHECK-SAME: ] (%[[ARG0_SHAPED]], %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]])
   // CHECK-SAME: : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor<?x?x1024xf32>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir
index 47d5a5b..60c843a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir
@@ -19,7 +19,7 @@
 // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]]
 // CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C3]]
 // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 1024 : index
-// CHECK-DAG: %[[DISPATCH:.+]] = flow.dispatch @dynamicRankedShape_ex_dispatch_0::@dynamicRankedShape_ex_dispatch_0[%[[WORKLOAD0]] : index](%[[ARG0]], %[[D1]], %[[D3]]) : (tensor<7x?x24x?xf32>, index, index)
+// CHECK-DAG: %[[DISPATCH:.+]] = flow.dispatch @dynamicRankedShape_ex_dispatch_0::@dynamicRankedShape_ex_dispatch_0[%[[WORKLOAD0]]] (%[[ARG0]], %[[D1]], %[[D3]]) : (tensor<7x?x24x?xf32>, index, index)
 // CHECK-DAG: return %[[DISPATCH]]
 module @dynamicRankedShapeModule {
 func @dynamicRankedShape(%arg0 : tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> {
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir
index 7d0ab56..b13d93f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir
@@ -25,7 +25,7 @@
 // CHECK-NEXT: func @simpleMath(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 // CHECK-NEXT:   %[[WORKLOAD0:.+]] = constant 4 : index
 // CHECK-NEXT:   %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NEXT:     %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK-NEXT:     %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
 // CHECK-NEXT:     flow.return %1 : tensor<4xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %0 : tensor<4xf32>
@@ -54,7 +54,7 @@
 // CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 // CHECK-NEXT:   %[[WORKLOAD0:.+]] = constant 4 : index
 // CHECK-NEXT:   %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NEXT:     %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK-NEXT:     %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
 // CHECK-NEXT:     flow.return %1 : tensor<4xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %0 : tensor<4xf32>
@@ -83,7 +83,7 @@
 // CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 // CHECK-NEXT:   %[[WORKLOAD0:.+]] = constant 4 : index
 // CHECK-NEXT:   %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NEXT:     %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK-NEXT:     %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32>
 // CHECK-NEXT:     flow.return %1 : tensor<4xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %0 : tensor<4xf32>
@@ -128,9 +128,9 @@
 // CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
 // CHECK-NEXT:   %[[WORKLOAD0:.+]] = constant 16 : index
 // CHECK-NEXT:   %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
-// CHECK-NEXT:     %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:     %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1 : index](%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:     %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1 : index](%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:     %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:     %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:     %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:     flow.return %3 : tensor<4x4xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %0 : tensor<4x4xf32>
@@ -165,7 +165,7 @@
 //  CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
 //  CHECK-NEXT:   %[[WORKLOAD0:.+]] = constant 4 : index
 //  CHECK-NEXT:   %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x8xf32>) -> tensor<4xf32> {
-//  CHECK-NEXT:     %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1 : index](%arg2) : (tensor<4x8xf32>) -> tensor<4xf32>
+//  CHECK-NEXT:     %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x8xf32>) -> tensor<4xf32>
 //  CHECK-NEXT:     flow.return %1 : tensor<4xf32>
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   return %0 : tensor<4xf32>
@@ -196,7 +196,7 @@
 // CHECK-DAG:   %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index
 // CHECK-NEXT:   %4 = flow.ex.stream.fragment(%arg4 = %arg1 : tensor<1x1xi32>, %arg5 = %arg0 : tensor<2x4xi32>, %arg6 = %[[ARG2_INDEX]] : index, %arg7 = %[[ARG3_INDEX]] : index, %arg8 = %[[WORKLOAD0]] : index) -> tensor<2x4xi32> {
 // CHECK-NEXT:     %5 = flow.tensor.update %arg4, %arg5[%arg6, %arg7] : tensor<1x1xi32> -> tensor<2x4xi32>
-// CHECK-NEXT:     %6 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8 : index](%arg5, %5) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
+// CHECK-NEXT:     %6 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8] (%arg5, %5) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
 // CHECK-NEXT:     flow.return %6 : tensor<2x4xi32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %4 : tensor<2x4xi32>
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 1272f92..31915be 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -408,7 +408,9 @@
   dispatchState.device = device;
   dispatchState.commandBuffer = commandBuffer;
   dispatchState.executableLayout = executableLayout;
-  dispatchState.workload = rewriter.getRemappedValue(dispatchOp.workload());
+  for (auto dim : dispatchOp.workgroup_count()) {
+    dispatchState.workgroupCount.push_back(rewriter.getRemappedValue(dim));
+  }
   // TODO(benvanik): support extended push constants.
   dispatchState.basePushConstantOffset = 0;
   dispatchState.operands = operandAdaptors;
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index 72922dc..18f21ea 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s
+// RUN: iree-opt -print-ir-after-all -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s
 
 hal.executable @ex0 {
   hal.interface @interface {
@@ -29,11 +29,11 @@
     //      CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set=0, bindings=[0 = (%arg0, %c0, %c512), 1 = (%[[TMP_BUF]], %c0, %c512)]
     //      CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz
     //      CHECK: hal.command_buffer.execution_barrier
-    %1 = flow.dispatch @ex0::@entry0[%arg1 : index](%arg2) : (tensor<128xf32>) -> tensor<128xf32>
+    %1 = flow.dispatch @ex0::@entry0[%arg1] (%arg2) : (tensor<128xf32>) -> tensor<128xf32>
     //      CHECK: hal.command_buffer.push_descriptor_set
     //      CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz
     //      CHECK: hal.command_buffer.execution_barrier
-    %2 = flow.dispatch @ex0::@entry0[%arg1 : index](%1) : (tensor<128xf32>) -> tensor<128xf32>
+    %2 = flow.dispatch @ex0::@entry0[%arg1] (%1) : (tensor<128xf32>) -> tensor<128xf32>
     flow.return %2 : tensor<128xf32>
   }
   // CHECK: hal.command_buffer.end %[[CMD]]
@@ -98,13 +98,111 @@
   %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<?x128xf32>, %arg3 = %bs : index) -> tensor<?x128xf32> {
     %1 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,128]>
     %2 = shapex.tie_shape %arg2, %1 : tensor<?x128xf32>, !shapex.ranked_shape<[?,128]>
-    %3 = flow.dispatch @ex0::@entry0[%arg1 : index](%2, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
+    %3 = flow.dispatch @ex0::@entry0[%arg1] (%2, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
     %4 = shapex.tie_shape %3, %1 : tensor<?x128xf32>, !shapex.ranked_shape<[?,128]>
-    %5 = flow.dispatch @ex0::@entry0[%arg1 : index](%4, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
+    %5 = flow.dispatch @ex0::@entry0[%arg1] (%4, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
     %6 = shapex.tie_shape %5, %1 : tensor<?x128xf32>, !shapex.ranked_shape<[?,128]>
-    %7 = flow.dispatch @ex0::@entry0[%arg1 : index](%6, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
+    %7 = flow.dispatch @ex0::@entry0[%arg1] (%6, %arg3) : (tensor<?x128xf32>, index) -> tensor<?x128xf32>
     %8 = shapex.tie_shape %7, %1 : tensor<?x128xf32>, !shapex.ranked_shape<[?,128]>
     flow.return %8 : tensor<?x128xf32>
   }
   return %0 : tensor<?x128xf32>
 }
+
+// -----
+
+hal.executable @ex attributes {sym_visibility = "private"} {
+  hal.interface @legacy_io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @tgt, filter="dylib-llvm-aot" {
+    hal.executable.entry_point @entry attributes {
+      interface = @legacy_io,
+      ordinal = 0 : i32,
+      signature = (!flow.dispatch.input<7x4x24xf32>, !flow.dispatch.output<4x7x1024xf32>) -> ()
+    }
+    module {}
+  }
+}
+
+// CHECK-LABEL: func @static_tiled_dispatch
+func @static_tiled_dispatch(%arg0: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> {
+  %c1024 = constant 1024 : index
+  %c512 = constant 512 : index
+  // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, "OneShot", "Transfer|Dispatch"
+  // CHECK-NEXT: hal.command_buffer.begin %[[CMD]]
+  %1 = flow.ex.stream.fragment(
+      %arg3 = %arg0 : tensor<7x4x24xf32>,
+      %arg6 = %c1024 : index,
+      %arg7 = %c512 : index
+    ) -> tensor<4x7x1024xf32> {
+    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %c2688), 1 = (%buffer, %c0, %c114688)]
+    // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz
+    %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32>
+    flow.return %0 : tensor<4x7x1024xf32>
+  }
+  // CHECK: hal.command_buffer.end %[[CMD]]
+  return %1 : tensor<4x7x1024xf32>
+}
+
+// -----
+
+hal.executable @ex attributes {sym_visibility = "private"} {
+  hal.interface @legacy_io attributes {push_constants = 4 : i32} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @tgt, filter="dylib-llvm-aot" {
+    hal.executable.entry_point @entry attributes {
+      interface = @legacy_io,
+      ordinal = 0 : i32,
+      signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output<?x?x1024xf32>, index, index, index, index) -> ()
+    }
+    module {}
+  }
+}
+
+// CHECK-LABEL: func @dynamic_tiled_dispatch
+func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: index) -> tensor<?x?x1024xf32> {
+  %c1024 = constant 1024 : index
+  %c512 = constant 512 : index
+  // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, "OneShot", "Transfer|Dispatch"
+  // CHECK-NEXT: hal.command_buffer.begin %[[CMD]]
+  %2 = flow.ex.stream.fragment(
+      %arg3 = %arg0 : tensor<7x?x24x?xf32>,
+      %arg4 = %arg1 : index,
+      %arg5 = %arg2 : index,
+      %arg6 = %c1024 : index,
+      %arg7 = %c512 : index
+    ) -> tensor<?x?x1024xf32> {
+    %3 = shapex.make_ranked_shape %arg4, %arg5 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]>
+    %4 = shapex.make_ranked_shape %arg5, %arg4 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
+    %5 = shapex.tie_shape %arg3, %3 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>
+    // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32
+    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %9), 1 = (%buffer, %c0, %12)]
+
+    // CHECK: #hal.device.match.id<"dylib*">(
+    // CHECK-SAME: %[[CMD_INNER:.+]] = %cmd : !hal.command_buffer,
+    // CHECK-SAME: %[[COUNT_X_INNER:.+]] = %c1024 : index,
+    // CHECK-SAME: %[[COUNT_Y_INNER:.+]] = %c512 : index,
+    // CHECK-SAME: %[[COUNT_Z_INNER:.+]] = %c512 : index
+
+    // This makes me so sad.
+    //      CHECK: %[[C1:.+]] = constant 1
+    // CHECK-NEXT: %[[COUNT_X_TMP:.+]] = addi %[[COUNT_X_INNER]], %[[C1]]
+    // CHECK-NEXT: %[[COUNT_X:.+]] = subi %[[COUNT_X_TMP]], %[[C1]]
+    // CHECK-NEXT: %[[COUNT_Y_TMP:.+]] = addi %[[COUNT_Y_INNER]], %[[C1]]
+    // CHECK-NEXT: %[[COUNT_Y:.+]] = subi %[[COUNT_Y_TMP]], %[[C1]]
+    // CHECK-NEXT: %[[COUNT_Z_TMP:.+]] = addi %[[COUNT_Z_INNER]], %[[C1]]
+    // CHECK-NEXT: %[[COUNT_Z:.+]] = subi %[[COUNT_Z_TMP]], %[[C1]]
+
+    // CHECK: hal.command_buffer.dispatch.symbol %[[CMD_INNER]], @ex::@tgt::@entry, workgroup_xyz =
+    // CHECK-SAME: [%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]]
+    %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%5, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor<?x?x1024xf32>
+    %7 = shapex.tie_shape %6, %4 : tensor<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>
+    flow.return %7 : tensor<?x?x1024xf32>
+  }
+  // CHECK: hal.command_buffer.end %[[CMD]]
+  return %2 : tensor<?x?x1024xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index 670d4b9..2b01894 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -216,6 +216,15 @@
 
   LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
                                DeviceSwitchRewriter &switchRewriter) override {
+    // TODO(#4140): remove this legacy path when linalg-on-tensors is used.
+    // In the linalg-on-tensors world where we are performing the tiling logic
+    // in the flow dialect we don't even really need the ability to override
+    // dispatch recording at all - just a way to allow targets to map workgroup
+    // counts from the N-dimensional flow workgroup counts to the 3D hal counts.
+    if (dispatchState.workgroupCount.size() == 3) {
+      return TargetBackend::recordDispatch(loc, dispatchState, switchRewriter);
+    }
+
     IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
     ModuleOp llvmIRModuleOp;
     for (auto executableTargetOp :
@@ -240,7 +249,7 @@
     auto *region = switchRewriter.addConditionRegion(
         IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
         {
-            dispatchState.workload,
+            dispatchState.workgroupCount[0],
             dispatchState.commandBuffer,
         });
     auto &entryBlock = region->front();
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 377e665..6e3dbdf 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -68,6 +68,15 @@
 LogicalResult SPIRVTargetBackend::recordDispatch(
     Location loc, DispatchState dispatchState,
     DeviceSwitchRewriter &switchRewriter) {
+  // TODO(#4140): remove this legacy path when linalg-on-tensors is used.
+  // In the linalg-on-tensors world where we are performing the tiling logic
+  // in the flow dialect we don't even really need the ability to override
+  // dispatch recording at all - just a way to allow targets to map workgroup
+  // counts from the N-dimensional flow workgroup counts to the 3D hal counts.
+  if (dispatchState.workgroupCount.size() == 3) {
+    return TargetBackend::recordDispatch(loc, dispatchState, switchRewriter);
+  }
+
   // Multiple entry points might be generated for a single dispatch function.
   // Under such circumstances, we will have a special attribute indicating the
   // schedule of the split entry points. Try to see if we can find such
@@ -118,7 +127,7 @@
   auto *region = switchRewriter.addConditionRegion(
       IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
       {
-          dispatchState.workload,
+          dispatchState.workgroupCount[0],
           dispatchState.commandBuffer,
       });
 
@@ -188,7 +197,7 @@
 // query independently so that we don't need to lookup the value here.
 std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
     Location loc, IREE::HAL::ExecutableOp executableOp,
-    IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+    IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
     OpBuilder &builder) {
   // TODO(ravishankarm): possibly emit different recordDispatch logic if the
   // workgroup sizes differ among targets.
@@ -212,7 +221,7 @@
 
 std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
     Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
-    Value workload, OpBuilder &builder) {
+    ValueRange workload, OpBuilder &builder) {
   std::array<Value, 3> workgroupSize;
   for (auto executionModeOp :
        spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
index de6f057..ea3389c 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
@@ -44,13 +44,13 @@
   // Finds the spv.ExecutionMode operation to get the workgroup size from.
   std::array<Value, 3> calculateDispatchWorkgroupSize(
       Location loc, IREE::HAL::ExecutableOp executableOp,
-      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+      IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
       OpBuilder &builder) override;
 
  private:
   std::array<Value, 3> calculateDispatchWorkgroupSize(
       Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
-      Value workload, OpBuilder &builder);
+      ValueRange workload, OpBuilder &builder);
 
   SPIRVCodegenOptions spvCodeGenOptions_;
 };
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index b2037e6..5878d76 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -102,7 +102,7 @@
 
 std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupSize(
     Location loc, IREE::HAL::ExecutableOp executableOp,
-    IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+    IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
     OpBuilder &builder) {
   // When no workgroup size is specified we just assume [1,1,1].
   // This yields a workgroup count that models the extents of the workload.
@@ -115,7 +115,7 @@
 
 std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupCount(
     Location loc, IREE::HAL::ExecutableOp executableOp,
-    IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+    IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
     OpBuilder &builder) {
   auto workgroupSize = calculateDispatchWorkgroupSize(
       loc, executableOp, entryPointOp, workload, builder);
@@ -123,54 +123,75 @@
 }
 
 std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupCount(
-    Location loc, Value workload, const std::array<Value, 3> &workgroupSize,
-    OpBuilder &builder) {
+    Location loc, ValueRange workload,
+    const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
   std::array<Value, 3> result;
+
   auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
-  for (int i = 0; i < 3; ++i) {
-    // Round up: (workload + workgroup_size - 1) / workgroup_size;
-    auto rounded = builder.createOrFold<mlir::SubIOp>(
-        loc,
-        builder.createOrFold<mlir::AddIOp>(loc, workload, workgroupSize[i]),
-        constantOne);
-    auto workgroupCountI = builder.createOrFold<mlir::UnsignedDivIOp>(
-        loc, rounded, workgroupSize[i]);
-    result[i] = workgroupCountI;
+  if (workload.size() <= 3) {
+    // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup size.
+    for (int i = 0; i < 3; ++i) {
+      // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
+      Value workloadI = i < workload.size() ? workload[i] : constantOne;
+      workloadI = builder.createOrFold<mlir::SubIOp>(
+          loc,
+          builder.createOrFold<mlir::AddIOp>(loc, workloadI, workgroupSize[i]),
+          constantOne);
+      result[i] = builder.createOrFold<UnsignedDivIOp>(loc, workloadI,
+                                                       workgroupSize[i]);
+    }
+  } else {
+    // TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
+    Value flatWorkload = constantOne;
+    for (auto workloadI : workload) {
+      flatWorkload = builder.createOrFold<MulIOp>(loc, flatWorkload, workloadI);
+    }
+    for (int i = 0; i < 3; ++i) {
+      // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
+      auto rounded = builder.createOrFold<mlir::SubIOp>(
+          loc,
+          builder.createOrFold<mlir::AddIOp>(loc, flatWorkload,
+                                             workgroupSize[i]),
+          constantOne);
+      auto workgroupCountI = builder.createOrFold<mlir::UnsignedDivIOp>(
+          loc, rounded, workgroupSize[i]);
+      result[i] = workgroupCountI;
 
-    // Multiply back out and subtract from invocations.
-    workload = builder.createOrFold<SubIOp>(
-        loc, workload,
-        builder.createOrFold<MulIOp>(loc, workgroupCountI, rounded));
-
-    // Ensure > 0.
-    auto workloadGreaterZero =
-        builder.create<CmpIOp>(loc, CmpIPredicate::sge, workload, constantOne);
-    workload = builder.create<SelectOp>(loc, workloadGreaterZero, workload,
-                                        constantOne);
+      // Multiply back out and subtract from invocations.
+      flatWorkload = builder.createOrFold<SubIOp>(
+          loc, flatWorkload,
+          builder.createOrFold<MulIOp>(loc, workgroupCountI, rounded));
+    }
   }
+
   return result;
 }
 
 LogicalResult TargetBackend::recordDispatch(
     Location loc, DispatchState dispatchState,
     DeviceSwitchRewriter &switchRewriter) {
+  SmallVector<Value, 4> regionArgs;
+  regionArgs.push_back(dispatchState.commandBuffer);
+  for (auto dim : dispatchState.workgroupCount) {
+    regionArgs.push_back(dim);
+  }
   auto *region = switchRewriter.addConditionRegion(
       IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
-      {
-          dispatchState.workload,
-          dispatchState.commandBuffer,
-      });
+      regionArgs);
   auto &entryBlock = region->front();
-  auto workload = entryBlock.getArgument(0);
-  auto commandBuffer = entryBlock.getArgument(1);
+  auto commandBuffer = entryBlock.getArgument(0);
+  SmallVector<Value, 3> originalWorkgroupCount;
+  for (int i = 0; i < dispatchState.workgroupCount.size(); ++i) {
+    originalWorkgroupCount.push_back(entryBlock.getArgument(1 + i));
+  }
 
   auto builder = OpBuilder::atBlockBegin(&entryBlock);
-  auto workgroupCount = calculateDispatchWorkgroupCount(
-      loc, dispatchState.executableOp, dispatchState.entryPointOp, workload,
-      builder);
+  auto remappedWorkgroupCount = calculateDispatchWorkgroupCount(
+      loc, dispatchState.executableOp, dispatchState.entryPointOp,
+      originalWorkgroupCount, builder);
   builder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
-      loc, commandBuffer, dispatchState.entryPointOp, workgroupCount[0],
-      workgroupCount[1], workgroupCount[2]);
+      loc, commandBuffer, dispatchState.entryPointOp, remappedWorkgroupCount[0],
+      remappedWorkgroupCount[1], remappedWorkgroupCount[2]);
 
   builder.create<IREE::HAL::ReturnOp>(loc);
   return success();
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 7019644..3c1db8a 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -192,9 +192,9 @@
     // SSA value of the loaded hal.executable_layout reference.
     Value executableLayout;
 
-    // SSA value of the total workload of the dispatch. See `flow.dispatch` for
-    // more information on how this is calculated.
-    Value workload;
+    // SSA values of the workgroup count of the dispatch. See `flow.dispatch`
+    // for more information on how this is calculated.
+    SmallVector<Value, 3> workgroupCount;
 
     // A base offset within the push constants array that all new push constants
     // must follow. Note that backend-specific push constants must have been
@@ -232,9 +232,10 @@
   // beginning at offset |dispatchState.basePushConstantOffset|. Note that the
   // push constants must have been declared by `extractInterface`.
   //
-  // The provided |dispatchState.workload| can be used to derive the workgroup
-  // counts for dispatch using `calculateDispatchWorkgroupCounts` (or other
-  // logic).
+  // The provided |dispatchState.workgroupCount| can be used to access the
+  // workgroup count values for dispatch as provided on the original
+  // flow.dispatch op. These arbitrarily-ranked dimensions need to be adapted
+  // into the target-dependent 3-D XYZ grid space.
   //
   // |dispatchState.operands| and |dispatchState.results| can be used to access
   // the buffers allocated in case additional command buffer operations are
@@ -340,22 +341,22 @@
   // for a single workgroup.
   virtual std::array<Value, 3> calculateDispatchWorkgroupSize(
       Location loc, IREE::HAL::ExecutableOp executableOp,
-      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+      IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
       OpBuilder &builder);
 
   // Calculates the workgroup count (x, y, z) for dispatching to the given
-  // |entryPointOp|. The provided |workload| is the total number of invocations
-  // required as calculated by the generic workload logic (basically, number of
-  // output elements in tensors).
+  // |entryPointOp|. The provided N-dimensional |workload| is the total number
+  // of invocations required as calculated by the generic workload logic
+  // (basically, number of output elements in tensors).
   virtual std::array<Value, 3> calculateDispatchWorkgroupCount(
       Location loc, IREE::HAL::ExecutableOp executableOp,
-      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+      IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
       OpBuilder &builder);
-  // Calculates the workgroup count (x, y, z) given the total |workload| and
-  // specific |workgroupSize|.
+  // Calculates the workgroup count (x, y, z) given the total N-dimensional
+  // |workload| and specific |workgroupSize|.
   std::array<Value, 3> calculateDispatchWorkgroupCount(
-      Location loc, Value workload, const std::array<Value, 3> &workgroupSize,
-      OpBuilder &builder);
+      Location loc, ValueRange workload,
+      const std::array<Value, 3> &workgroupSize, OpBuilder &builder);
 };
 
 }  // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 12c5568..ad048ad 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -243,7 +243,7 @@
 
   std::array<Value, 3> calculateDispatchWorkgroupCount(
       Location loc, IREE::HAL::ExecutableOp executableOp,
-      IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+      IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
       OpBuilder &builder) override {
     // For now we are not tiling and just dispatch everything as 1,1,1.
     auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
index 429db63..df5ed38 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
@@ -6,7 +6,7 @@
   // CHECK: %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 0, 1, 0]> : tensor<4xi8>
   %cst = constant dense<[true, false, true, false]> : tensor<4xi1>
   %0 = flow.ex.stream.fragment(%arg1 = %c4 : index, %arg2 = %arg0 : tensor<4xi1>, %arg3 = %cst : tensor<4xi1>) -> tensor<4xi1> {
-    %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1 : index](%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
+    %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1] (%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
     flow.return %1 : tensor<4xi1>
   }
   return %0 : tensor<4xi1>
diff --git a/iree/test/e2e/hackability/flow_partitioned.mlir b/iree/test/e2e/hackability/flow_partitioned.mlir
index 3b85a0e..d9b99f4 100644
--- a/iree/test/e2e/hackability/flow_partitioned.mlir
+++ b/iree/test/e2e/hackability/flow_partitioned.mlir
@@ -15,7 +15,7 @@
 func @staticShapedFn() -> tensor<4xf32> {
   %input = iree.unfoldable_constant dense<[-1.0, 2.0, -3.0, 4.0]> : tensor<4xf32>
   %workload = constant 4 : index
-  %0 = flow.dispatch @ex0::@dispatch0[%workload : index](%input) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = flow.dispatch @ex0::@dispatch0[%workload] (%input) : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 // CHECK: 4xf32=-2 4 -6 8