Integrate llvm/llvm-project@eb141867 (#12264)

* Reset third_party/llvm-project:
eb14186771e7bca992c043637aac3ed7104eaa1f (2023-02-17 11:27:24 +0200):
Revert "[LLD] [COFF] Don't try to detect MSVC installations in mingw
mode"
* Updated to https://github.com/tensorflow/mlir-hlo/commit/50584fafb42af5dc34355282e380b6a74c355b29
* Updated to https://github.com/tensorflow/tensorflow/commit/20ff2f32d85e79195b9b38f69039eb185f8cb1a5
* Massive renaming to update `scf.foreach_thread` to `scf.forall` and associated ops / methods
* Added `threadIDGenerator` for `mapNestedForeachToThreadsImpl`
* Fixed vector op accessor method callsites
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
index 94ac28d..fb27df4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
@@ -156,8 +156,8 @@
         transferReadOp.hasOutOfBoundsDim()) {
       return failure();
     }
-    int64_t rankDiff =
-        op.getVectorType().getRank() - transferReadOp.getVectorType().getRank();
+    int64_t rankDiff = op.getResultVectorType().getRank() -
+                       transferReadOp.getVectorType().getRank();
     SmallVector<AffineExpr> exprs(rankDiff, rewriter.getAffineConstantExpr(0));
     ArrayRef<AffineExpr> originalExpr =
         transferReadOp.getPermutationMap().getResults();
@@ -167,7 +167,7 @@
                        transferReadOp.getPermutationMap().getNumSymbols(),
                        exprs, op.getContext());
     ArrayAttr inBounds = rewriter.getBoolArrayAttr(
-        SmallVector<bool>(op.getVectorType().getRank(), true));
+        SmallVector<bool>(op.getResultVectorType().getRank(), true));
     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
         op, op.getType(), transferReadOp.getSource(),
         transferReadOp.getIndices(), newMap, transferReadOp.getPadding(),
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index 23a3162..4f1c2a1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -67,11 +67,12 @@
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter& rewriter) const override {
-    unsigned numNonUnitSrcDim = llvm::count_if(
-        op.getVectorType().getShape(), [](int64_t dim) { return dim != 1; });
+    unsigned numNonUnitSrcDim =
+        llvm::count_if(op.getSourceVectorType().getShape(),
+                       [](int64_t dim) { return dim != 1; });
     if (numNonUnitSrcDim > 1) return failure();
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     op.getVector());
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), op.getVector());
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index fef674d..b9e0b8e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -59,7 +59,7 @@
 // workgroups to a static value. Ideally this should not be done and the static
 // and dyamic cases are handled the same way. When the tile+distribute moves
 // away from using `scf.for` to using a construct that better captures
-// distribution (like `scf.foreach_thread`) this information can be dropped.
+// distribution (like `scf.forall`) this information can be dropped.
 static LogicalResult getTileAndDistributeConfig(
     ArrayRef<Operation *> computeOps, Operation *&dispatchRootOp,
     SmallVectorImpl<int64_t> &tileSizes,
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 46a6474..753da93 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -311,30 +311,29 @@
 }
 
 //===----------------------------------------------------------------------===//
-// ShareForeachThreadOperandsOp
+// ShareForallOperandsOp
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform_dialect::ShareForeachThreadOperandsOp::applyToOne(
-    scf::ForeachThreadOp foreachThreadOp,
-    transform::ApplyToEachResultList &results,
+transform_dialect::ShareForallOperandsOp::applyToOne(
+    scf::ForallOp forallOp, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   IRRewriter rewriter(getContext());
   SmallVector<int64_t> shareOperands(getShareOperands());
   // Empty case: consider all operands need to be shared.
   if (shareOperands.empty()) {
-    shareOperands = llvm::to_vector(
-        llvm::seq<int64_t>(0, foreachThreadOp.getOutputs().size()));
+    shareOperands =
+        llvm::to_vector(llvm::seq<int64_t>(0, forallOp.getOutputs().size()));
   }
   for (int64_t outputIdx : getShareOperands()) {
-    if (outputIdx < 0 || outputIdx >= foreachThreadOp.getOutputs().size())
-      return mlir::emitDefiniteFailure(foreachThreadOp, "operand idx overflow");
-    Value toShare = foreachThreadOp.getOutputs()[outputIdx];
+    if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size())
+      return mlir::emitDefiniteFailure(forallOp, "operand idx overflow");
+    Value toShare = forallOp.getOutputs()[outputIdx];
     if (std::distance(toShare.getUses().begin(), toShare.getUses().end()) !=
         2) {
       /*return mlir::emitSilenceableFailure(
-          foreachThreadOp,
-          "operand to share must have exactly 2 uses, the foreach_thread op "
+          forallOp,
+          "operand to share must have exactly 2 uses, the forall op "
           "and an extract_slice op.");*/
       continue;
     }
@@ -345,12 +344,12 @@
     }
     if (!extractSliceOp) {
       /*return mlir::emitSilenceableFailure(
-        foreachThreadOp,
+        forallOp,
         "shared operands use must be extractSliceOp.");*/
       continue;
     }
     // Get the corresponding bbArg.
-    BlockArgument bbArg = foreachThreadOp.getOutputBlockArguments()[outputIdx];
+    BlockArgument bbArg = forallOp.getOutputBlockArguments()[outputIdx];
 
     // Check if the extract_slice has a matching parallel_insert_slice
     // (i.e., same source/target, offsets, sizes and strides).
@@ -365,7 +364,7 @@
              llvm::equal(insertSlice.getMixedStrides(),
                          extractSliceOp.getMixedStrides());
     };
-    if (llvm::none_of(foreachThreadOp.getTerminator().getYieldingOps(),
+    if (llvm::none_of(forallOp.getTerminator().getYieldingOps(),
                       isMatchingParallelInsertSlice)) {
       continue;
     }
@@ -376,16 +375,17 @@
     });
   }
 
-  results.push_back(foreachThreadOp);
+  results.push_back(forallOp);
   return DiagnosedSilenceableFailure::success();
 }
 
 //===---------------------------------------------------------------------===//
-// ForeachThreadToWorkgroupOp
+// ForallToWorkgroupOp
 //===---------------------------------------------------------------------===//
 
-void transform_dialect::ForeachThreadToWorkgroupOp::build(
-    OpBuilder &builder, OperationState &result, Value target) {
+void transform_dialect::ForallToWorkgroupOp::build(OpBuilder &builder,
+                                                   OperationState &result,
+                                                   Value target) {
   result.addOperands(target);
   MLIRContext *ctx = builder.getContext();
   result.addTypes({pdl::OperationType::get(ctx)});
@@ -396,9 +396,9 @@
 /// operands. Assumes the HAL::ExecutableExportOp is built with an empty
 /// region.
 static LogicalResult populateWorkgroupCountComputingRegion(
-    PatternRewriter &rewriter, scf::ForeachThreadOp foreachThreadOp,
+    PatternRewriter &rewriter, scf::ForallOp forallOp,
     HAL::ExecutableExportOp exportOp) {
-  Location loc = foreachThreadOp.getLoc();
+  Location loc = forallOp.getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   Region &r = exportOp.getWorkgroupCount();
   assert(r.empty() && "expected block-less workgroup_count region");
@@ -410,7 +410,7 @@
   SmallVector<Value> results;
   // For now, this assumes that we only pull in constants.
   // TODO: Iteratively pull required operations.
-  for (Value v : foreachThreadOp.getNumThreads()) {
+  for (Value v : forallOp.getUpperBound(rewriter)) {
     auto op = dyn_cast_or_null<arith::ConstantIndexOp>(v.getDefiningOp());
     if (!op) return failure();
     results.push_back(
@@ -426,41 +426,41 @@
 }
 
 //===---------------------------------------------------------------------===//
-// Patterns for ForeachThreadToWorkgroup rewrite.
+// Patterns for ForallToWorkgroup rewrite.
 //===---------------------------------------------------------------------===//
 
-LogicalResult rewriteForeachThreadToWorkgroup(
-    scf::ForeachThreadOp foreachThreadOp,
-    IREE::HAL::ExecutableExportOp exportOp, PatternRewriter &rewriter) {
+LogicalResult rewriteForallToWorkgroup(scf::ForallOp forallOp,
+                                       IREE::HAL::ExecutableExportOp exportOp,
+                                       PatternRewriter &rewriter) {
   // Step 0. Target-specific verifications. There is no good place to anchor
-  // those right now: the ForeachThreadOp is target-independent and the
-  // transform op does not apply to individual ForeachThreadOp.
-  MLIRContext *ctx = foreachThreadOp->getContext();
-  Location loc = foreachThreadOp->getLoc();
+  // those right now: the ForallOp is target-independent and the
+  // transform op does not apply to individual ForallOp.
+  MLIRContext *ctx = forallOp->getContext();
+  Location loc = forallOp->getLoc();
   // TODO iree should have own device mapping like #hal.workgroup<x/y/z>
   Attribute bX = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimX);
   Attribute bY = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimY);
   Attribute bZ = gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimZ);
-  if (foreachThreadOp.getNumResults() > 0)
-    return foreachThreadOp->emitError(
-        "only bufferized scf.foreach_thread lowers to workgroup");
-  if (foreachThreadOp.getNumThreads().size() > 3)
-    return foreachThreadOp->emitError(
-        "scf.foreach_thread with rank > 3 does not lower to workgroup");
+  if (forallOp.getNumResults() > 0)
+    return forallOp->emitError(
+        "only bufferized scf.forall lowers to workgroup");
+  if (forallOp.getRank() > 3)
+    return forallOp->emitError(
+        "scf.forall with rank > 3 does not lower to workgroup");
 
-  if (!foreachThreadOp.getMapping().has_value())
-    return foreachThreadOp->emitError("mapping must be present");
+  if (!forallOp.getMapping().has_value())
+    return forallOp->emitError("mapping must be present");
   SmallVector<Attribute> blockMapping =
-      llvm::to_vector(foreachThreadOp.getMapping()->getValue());
+      llvm::to_vector(forallOp.getMapping()->getValue());
   if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
         return !map.isa<gpu::GPUBlockMappingAttr>();
       })) {
-    return foreachThreadOp->emitError("mapping must be #gpu.block<x/y/z/>");
+    return forallOp->emitError("mapping must be #gpu.block<x/y/z/>");
   }
 
   // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
   SmallVector<Value> numBlocks =
-      llvm::to_vector(foreachThreadOp.getNumThreads());
+      llvm::to_vector(forallOp.getUpperBound(rewriter));
   // Ensure we have 3 block sizes, one for each id.
   Value one;
   for (auto attr : {bX, bY, bZ}) {
@@ -476,12 +476,12 @@
     return static_cast<int64_t>(a.cast<gpu::GPUBlockMappingAttr>().getBlock()) <
            static_cast<int64_t>(b.cast<gpu::GPUBlockMappingAttr>().getBlock());
   };
-  SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
-      blockMapping, numBlocks, comparator);
+  SmallVector<Value> gridDimValues =
+      scf::ForallOp::getValuesSortedByKey(blockMapping, numBlocks, comparator);
 
   // Step 3. Outline the compute workload region and set up the workload
   // operands, if this has not been done already.
-  // Using `transform.iree.tile_to_foreach_thread_and_workgroup_count_region` is
+  // Using `transform.iree.tile_to_forall_and_workgroup_count_region` is
   // the preferred way to set up tiling and workgroup_count region **at the same
   // time**.
   //
@@ -492,21 +492,21 @@
   // the flow level and explicitly match the ops we want to fuse.
   // Once fusion is customizable enough in perpetuity, we can retire this.
   if (exportOp.getWorkgroupCount().empty()) {
-    if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+    if (llvm::any_of(forallOp.getUpperBound(rewriter), [](Value v) {
           return !v.getDefiningOp<arith::ConstantIndexOp>();
         })) {
-      return foreachThreadOp->emitError(
+      return forallOp->emitError(
           "unsupported dynamic workgroup_count atm --- need to slice out "
           "workgroup_count computation into ExecutableExport::workgroup_count."
           "\nThis region may require arbitrary computations and cannot "
           "magically match what the `stream.cmd.dispatch` has already imposed "
           "on us at a distance."
           "\nFor now we must specify the number of values properly when "
-          "applying the topLevel tile_to_foreach_thread_op");
+          "applying the topLevel tile_to_forall_op");
     }
-    if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
+    if (failed(populateWorkgroupCountComputingRegion(rewriter, forallOp,
                                                      exportOp))) {
-      return foreachThreadOp->emitOpError(
+      return forallOp->emitOpError(
                  "failed to populate workload region for dispatchOp: ")
              << exportOp;
     }
@@ -523,24 +523,24 @@
     workgroupCountOps.push_back(
         rewriter.create<HAL::InterfaceWorkgroupCountOp>(loc, idx));
   }
-  bvm.map(foreachThreadOp.getThreadIndices(), workgroupIdOps);
-  bvm.map(foreachThreadOp.getNumThreads(), workgroupCountOps);
+  bvm.map(forallOp.getInductionVars(), workgroupIdOps);
+  bvm.map(forallOp.getUpperBound(rewriter), workgroupCountOps);
 
-  // Step 5. Predicate omitted given unique topLevel scf::ForeachThreadOp.
+  // Step 5. Predicate omitted given unique topLevel scf::ForallOp.
 
-  // Step 6. Move the body of foreachThreadOp.
+  // Step 6. Move the body of forallOp.
   // Erase the terminator first, it will not be used since we are on buffers.
-  rewriter.eraseOp(foreachThreadOp.getTerminator());
+  rewriter.eraseOp(forallOp.getTerminator());
   Block *targetBlock;
   Block::iterator insertionPoint;
-  targetBlock = foreachThreadOp->getBlock();
-  insertionPoint = Block::iterator(foreachThreadOp);
-  Block &sourceBlock = foreachThreadOp.getRegion().front();
+  targetBlock = forallOp->getBlock();
+  insertionPoint = Block::iterator(forallOp);
+  Block &sourceBlock = forallOp.getRegion().front();
   targetBlock->getOperations().splice(insertionPoint,
                                       sourceBlock.getOperations());
 
   // Step 7. RAUW thread indices to thread ops.
-  for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
+  for (Value blockIdx : forallOp.getInductionVars()) {
     for (Operation *user : llvm::make_early_inc_range(blockIdx.getUsers())) {
       rewriter.updateRootInPlace(user, [&]() {
         user->replaceUsesOfWith(blockIdx, bvm.lookup(blockIdx));
@@ -548,10 +548,10 @@
     }
   }
 
-  // Step 5. Barriers omitted given unique topLevel scf::ForeachThreadOp.
+  // Step 5. Barriers omitted given unique topLevel scf::ForallOp.
 
   // Step 6. Erase old op.
-  rewriter.eraseOp(foreachThreadOp);
+  rewriter.eraseOp(forallOp);
 
   return success();
 }
@@ -560,8 +560,7 @@
 // IREE-specific transformations defined outside of iree_linalg_transform.
 //===---------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure
-transform_dialect::ForeachThreadToWorkgroupOp::applyToOne(
+DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne(
     func::FuncOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
@@ -582,26 +581,24 @@
         target, "no IREE::HAL::ExecutableExportOp found");
   }
 
-  scf::ForeachThreadOp topLevelForeachThreadOp;
-  auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
-    if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
+  scf::ForallOp topLevelForallOp;
+  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
+    if (forallOp->getParentOfType<scf::ForallOp>())
       return WalkResult::advance();
-    if (topLevelForeachThreadOp) return WalkResult::interrupt();
-    topLevelForeachThreadOp = foreachThreadOp;
+    if (topLevelForallOp) return WalkResult::interrupt();
+    topLevelForallOp = forallOp;
     return WalkResult::advance();
   });
 
   if (walkResult.wasInterrupted()) {
     results.assign(1, nullptr);
     return mlir::emitSilenceableFailure(
-        target, "could not find a unique topLevel scf.foreach_thread");
+        target, "could not find a unique topLevel scf.forall");
   }
 
-  SimplePatternRewriter rewriter(topLevelForeachThreadOp);
-  if (failed(rewriteForeachThreadToWorkgroup(topLevelForeachThreadOp, exportOp,
-                                             rewriter))) {
-    return mlir::emitDefiniteFailure(target,
-                                     "rewriteForeachThreadToWorkgroup failed");
+  SimplePatternRewriter rewriter(topLevelForallOp);
+  if (failed(rewriteForallToWorkgroup(topLevelForallOp, exportOp, rewriter))) {
+    return mlir::emitDefiniteFailure(target, "rewriteForallToWorkgroup failed");
   }
 
   results.push_back(target);
@@ -609,10 +606,10 @@
 }
 
 //===---------------------------------------------------------------------===//
-// TileToForeachThreadAndWorkgroupCountRegionOp
+// TileToForallAndWorkgroupCountRegionOp
 //===---------------------------------------------------------------------===//
 
-void transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::build(
+void transform_dialect::TileToForallAndWorkgroupCountRegionOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec,
     ArrayAttr mappingAttr) {
@@ -622,7 +619,7 @@
                /*_=*/transform::TileSizesSpec(), /*mapping=*/mappingAttr);
 }
 
-void transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::build(
+void transform_dialect::TileToForallAndWorkgroupCountRegionOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
     ArrayAttr mappingAttr) {
@@ -646,7 +643,7 @@
         /*mapping=*/mappingAttr);
 }
 
-void transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::build(
+void transform_dialect::TileToForallAndWorkgroupCountRegionOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec,
     ArrayAttr mappingAttr) {
@@ -658,7 +655,7 @@
                /*mapping=*/mappingAttr);
 }
 
-void transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::build(
+void transform_dialect::TileToForallAndWorkgroupCountRegionOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
     ArrayAttr mappingAttr) {
@@ -777,27 +774,27 @@
   return success();
 }
 
-SmallVector<OpFoldResult> transform_dialect::
-    TileToForeachThreadAndWorkgroupCountRegionOp::getMixedNumThreads() {
+SmallVector<OpFoldResult>
+transform_dialect::TileToForallAndWorkgroupCountRegionOp::getMixedNumThreads() {
   Builder b(getContext());
   return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
 }
 
-SmallVector<OpFoldResult> transform_dialect::
-    TileToForeachThreadAndWorkgroupCountRegionOp::getMixedTileSizes() {
+SmallVector<OpFoldResult>
+transform_dialect::TileToForallAndWorkgroupCountRegionOp::getMixedTileSizes() {
   Builder b(getContext());
   return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
 }
 
 LogicalResult
-transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::verify() {
+transform_dialect::TileToForallAndWorkgroupCountRegionOp::verify() {
   if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
     return emitOpError("either num_threads or tile_sizes must be specified");
   return success();
 }
 
-void transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::
-    getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+void transform_dialect::TileToForallAndWorkgroupCountRegionOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::consumesHandle(getTarget(), effects);
   transform::onlyReadsHandle(getTileSizes(), effects);
   transform::onlyReadsHandle(getNumThreads(), effects);
@@ -806,12 +803,12 @@
 }
 
 DiagnosedSilenceableFailure
-transform_dialect::TileToForeachThreadAndWorkgroupCountRegionOp::apply(
+transform_dialect::TileToForallAndWorkgroupCountRegionOp::apply(
     transform::TransformResults &transformResults,
     transform::TransformState &state) {
   ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
   if (targetOps.empty()) {
-    transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
+    transformResults.set(getForallOp().cast<OpResult>(), {});
     transformResults.set(getTiledOp().cast<OpResult>(), {});
     return DiagnosedSilenceableFailure::success();
   }
@@ -844,20 +841,20 @@
   SmallVector<Operation *> tileOps;
   SmallVector<Operation *> tiledOps;
 
-  DiagnosedSilenceableFailure diag = transform::tileToForeachThreadOpImpl(
+  DiagnosedSilenceableFailure diag = transform::tileToForallOpImpl(
       rewriter, state, cast<transform::TransformOpInterface>(getOperation()),
       targets, getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
       tiledOps);
 
   if (!diag.succeeded()) {
-    transformResults.set(getForeachThreadOp().cast<OpResult>(),
+    transformResults.set(getForallOp().cast<OpResult>(),
                          SmallVector<mlir::Operation *>{});
     transformResults.set(getTiledOp().cast<OpResult>(),
                          SmallVector<mlir::Operation *>{});
     return diag;
   }
 
-  transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
+  transformResults.set(getForallOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index d9241fc..6ef6aa7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -20,7 +20,7 @@
 }  // namespace func
 
 namespace scf {
-class ForeachThreadOp;
+class ForallOp;
 }  // namespace scf
 
 namespace transform {
@@ -44,7 +44,7 @@
   bool foldReassociativeReshapes = false;
   bool foldTensorEmptyExtract = false;
   bool lowerTransferOpPermutations = false;
-  bool promoteForeachThreadCaptureToShared = false;
+  bool promoteForallCaptureToShared = false;
   bool rankReducingLinalg = false;
   bool rankReducingVector = false;
   bool rewritePackOps = false;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index cbd948b..ea4491e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -228,40 +228,40 @@
   let builders = [OpBuilder<(ins "Value":$target)>];
 }
 
-def ForeachThreadToWorkgroupOp : Op<Transform_Dialect,
-    "iree.foreach_thread_to_workgroup",
+def ForallToWorkgroupOp : Op<Transform_Dialect,
+    "iree.forall_to_workgroup",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformOpInterface,
      TransformEachOpTrait]> {
   let description = [{
     Target the whole hal.executable_variant op and rewrite the unique topLevel
-    scf.foreach_thread to distributed workgroup_id and workgroup_count.
+    scf.forall to distributed workgroup_id and workgroup_count.
 
     The mapping of threads to workgroup_id is currently one-to-one and in order.
-    Only **bufferized** scf.foreach_thread are currently supported.
-    Only scf.foreach_thread distributed to **at most 3 dimensions** are currently
+    Only **bufferized** scf.forall are currently supported.
+    Only scf.forall distributed to **at most 3 dimensions** are currently
     supported.
 
     Return modes:
     =============
     This operation ignores non-Func ops and drops them in the return.
 
-    If no unique scf.foreach_thread topLevel operation is found, then the
+    If no unique scf.forall topLevel operation is found, then the
     transform definitely fails.
-    If the unique topLevel scf.foreach_thread has results (i.e. tensors), then
+    If the unique topLevel scf.forall has results (i.e. tensors), then
     the transform definitely fails.
 
-    If the unique topLevel scf.foreach_thread maps to a dynamic number of
+    If the unique topLevel scf.forall maps to a dynamic number of
     threads, then the transform definitely fails. This is a temporary
-    limitation until the backward slice computing scf.foreach_thread.num_threads
+    limitation until the backward slice computing scf.forall.num_threads
     can be extracted into the hal::executable_export workgroup_count region.
     This region may require arbitrary computations and cannot magically match
     what the `stream.cmd.dispatch` has already imposed on us at a distance.
     For now we must specify the number of values properly when applying the
-    topLevel tile_to_foreach_thread_op.
+    topLevel tile_to_forall_op.
 
-    If the unique topLevel scf.foreach_thread operation contained within the
+    If the unique topLevel scf.forall operation contained within the
     FuncOp referred to by the `target` PDLOperation lowers to workgroup properly,
     the transform succeeds. Otherwise the transform definitely fails.
 
@@ -288,14 +288,14 @@
   }];
 }
 
-def ShareForeachThreadOperandsOp : Op<
-    Transform_Dialect, "iree.share_foreach_thread_operands", [
+def ShareForallOperandsOp : Op<
+    Transform_Dialect, "iree.share_forall_operands", [
       FunctionalStyleTransformOpTrait,
       MemoryEffectsOpInterface,
       TransformEachOpTrait,
       TransformOpInterface]> {
   let description = [{
-    Target a single scf.foreach_thread op and shares all uses of the specified 
+    Target a single scf.forall op and shares all uses of the specified 
     `share_operands` operand indices.
 
     Sharing can be thought of as the inverse of traditional privatization.
@@ -303,7 +303,7 @@
     by a single thread to and subsequently slicing out that part into a 
     thread_private storage that has smaller footprint, better locality and better
     alignment properties.
-    In the case of scf.foreach_thread on tensors, tensor values are immutable 
+    In the case of scf.forall on tensors, tensor values are immutable 
     and the same tensor value may be passed as `shared_outs` and also captured
     for internal uses.
     Due to the immutability property, the whole tensor values are private by 
@@ -320,7 +320,7 @@
     However this can still be unsafe wrt parallelism so use carefully!
 
     Sharing consists in rewriting all uses of the operands passed as 
-    `shared_outs` that are also captured wihtin the `scf.foreach_thread` region
+    `shared_outs` that are also captured wihtin the `scf.forall` region
     into the matching `shared_outs` bbarg.
 
     Only those operands whose indices are specified in `share_operands` are
@@ -335,42 +335,42 @@
     In the future, we should emit a notification.
 
     This transform consumes the target handle and produces a result handle to
-    the modified `scf.foreach_thread` op.
+    the modified `scf.forall` op.
   }];
 
   let arguments = (
-      ins Transform_ConcreteOpType<"scf.foreach_thread">:$foreach_thread_op,
+      ins Transform_ConcreteOpType<"scf.forall">:$forall_op,
           DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$share_operands
   );
-  let results = (outs Transform_ConcreteOpType<"scf.foreach_thread">:$result);
+  let results = (outs Transform_ConcreteOpType<"scf.forall">:$result);
 
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
   let assemblyFormat = [{
-    $foreach_thread_op (`share_operands` `=` $share_operands^ )? attr-dict
+    $forall_op (`share_operands` `=` $share_operands^ )? attr-dict
       `:` functional-type(operands, results)
   }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForeachThreadOp foreachThreadOp,
+        ::mlir::scf::ForallOp forallOp,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
 }
 
-def TileToForeachThreadAndWorkgroupCountRegionOp :
-    Op<Transform_Dialect, "iree.tile_to_foreach_thread_and_workgroup_count_region",
+def TileToForallAndWorkgroupCountRegionOp :
+    Op<Transform_Dialect, "iree.tile_to_forall_and_workgroup_count_region",
       [AttrSizedOperandSegments,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        TransformOpInterface]> {
   let description = [{
-    Wrapper around `structured.tile_to_foreach_thread_op` for use within IREE.
+    Wrapper around `structured.tile_to_forall_op` for use within IREE.
 
-    In addition to tile and distribute using `scf.foreach_thread`, lowers the
+    In addition to tile and distribute using `scf.forall`, lowers the
     the `workgroup_count` region of the export op corresponding to the parent
     `func.func` of the target to return the number of workgroups.
-    Please see the doc of `structured.tile_to_foreach_thread_op` for full
+    Please see the doc of `structured.tile_to_forall_op` for full
     description of op semantics.
   }];
 
@@ -380,7 +380,7 @@
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
-  let results = (outs PDL_Operation:$foreach_thread_op,
+  let results = (outs PDL_Operation:$forall_op,
                       PDL_Operation:$tiled_op);
   let assemblyFormat = [{
     $target oilist(
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reductions.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reductions.mlir
index 0503be7..289bb02 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reductions.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reductions.mlir
@@ -30,11 +30,11 @@
 }
 
 // CHECK-LABEL: @reduce
-// CHECK: scf.foreach_thread
-// CHECK:   scf.foreach_thread
+// CHECK: scf.forall
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
-// CHECK:   scf.foreach_thread
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
 
@@ -77,12 +77,12 @@
 }
 
 // CHECK-LABEL: @eltwise_reduce
-// CHECK: scf.foreach_thread
-// CHECK:   scf.foreach_thread
+// CHECK: scf.forall
+// CHECK:   scf.forall
 // CHECK:     linalg.generic
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
-// CHECK:   scf.foreach_thread
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
 
@@ -124,11 +124,11 @@
 
 
 // CHECK-LABEL: @reduce_eltwise
-// CHECK: scf.foreach_thread
-// CHECK:   scf.foreach_thread
+// CHECK: scf.forall
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
-// CHECK:   scf.foreach_thread
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
 // CHECK:     linalg.generic
@@ -185,12 +185,12 @@
 }
 
 // CHECK-LABEL: @eltwise_reduce_eltwise
-// CHECK: scf.foreach_thread
-// CHECK:   scf.foreach_thread
+// CHECK: scf.forall
+// CHECK:   scf.forall
 // CHECK:     linalg.generic
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
-// CHECK:   scf.foreach_thread
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
 // CHECK:     linalg.generic
@@ -247,12 +247,12 @@
 }
 
 // CHECK-LABEL: @eltwise_reduce_eltwise_swapped
-// CHECK: scf.foreach_thread
-// CHECK:   scf.foreach_thread
+// CHECK: scf.forall
+// CHECK:   scf.forall
 // CHECK:     linalg.generic
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
-// CHECK:   scf.foreach_thread
+// CHECK:   scf.forall
 // CHECK:     linalg.fill
 // CHECK:     linalg.generic
 // CHECK:     linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
index b63d491..360ba82 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
@@ -14,7 +14,7 @@
   // Step 1. Map to a single block by tiling with size 1 and fusing.
   %fusion_root_1, %fusion_group_1 = transform.iree.take_first %maybe_trailing_0, %combiner_op
     : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
-  %grid_loop, %outer_tiled = transform.structured.tile_to_foreach_thread_op %fusion_root_1 tile_sizes [1]
+  %grid_loop, %outer_tiled = transform.structured.tile_to_forall_op %fusion_root_1 tile_sizes [1]
     ( mapping = [#gpu.block<x>] )
   
   %func = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
@@ -42,14 +42,14 @@
   %fusion_group_22_full = transform.merge_handles %fused_2, %original_fill_2
     : !pdl.operation
   %block_loop_22, %fusion_root_22_tiled =
-    transform.structured.tile_to_foreach_thread_op %outer_tiled
+    transform.structured.tile_to_forall_op %outer_tiled
     tile_sizes [1] ( mapping = [#gpu.thread<z>] )
   transform.structured.fuse_into_containing_op %fusion_group_22_full into %block_loop_22
 
   %fusion_group_21 = transform.merge_handles %maybe_leading_2, %more_parallel_fill_2
     : !pdl.operation
   %block_loop_21, %fusion_root_21_tiled =
-    transform.structured.tile_to_foreach_thread_op %parallel_reduction_2
+    transform.structured.tile_to_forall_op %parallel_reduction_2
     tile_sizes [1, 1] ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   transform.structured.fuse_into_containing_op %fusion_group_21 into %block_loop_21
   
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
index eda27bb..b74075a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
@@ -32,11 +32,11 @@
   %empty = tensor.empty() : tensor<16x128xf32>
   %filled = linalg.fill ins(%f0 : f32) outs(%empty : tensor<16x128xf32>) -> tensor<16x128xf32>
 
-  // CHECK: foreach_thread{{.*}}shared_outs(%[[ARG:.*]] =
+  // CHECK: forall{{.*}}shared_outs(%[[ARG:.*]] =
   // CHECK:   %[[A:.*]] = tensor.extract_slice %[[ARG]]
   // CHECK:   %[[B:.*]] = tensor.extract_slice %[[ARG]]
   // CHECK:   %[[C:.*]] = linalg.generic{{.*}}ins(%[[A]]{{.*}}outs(%[[B]]
-  %10 = scf.foreach_thread (%arg0, %arg1) in (%c16, %c32) shared_outs(%arg2 = %filled) -> (tensor<16x128xf32>) {
+  %10 = scf.forall (%arg0, %arg1) in (%c16, %c32) shared_outs(%arg2 = %filled) -> (tensor<16x128xf32>) {
     %11 = affine.apply #map2(%arg1)
     %extracted_slice = tensor.extract_slice %filled[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
     %extracted_slice_2 = tensor.extract_slice %arg2[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
@@ -45,7 +45,7 @@
       %res = arith.addf %in, %in: f32
       linalg.yield %res : f32
     } -> tensor<1x4xf32>
-    scf.foreach_thread.perform_concurrently {
+    scf.forall.in_parallel {
       tensor.parallel_insert_slice %13 into %arg2[%arg0, %11] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<16x128xf32>
     }
   }
@@ -54,9 +54,9 @@
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.foreach_thread">
-  transform.iree.share_foreach_thread_operands %1 share_operands = [0] : (!transform.op<"scf.foreach_thread">) -> !transform.op<"scf.foreach_thread">
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.forall">
+  transform.iree.share_forall_operands %1 share_operands = [0] : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
 }
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
index c466cfd..ce05c5a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
@@ -49,10 +49,21 @@
         gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
         gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
 
+    auto threadIdGenerator = [](RewriterBase& rewriter, scf::ForallOp forallOp,
+                                SmallVectorImpl<Value>& threadIds) {
+      Location loc = forallOp.getLoc();
+      IndexType indexType = rewriter.getIndexType();
+      threadIds.assign(
+          {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
+           rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
+           rewriter.create<gpu::ThreadIdOp>(loc, indexType,
+                                            gpu::Dimension::z)});
+    };
+
     DiagnosedSilenceableFailure const result =
         mlir::transform::gpu::mapNestedForeachToThreadsImpl(
-            rewriter, funcOp, workgroupSize, false, std::nullopt,
-            threadMappingAttributes);
+            rewriter, funcOp, workgroupSize, threadIdGenerator, false,
+            std::nullopt, threadMappingAttributes);
 
     if (!result.succeeded()) return signalPassFailure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
index 42c58ce..a65da9a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
@@ -138,7 +138,7 @@
     std::reverse(idDims.begin(), idDims.end());
     ArrayAttr mapping = rewriter.getArrayAttr(idDims);
     auto tilingResult =
-        linalg::tileToForeachThreadOp(rewriter, tilingOp, numThreads, mapping);
+        linalg::tileToForallOp(rewriter, tilingOp, numThreads, mapping);
     rewriter.replaceOp(tilingOp, tilingResult->tileOp->getResults());
   }
   return success();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 4883ca3..7141f1e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -52,13 +52,12 @@
 // IREE-specific LLVMGPU transformations.
 //===---------------------------------------------------------------------===//
 
-void transform_dialect::MapNestedForeachThreadToGpuThreadsOp::build(
+void transform_dialect::MapNestedForallToGpuThreadsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<int64_t> workgroupSize) {
   result.addOperands(target);
   result.addAttribute(
-      MapNestedForeachThreadToGpuThreadsOp::getWorkgroupSizeAttrName(
-          result.name),
+      MapNestedForallToGpuThreadsOp::getWorkgroupSizeAttrName(result.name),
       builder.getI64ArrayAttr(workgroupSize));
   MLIRContext *ctx = builder.getContext();
   result.addTypes({pdl::OperationType::get(ctx)});
@@ -68,7 +67,7 @@
 // reuse most of the code and not require a static number of threads.
 // TODO: synchronizations for imperfectly nested stuff.
 DiagnosedSilenceableFailure
-transform_dialect::MapNestedForeachThreadToGpuThreadsOp::applyToOne(
+transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne(
     func::FuncOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
@@ -101,9 +100,19 @@
       gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
       gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
 
+  auto threadIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
+                              SmallVectorImpl<Value> &threadIds) {
+    Location loc = forallOp.getLoc();
+    IndexType indexType = rewriter.getIndexType();
+    threadIds.assign(
+        {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
+         rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
+         rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::z)});
+  };
+
   DiagnosedSilenceableFailure diag =
       mlir::transform::gpu::mapNestedForeachToThreadsImpl(
-          rewriter, target, workgroupSize, true, transformOp,
+          rewriter, target, workgroupSize, threadIdGenerator, true, transformOp,
           threadMappingAttributes);
 
   if (diag.succeeded()) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h
index ff86d25..4513124 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h
@@ -19,7 +19,7 @@
 }
 
 namespace scf {
-class ForeachThreadOp;
+class ForallOp;
 class IfOp;
 class ForOp;
 }  // namespace scf
@@ -46,11 +46,10 @@
 }  // namespace transform_dialect
 }  // namespace IREE
 
-/// Transformation to convert scf.foreach_thread to gpu distribution.
-FailureOr<SmallVector<OpFoldResult>> rewriteForeachThreadToGpu(
-    scf::ForeachThreadOp foreachThreadOp,
-    const SmallVector<int64_t> &globalWorkgroupSizes, RewriterBase &rewriter,
-    bool syncAfterDistribute = true);
+/// Transformation to convert scf.forall to gpu distribution.
+FailureOr<SmallVector<OpFoldResult>> rewriteForallToGpu(
+    scf::ForallOp forallOp, const SmallVector<int64_t> &globalWorkgroupSizes,
+    RewriterBase &rewriter, bool syncAfterDistribute = true);
 
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 656fee7..bb1b4fb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -14,36 +14,36 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
-def MapNestedForeachThreadToGpuThreadsOp :
-  Op<Transform_Dialect, "iree.map_nested_foreach_thread_to_gpu_threads",
+def MapNestedForallToGpuThreadsOp :
+  Op<Transform_Dialect, "iree.map_nested_forall_to_gpu_threads",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformEachOpTrait,
      TransformOpInterface]> {
   let description = [{
-    Target the whole hal.executable_variant op and rewrite all scf.foreach_thread
+    Target the whole hal.executable_variant op and rewrite all scf.forall
     to distributed gpu.thread_id and translation_info attribute.
 
     The mapping of threads to gpu.thread_id is currently one-to-one and in order.
-    Only **bufferized** scf.foreach_thread are currently supported.
-    Only scf.foreach_thread distributed to **at most 3 dimensions** are currently
+    Only **bufferized** scf.forall are currently supported.
+    Only scf.forall distributed to **at most 3 dimensions** are currently
     supported.
 
-    Multiple scf.foreach_thread are supported per function in which case, the
+    Multiple scf.forall are supported per function in which case, the
     max of all the threads is computed and taken for the global gpu.thread_id.
-    If necessary, scf.foreach_thread that do not use the whole thread range
+    If necessary, scf.forall that do not use the whole thread range
     result in predicated computations.
 
-    Barriers are inserted after each scf.foreach_thread op for now.
+    Barriers are inserted after each scf.forall op for now.
 
     Return modes:
     =============
     This operation ignores non-Func ops and drops them in the return.
 
-    If any scf.foreach_thread with tensors is found, the transform definitely
+    If any scf.forall with tensors is found, the transform definitely
     fails.
 
-    If all the scf.foreach_thread operations contained within the FuncOp
+    If all the scf.forall operations contained within the FuncOp
     referred to by the `target` PDLOperation lower to GPU properly, the
     transform succeeds. Otherwise the transform definitely fails.
 
@@ -59,10 +59,10 @@
       hal.executable.variant {
         hal.executable.export {
           func @foo() {
-            scf.foreach_thread (%i, %j) in (7, 9) {
+            scf.forall (%i, %j) in (7, 9) {
               ... // body 1
             }
-            scf.foreach_thread (%i) in (12) {
+            scf.forall (%i) in (12) {
               ... // body 2
             }
           }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_foreach.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_foreach.mlir
index 5049c6e..fddbf57 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_foreach.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_foreach.mlir
@@ -28,7 +28,7 @@
       %4 = memref.subview %2[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
       %5 = memref.subview %0[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
       %6 = memref.subview %1[%workgroup_id_y, %3] [1, 256] [1, 1] : memref<233x1024xf32> to memref<1x256xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
-      scf.foreach_thread (%arg0) in (%c64) shared_outs() -> () {
+      scf.forall (%arg0) in (%c64) shared_outs() -> () {
         %7 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
         %8 = memref.subview %4[0, %7] [1, 4] [1, 1] : memref<1x256xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<1x4xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
         %9 = vector.transfer_read %5[%c0, %7], %cst {in_bounds = [true]} : memref<1x256xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, vector<4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index d263246..e827c0b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -32,22 +32,22 @@
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
 //         CHECK:   transform.iree.match_callback failures(propagate) "reduction"(%{{.+}})
 //         CHECK:   transform.iree.take_first
-//         CHECK:   transform.iree.tile_to_foreach_thread_and_workgroup_count_region {{.*}} tile_sizes [1](mapping = [#gpu.block<x>])
+//         CHECK:   transform.iree.tile_to_forall_and_workgroup_count_region {{.*}} tile_sizes [1](mapping = [#gpu.block<x>])
 // CHECK-COUNT-2:   transform.structured.fuse_into_containing_op
 //         CHECK:   transform.iree.take_first
-//         CHECK:   tile_reduction_using_foreach_thread {{.*}} by num_threads = [0, 64], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
+//         CHECK:   tile_reduction_using_forall {{.*}} by num_threads = [0, 64], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
 //         CHECK:   transform.structured.fuse_into_containing_op
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [1](mapping = [#gpu.thread<y>])
-//         CHECK:   cast %{{.*}} : !pdl.operation to !transform.op<"scf.foreach_thread">
-//         CHECK:   transform.iree.share_foreach_thread_operands %{{.*}} share_operands = [0] : (!transform.op<"scf.foreach_thread">) -> !transform.op<"scf.foreach_thread">
+//         CHECK:   transform.structured.tile_to_forall_op %{{.*}} tile_sizes [1](mapping = [#gpu.thread<y>])
+//         CHECK:   cast %{{.*}} : !pdl.operation to !transform.op<"scf.forall">
+//         CHECK:   transform.iree.share_forall_operands %{{.*}} share_operands = [0] : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
 //         CHECK:   transform.structured.match ops{["func.func"]} in %arg0
 //         CHECK:   transform.structured.vectorize
 //         CHECK:   transform.iree.bufferize {target_gpu}
 //         CHECK:   transform.structured.match ops{["func.func"]} in %{{.*}}
 //         CHECK:   transform.iree.erase_hal_descriptor_type_from_memref
 //         CHECK:   transform.structured.match ops{["func.func"]} in %{{.*}}
-//         CHECK:   transform.iree.foreach_thread_to_workgroup
-//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [64, 1, 1]}
+//         CHECK:   transform.iree.forall_to_workgroup
+//         CHECK:   transform.iree.map_nested_forall_to_gpu_threads %{{.*}} {workgroup_size = [64, 1, 1]}
 //         CHECK:   transform.iree.apply_patterns %{{.*}} {fold_memref_aliases, rank_reducing_vector}
 //         CHECK:   transform.structured.match ops{["scf.if"]} in %{{.*}}
 //         CHECK:   sequence {{.*}} failures(suppress) {
@@ -92,8 +92,8 @@
 
 //   CHECK-LABEL: func.func @group_reduction_128
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.structured.tile_reduction_using_foreach_thread %{{.*}} by num_threads = [0, 32], tile_sizes = [0, 4], mapping = [#gpu.thread<x>]
-//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
+//         CHECK:   transform.structured.tile_reduction_using_forall %{{.*}} by num_threads = [0, 32], tile_sizes = [0, 4], mapping = [#gpu.thread<x>]
+//         CHECK:   transform.iree.map_nested_forall_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
 //         CHECK:   transform.iree.vector.to_warp_execute_on_lane_0 %{{.*}} {warp_size = 32 : i64}
 
 // -----
@@ -134,8 +134,8 @@
 
 //   CHECK-LABEL: func.func @group_reduction_D
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.structured.tile_reduction_using_foreach_thread %{{.*}} by num_threads = [0, 256], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
-//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [256, 1, 1]}
+//         CHECK:   transform.structured.tile_reduction_using_forall %{{.*}} by num_threads = [0, 256], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
+//         CHECK:   transform.iree.map_nested_forall_to_gpu_threads %{{.*}} {workgroup_size = [256, 1, 1]}
 //         CHECK:   transform.iree.vector.to_warp_execute_on_lane_0 %{{.*}} {warp_size = 256 : i64}
 
 // -----
@@ -174,12 +174,12 @@
 
 //   CHECK-LABEL: func.func @group_reduction_34
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.iree.tile_to_foreach_thread_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [64](mapping = [#gpu.block<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}}   num_threads [64] tile_sizes [](mapping = [#gpu.thread<x>])
+//         CHECK:   transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [64](mapping = [#gpu.block<x>])
+//         CHECK:   transform.structured.tile_to_forall_op %{{.*}}   num_threads [64] tile_sizes [](mapping = [#gpu.thread<x>])
 // CHECK-COUNT-4:   transform.structured.scalarize %{{.*}}
 //         CHECK:   transform.structured.split %{{.*}} after 32  {dimension = 1 : i64}
 //         CHECK:   transform.structured.tile %{{.*}}[0, 4]
-//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [64, 1, 1]}
+//         CHECK:   transform.iree.map_nested_forall_to_gpu_threads %{{.*}} {workgroup_size = [64, 1, 1]}
 //     CHECK-NOT:   transform.iree.vector.to_warp_execute_on_lane_0
 
 
@@ -227,11 +227,11 @@
 
 //   CHECK-LABEL: func.func @group_reduction_12345
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.iree.tile_to_foreach_thread_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [1](mapping = [#gpu.block<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}}   num_threads [] tile_sizes [1](mapping = [#gpu.thread<y>])
+//         CHECK:   transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [1](mapping = [#gpu.block<x>])
+//         CHECK:   transform.structured.tile_to_forall_op %{{.*}}   num_threads [] tile_sizes [1](mapping = [#gpu.thread<y>])
 //         CHECK:   transform.structured.split %{{.*}} after 8192  {dimension = 1 : i64}
 //         CHECK:   transform.structured.tile %{{.*}}[0, 8192]
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}}   num_threads [0, 1024] tile_sizes [](mapping = [#gpu.thread<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}}   num_threads [0, 1024] tile_sizes [](mapping = [#gpu.thread<x>])
-//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [1024, 1, 1]}
+//         CHECK:   transform.structured.tile_to_forall_op %{{.*}}   num_threads [0, 1024] tile_sizes [](mapping = [#gpu.thread<x>])
+//         CHECK:   transform.structured.tile_to_forall_op %{{.*}}   num_threads [0, 1024] tile_sizes [](mapping = [#gpu.thread<x>])
+//         CHECK:   transform.iree.map_nested_forall_to_gpu_threads %{{.*}} {workgroup_size = [1024, 1, 1]}
 //         CHECK:   transform.iree.vector.to_warp_execute_on_lane_0{{.*}}{warp_size = 1024 : i64}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir
index 5f4ebee..4127abf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/tile_on_tensor.mlir
@@ -37,14 +37,13 @@
 
 //         CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 4)>
 //   CHECK-LABEL: func.func @add_tensor
-//     CHECK-DAG:   %[[C64:.*]] = arith.constant 64 : index
 //     CHECK-DAG:   %[[A:.*]] = hal.interface.binding.subspan set(0) binding(0)
 //     CHECK-DAG:   %[[B:.*]] = hal.interface.binding.subspan set(0) binding(1)
 //     CHECK-DAG:   %[[C:.*]] = hal.interface.binding.subspan set(0) binding(2)
 //     CHECK-DAG:   %[[LA:.*]] = flow.dispatch.tensor.load %[[A]]
 //     CHECK-DAG:   %[[LB:.*]] = flow.dispatch.tensor.load %[[B]]
 //     CHECK-DAG:   %[[LC:.*]] = flow.dispatch.tensor.load %[[C]]
-//         CHECK:   %[[T:.*]] = scf.foreach_thread (%[[ARG:.*]]) in (%[[C64]]) shared_outs(%[[O:.*]] = %[[LC]]) -> (tensor<1x256xf32>) {
+//         CHECK:   %[[T:.*]] = scf.forall (%[[ARG:.*]]) in (64) shared_outs(%[[O:.*]] = %[[LC]]) -> (tensor<1x256xf32>) {
 //         CHECK:     %[[OFF:.*]] = affine.apply #[[$MAP]](%[[ARG]])
 //     CHECK-DAG:     %[[TA:.*]] = tensor.extract_slice %[[LA]][0, %[[OFF]]] [1, 4] [1, 1] : tensor<1x256xf32> to tensor<1x4xf32>
 //     CHECK-DAG:     %[[TB:.*]] = tensor.extract_slice %[[LB]][0, %[[OFF]]] [1, 4] [1, 1] : tensor<1x256xf32> to tensor<1x4xf32>
@@ -53,7 +52,7 @@
 //         CHECK:       %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
 //         CHECK:       linalg.yield %{{.*}} : f32
 //         CHECK:     } -> tensor<1x4xf32>
-//         CHECK:     scf.foreach_thread.perform_concurrently {
+//         CHECK:     scf.forall.in_parallel {
 //         CHECK:       tensor.parallel_insert_slice %[[L]] into %[[O]][0, %[[OFF]]] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<1x256xf32>
 //         CHECK:     }
 //         CHECK:   } {mapping = [#gpu.thread<x>]}
@@ -98,12 +97,11 @@
 //   CHECK-LABEL: func.func @reduction
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
-//     CHECK-DAG:   %[[C64:.*]] = arith.constant 64 : index
 //     CHECK-DAG:   %[[C384:.*]] = arith.constant 384 : index
 //         First the scf.foreach for the linalg.fill.
-//         CHECK:   scf.foreach_thread
+//         CHECK:   scf.forall
 //         then the reduction case.
-//         CHECK:   %[[T:.*]] = scf.foreach_thread (%[[ARG:.*]]) in (%[[C64]]) shared_outs(%[[O:.+]] = %{{.+}}) -> (tensor<64xf32>) {
+//         CHECK:   %[[T:.*]] = scf.forall (%[[ARG:.*]]) in (64) shared_outs(%[[O:.+]] = %{{.+}}) -> (tensor<64xf32>) {
 //         CHECK:     %[[OUTSLICE:.*]] = tensor.extract_slice %{{.*}}[%[[ARG]], 0] [1, 384] [1, 1] : tensor<64x384xf32> to tensor<1x384xf32>
 //         CHECK:     %[[A:.*]] = tensor.extract_slice %[[O]][%[[ARG]]] [1] [1] : tensor<64xf32> to tensor<1xf32>
 //         CHECK:     %[[R:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C384]] step %[[C4]] iter_args(%[[ACC:.*]] = %[[A]]) -> (tensor<1xf32>) {
@@ -114,7 +112,7 @@
 //         CHECK:       } -> tensor<1xf32>
 //         CHECK:       scf.yield %[[L]] : tensor<1xf32>
 //         CHECK:     }
-//         CHECK:     scf.foreach_thread.perform_concurrently {
+//         CHECK:     scf.forall.in_parallel {
 //         CHECK:       tensor.parallel_insert_slice %[[R]] into %[[O]][%[[ARG]]] [1] [1] : tensor<1xf32> into tensor<64xf32>
 //         CHECK:     }
 //         CHECK:   } {mapping = [#gpu.thread<x>]}
@@ -179,11 +177,11 @@
 //     CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
 //     CHECK-DAG:   %[[C10:.*]] = arith.constant 10 : index
 //     CHECK-DAG:   %[[C4096:.*]] = arith.constant 4096 : index
-//         CHECK:   scf.foreach_thread
+//         CHECK:   scf.forall
 //         CHECK:     linalg.fill
-//         CHECK:   scf.foreach_thread
+//         CHECK:   scf.forall
 //         CHECK:     scf.for %{{.*}} = %[[C0]] to %[[C10]] step %[[C4]]
 //         CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C4096]] step %[[C4]]
 //         CHECK:         linalg.generic
-//         CHECK:   scf.foreach_thread
+//         CHECK:   scf.forall
 //         CHECK:     linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 3aa246f..8ee8556 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -1,11 +1,11 @@
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%variant_op: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread, %tiled_fill = transform.structured.tile_to_foreach_thread_op %0 num_threads [5, 1] 
+  %forall, %tiled_fill = transform.structured.tile_to_forall_op %0 num_threads [5, 1] 
   ( mapping = [#gpu.thread<y>, #gpu.thread<x>] )
 
   %1 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_2, %tiled_matmul = transform.structured.tile_to_foreach_thread_op %1 num_threads [7, 9]
+  %forall_2, %tiled_matmul = transform.structured.tile_to_forall_op %1 num_threads [7, 9]
   ( mapping = [#gpu.thread<x>, #gpu.thread<y>] )
 
   %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
@@ -16,5 +16,5 @@
   // Get the function to which to apply to.
   %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
   %func = transform.get_closest_isolated_parent %2 : (!pdl.operation) -> !pdl.operation
-  transform.iree.map_nested_foreach_thread_to_gpu_threads %func { workgroup_size = [10, 11]}
+  transform.iree.map_nested_forall_to_gpu_threads %func { workgroup_size = [10, 11]}
 }
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index 5bf08ca..36d1d24 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -472,7 +472,7 @@
 
 def LLVMGPUDistribute :
     Pass<"iree-llvmgpu-distribute", "func::FuncOp"> {
-  let summary = "Pass to distribute foreachthread ops.";
+  let summary = "Pass to distribute scf.forall ops.";
   let constructor = "mlir::iree_compiler::createLLVMGPUDistribute()";
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
index 3281b17..4a8cdf8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
@@ -25,7 +25,7 @@
     // mismatched vector size for transfer and compute.
     vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
         patterns, [](vector::ExtractStridedSliceOp op) {
-          return op.getVectorType().getNumElements() > 4;
+          return op.getSourceVectorType().getNumElements() > 4;
         });
     vector::InsertOp::getCanonicalizationPatterns(patterns, context);
     vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 7b779bd..bf2e7c8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -98,7 +98,7 @@
     return nativeSize;
   } else if (auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
     // Unroll all reduction dimensions by size 1 for vector.multi_reduction.
-    auto srcVectorType = reductionOp.getSourceVectorType();
+    VectorType srcVectorType = reductionOp.getSourceVectorType();
     auto nativeSize = llvm::to_vector<>(srcVectorType.getShape());
     auto dims = reductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
     for (const auto &dimAttr : dims) {
@@ -106,12 +106,12 @@
     }
     return nativeSize;
   } else if (auto reductionOp = dyn_cast<vector::ReductionOp>(op)) {
-    auto srcVectorType = reductionOp.getVectorType();
+    VectorType srcVectorType = reductionOp.getSourceVectorType();
     assert(srcVectorType.getRank() == 1);  // Guaranteed by semantics
     int64_t vectorSize = getComputeVectorSize(srcVectorType.getDimSize(0));
     return SmallVector<int64_t>{vectorSize};
   } else if (auto transposeOp = dyn_cast<vector::TransposeOp>(op)) {
-    auto vectorType = transposeOp.getResultType();
+    VectorType vectorType = transposeOp.getResultVectorType();
     SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
     nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
     return nativeSize;
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
index 9d7ced1..07bc9ba 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
@@ -31,7 +31,7 @@
 using iree_compiler::cpu::CPUModel;
 using iree_compiler::cpu::ReductionConfig;
 using iree_compiler::cpu::ReductionStrategy;
-using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using transform::LowerVectorsOp;
 using transform::MatchOp;
 using transform::SplitHandlesOp;
@@ -62,7 +62,7 @@
   // Need to match again since bufferize invalidated all handles.
   // TODO: assumes a single func::FuncOp to transform, may need hardening.
   funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
+  funcH = b.create<ForallToWorkgroupOp>(funcH);
   auto pdlOperation = pdl::OperationType::get(b.getContext());
 
   // Step N. Lower vectors.
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/ReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/ReductionStrategy.cpp
index 0f17af3..3bd6764 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/ReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/ReductionStrategy.cpp
@@ -62,7 +62,7 @@
     const ReductionStrategy &strategy) {
   // Step 1. Tiling to the block/workgroup level. Keep everything fused.
   auto [maybeLeadingHBlock, gridFillH, gridReductionH, maybeTiledTrailingHBlock,
-        foreachThread] =
+        forall] =
       buildReductionStrategyBlockDistribution(b, variantH, strategy);
 
   // Step 2. Naive first strategy to tile the most minor dimension by
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
index d81605a..1bc37b5 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
@@ -26,13 +26,13 @@
 using iree_compiler::IREE::transform_dialect::ApplyBufferOptimizationsOp;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
-using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
 using iree_compiler::IREE::transform_dialect::IREEEliminateEmptyTensorsOp;
 using iree_compiler::IREE::transform_dialect::
     IREEEraseHALDescriptorTypeFromMemRefOp;
 using iree_compiler::IREE::transform_dialect::
-    TileToForeachThreadAndWorkgroupCountRegionOp;
+    TileToForallAndWorkgroupCountRegionOp;
 using transform::FuseIntoContainingOp;
 using transform::MatchOp;
 using transform::MergeHandlesOp;
@@ -40,7 +40,7 @@
 using transform::SequenceOp;
 using transform::SplitHandlesOp;
 using transform::SplitReductionOp;
-using transform::TileToForeachThreadOp;
+using transform::TileToForallOp;
 using transform::VectorizeOp;
 using transform_ext::RegisterMatchCallbacksOp;
 using transform_ext::TakeFirstOp;
@@ -148,12 +148,12 @@
 }
 
 /// Performs the following transformations:
-///   1. Tiles `rootH` to scf.foreach_thread to with `tileSizesOrNumThreads`
+///   1. Tiles `rootH` to scf.forall to with `tileSizesOrNumThreads`
 ///      according to whether spec is a TileSizesSpec or a NumThreadsSpec.
-///   2. Maps the resulting scf.foreach_thread to threads according to
+///   2. Maps the resulting scf.forall to threads according to
 ///      `threadDimMapping`.
 ///   3. Iterates over `opsHToFuse` in order and fuses into the containing op.
-/// Returns a handle to the resulting scf.foreach_thread.
+/// Returns a handle to the resulting scf.forall.
 ///
 /// Fusion operates in batch mode: a single fusion command is issued and a
 /// topological sort is automatically computed by the fusion.
@@ -168,25 +168,25 @@
 /// appended in order.
 // TODO: apply forwarding pattern.
 template <typename TilingTransformOp, typename TileOrNumThreadSpec>
-static iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
+static iree_compiler::TileToForallAndFuseAndDistributeResult
 buildTileAndFuseAndDistributeImpl(ImplicitLocOpBuilder &b, Value rootH,
                                   ValueRange opsHToFuse,
                                   ArrayRef<OpFoldResult> tileSizesOrNumThreads,
                                   ArrayAttr threadDimMapping) {
-  iree_compiler::TileToForeachThreadAndFuseAndDistributeResult result;
+  iree_compiler::TileToForallAndFuseAndDistributeResult result;
   auto tileToForeachOp = b.create<TilingTransformOp>(
       rootH, tileSizesOrNumThreads, TileOrNumThreadSpec(), threadDimMapping);
-  result.foreachThreadH = tileToForeachOp.getForeachThreadOp();
+  result.forallH = tileToForeachOp.getForallOp();
   result.tiledOpH = tileToForeachOp.getTiledOp();
 
   // Batch fusion if requested.
   if (opsHToFuse.size() > 1) {
     Value mergedOpsH =
         b.create<MergeHandlesOp>(opsHToFuse, /*deduplicate=*/true);
-    b.create<FuseIntoContainingOp>(mergedOpsH, result.foreachThreadH);
+    b.create<FuseIntoContainingOp>(mergedOpsH, result.forallH);
   } else if (opsHToFuse.size() == 1) {
-    Value fusedH = b.create<FuseIntoContainingOp>(opsHToFuse.front(),
-                                                  result.foreachThreadH);
+    Value fusedH =
+        b.create<FuseIntoContainingOp>(opsHToFuse.front(), result.forallH);
     result.resultingFusedOpsHandles.push_back(fusedH);
   }
   return result;
@@ -195,7 +195,7 @@
 // TODO: if someone knows how to properly export templates go for it ..
 // sigh.
 template <typename TilingTransformOp>
-static iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
+static iree_compiler::TileToForallAndFuseAndDistributeResult
 buildTileFuseDistWithTileSizes(ImplicitLocOpBuilder &b, Value rootH,
                                ValueRange opsHToFuse,
                                ArrayRef<OpFoldResult> tileSizes,
@@ -204,20 +204,18 @@
                                            transform::TileSizesSpec>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping);
 }
-iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
-mlir::iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
+iree_compiler::TileToForallAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForallWithTileSizes(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
-  return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+  return buildTileFuseDistWithTileSizes<TileToForallOp>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping);
 }
-iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
-mlir::iree_compiler::
-    buildTileFuseDistToForeachThreadAndWorkgroupCountWithTileSizes(
-        ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
-        ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
-  return buildTileFuseDistWithTileSizes<
-      TileToForeachThreadAndWorkgroupCountRegionOp>(
+iree_compiler::TileToForallAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForallAndWorkgroupCountWithTileSizes(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
+  return buildTileFuseDistWithTileSizes<TileToForallAndWorkgroupCountRegionOp>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping);
 }
 
@@ -225,7 +223,7 @@
 // TODO: if someone knows how to properly export templates go for it ..
 // sigh.
 template <typename TilingTransformOp>
-static iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
+static iree_compiler::TileToForallAndFuseAndDistributeResult
 buildTileFuseDistWithNumThreads(ImplicitLocOpBuilder &b, Value rootH,
                                 ValueRange opsHToFuse,
                                 ArrayRef<OpFoldResult> numThreads,
@@ -234,20 +232,18 @@
                                            transform::NumThreadsSpec>(
       b, rootH, opsHToFuse, numThreads, threadDimMapping);
 }
-iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
-mlir::iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+iree_compiler::TileToForallAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForallWithNumThreads(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
-  return buildTileFuseDistWithNumThreads<TileToForeachThreadOp>(
+  return buildTileFuseDistWithNumThreads<TileToForallOp>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping);
 }
-iree_compiler::TileToForeachThreadAndFuseAndDistributeResult
-mlir::iree_compiler::
-    buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
-        ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
-        ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
-  return buildTileFuseDistWithNumThreads<
-      TileToForeachThreadAndWorkgroupCountRegionOp>(
+iree_compiler::TileToForallAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForallAndWorgroupCountWithNumThreads(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
+  return buildTileFuseDistWithNumThreads<TileToForallAndWorkgroupCountRegionOp>(
       b, rootH, opsHToFuse, tileSizes, threadDimMapping);
 }
 
@@ -334,8 +330,8 @@
 /// Then tile the parallel part and map it to `tileSize` threads, each reducing
 /// on `vectorSize` elements.
 /// Lastly, fuse the newly created fill and elementwise operations into the
-/// resulting containing foreach_thread op.
-/// Return a triple of handles to (foreach_thread, fill, combiner)
+/// resulting containing forall op.
+/// Return a triple of handles to (forall, fill, combiner)
 std::tuple<Value, Value, Value>
 mlir::iree_compiler::buildTileReductionUsingScfForeach(
     ImplicitLocOpBuilder &b, Value reductionH, int64_t reductionRank,
@@ -345,18 +341,18 @@
   numThreads.push_back(tileSize);
   SmallVector<int64_t> tileSizes = leadingParallelDims;
   tileSizes.push_back(reductionVectorSize);
-  auto tileReduction = b.create<transform::TileReductionUsingForeachThreadOp>(
+  auto tileReduction = b.create<transform::TileReductionUsingForallOp>(
       /*target=*/reductionH,
       /*numThreads=*/numThreads,
       /*tileSizes=*/tileSizes,
       /*threadDimMapping=*/b.getArrayAttr(mappingAttr));
-  Value blockParallelForeachThreadOp = tileReduction.getForeachThreadOp();
+  Value blockParallelForallOp = tileReduction.getForallOp();
   Value blockParallelFillH = tileReduction.getFillOp();
   Value blockCombinerOpH = tileReduction.getCombiningLinalgOp();
   // Fuse the fill and elementwise to privatize them.
-  blockParallelFillH = b.create<FuseIntoContainingOp>(
-      blockParallelFillH, blockParallelForeachThreadOp);
-  return std::make_tuple(blockParallelForeachThreadOp, blockParallelFillH,
+  blockParallelFillH =
+      b.create<FuseIntoContainingOp>(blockParallelFillH, blockParallelForallOp);
+  return std::make_tuple(blockParallelForallOp, blockParallelFillH,
                          blockCombinerOpH);
 }
 
@@ -375,8 +371,8 @@
   auto [fusionTargetH, fusionGroupH] =
       buildSelectFirstNonEmpty(b, maybeTrailingH, reductionH);
   ArrayRef<Attribute> allBlocksRef(strategy.allBlockAttrs);
-  TileToForeachThreadAndFuseAndDistributeResult tileResult =
-      buildTileFuseDistToForeachThreadAndWorkgroupCountWithTileSizes(
+  TileToForallAndFuseAndDistributeResult tileResult =
+      buildTileFuseDistToForallAndWorkgroupCountWithTileSizes(
           /*builder=*/b,
           /*rootH=*/fusionTargetH,
           /*opsToFuseH=*/fusionGroupH,
@@ -385,14 +381,14 @@
           /*threadDimMapping=*/
           b.getArrayAttr(
               allBlocksRef.take_front(strategy.captures.reductionRank - 1)));
-  fillH = b.create<FuseIntoContainingOp>(fillH, tileResult.foreachThreadH);
+  fillH = b.create<FuseIntoContainingOp>(fillH, tileResult.forallH);
   maybeLeadingH =
-      b.create<FuseIntoContainingOp>(maybeLeadingH, tileResult.foreachThreadH);
+      b.create<FuseIntoContainingOp>(maybeLeadingH, tileResult.forallH);
   // Step 3. Normalize to reorder results irrespective of emptiness.
   auto [blockReductionH, maybeBlockTrailingH] = buildSelectFirstNonEmpty(
       b, tileResult.resultingFusedOpsHandles.front(), tileResult.tiledOpH);
   return std::make_tuple(maybeLeadingH, fillH, blockReductionH,
-                         maybeBlockTrailingH, tileResult.foreachThreadH);
+                         maybeBlockTrailingH, tileResult.forallH);
 }
 
 Value mlir::iree_compiler::buildMemoryOptimizations(ImplicitLocOpBuilder &b,
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
index 83c2686..7d8620f 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
@@ -91,10 +91,10 @@
 
 /// Result of the combined transform performing tiling, fusion and
 /// distribution to parallel constructs.
-struct TileToForeachThreadAndFuseAndDistributeResult {
-  /// Outer `scf.foreach_thread` loop containing the tiled and fused
+struct TileToForallAndFuseAndDistributeResult {
+  /// Outer `scf.forall` loop containing the tiled and fused
   /// operations.
-  Value foreachThreadH;
+  Value forallH;
   /// Handles to fused operations other than the final consumer operation. May
   /// be empty if fusion was not performed iteratively.
   // TODO: support returning handles from `fuse_into_containing_op` and remove
@@ -105,9 +105,9 @@
 };
 
 /// Build transform IR to perform the following transformations:
-///   1. Tiles `rootH` to scf.foreach_thread to with `tileSizesOrNumThreads`
+///   1. Tiles `rootH` to scf.forall to with `tileSizesOrNumThreads`
 ///      according to whether spec is a TileSizesSpec or a NumThreadsSpec.
-///   2. Maps the resulting scf.foreach_thread to threads according to
+///   2. Maps the resulting scf.forall to threads according to
 ///      `threadDimMapping`.
 ///   3. Iterates over `opsHToFuse` in order and fuses into the containing op.
 ///
@@ -122,37 +122,33 @@
 /// enabling transform will be introduced and may result in better fusions.
 ///
 /// Note: this version cannot be used for the block-level tiling in a dispatch
-/// region. `buildTileFuseDistToForeachThreadAndWorkgroupCountWithTileSizes` is
+/// region. `buildTileFuseDistToForallAndWorkgroupCountWithTileSizes` is
 /// the modified version that is aware of the `workgroup_count` region.
 ///
 // TODO: if someone knows how to properly export templates go for it .. sigh.
-TileToForeachThreadAndFuseAndDistributeResult
-buildTileFuseDistToForeachThreadWithTileSizes(ImplicitLocOpBuilder &b,
-                                              Value rootH,
-                                              ValueRange opsHToFuse,
-                                              ArrayRef<OpFoldResult> tileSizes,
-                                              ArrayAttr threadDimMapping);
+TileToForallAndFuseAndDistributeResult buildTileFuseDistToForallWithTileSizes(
+    ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+    ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping);
 
-/// Version of `buildTileFuseDistToForeachThreadWithTileSizes` that is aware of
+/// Version of `buildTileFuseDistToForallWithTileSizes` that is aware of
 /// IREE's `workgroup_count` region and should be used for the block-level
 /// tiling in a dispatch region.
-TileToForeachThreadAndFuseAndDistributeResult
-buildTileFuseDistToForeachThreadAndWorkgroupCountWithTileSizes(
+TileToForallAndFuseAndDistributeResult
+buildTileFuseDistToForallAndWorkgroupCountWithTileSizes(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping);
 
 /// Similar to `buildTileFuseDistWithTileSizes` but using `numThreads` instead
 /// of `tileSizes`.
-TileToForeachThreadAndFuseAndDistributeResult
-buildTileFuseDistToForeachThreadWithNumThreads(
+TileToForallAndFuseAndDistributeResult buildTileFuseDistToForallWithNumThreads(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping);
 
-/// Version of `buildTileFuseDistToForeachThreadWithNumThreads` that is aware of
+/// Version of `buildTileFuseDistToForallWithNumThreads` that is aware of
 /// IREE's `workgroup_count` region and should be used for the block-level
 /// tiling in a dispatch region.
-TileToForeachThreadAndFuseAndDistributeResult
-buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+TileToForallAndFuseAndDistributeResult
+buildTileFuseDistToForallAndWorgroupCountWithNumThreads(
     ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
     ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping);
 
@@ -170,8 +166,8 @@
 /// Then tile the parallel part and map it to `tileSize` threads, each reducing
 /// on `vectorSize` elements.
 /// Lastly, fuse the newly created fill and elementwise operations into the
-/// resulting containing foreach_thread op.
-/// Return a triple of handles to (foreach_thread, fill, combiner)
+/// resulting containing forall op.
+/// Return a triple of handles to (forall, fill, combiner)
 std::tuple<Value, Value, Value> buildTileReductionUsingScfForeach(
     ImplicitLocOpBuilder &b, Value reductionH, int64_t reductionRank,
     int64_t tileSize, int64_t reductionVectorSize, Attribute mappingAttr);
@@ -183,16 +179,16 @@
 
 /// Build transform IR to match exactly an N-D reduction operation (with
 /// optional leading and trailing elementwise) and create a top-level
-/// `scf.foreach_thread` tiled by `strategy.workgroupTileSizes`.
+/// `scf.forall` tiled by `strategy.workgroupTileSizes`.
 /// The matched `maybeLeadingH`, `fillH`, `reductionH` and `maybeTrailingH` are
-/// fused into the top-level `scf.foreach_thread` and handles are returned to
+/// fused into the top-level `scf.forall` and handles are returned to
 /// the fused versions of these ops, in order, that are all tiled and
-/// distributed accordingly. The scf.foreach_thread is returned as the last
+/// distributed accordingly. The scf.forall is returned as the last
 /// value.
-/// The mapping of the `scf.foreach_thread` dimensions is tied the first
+/// The mapping of the `scf.forall` dimensions is tied the first
 /// dimensions of `strategy.allBlockAttrs`.
 ///
-/// Note: `buildTileFuseDistToForeachThreadAndWorkgroupCountWithTileSizes` is
+/// Note: `buildTileFuseDistToForallAndWorkgroupCountWithTileSizes` is
 /// called internally, this version is only for the block-level tiling inside a
 /// dispatch region with an attached workgroup_count region.
 ///
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
index 9ed8323..b029779 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
@@ -33,9 +33,8 @@
 // TODO: significantly better namespacing.
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
-using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
-using iree_compiler::IREE::transform_dialect::
-    MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::MapNestedForallToGpuThreadsOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
 using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
 using transform::FuseIntoContainingOp;
@@ -139,8 +138,8 @@
 /// func.func.
 Value mlir::iree_compiler::gpu::buildMapToBlockAndThreads(
     ImplicitLocOpBuilder &b, Value funcH, ArrayRef<int64_t> blockSize) {
-  funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
-  return b.create<MapNestedForeachThreadToGpuThreadsOp>(funcH, blockSize);
+  funcH = b.create<ForallToWorkgroupOp>(funcH);
+  return b.create<MapNestedForallToGpuThreadsOp>(funcH, blockSize);
 }
 
 /// Post-bufferization vector distribution with rank-reduction.
@@ -222,7 +221,7 @@
     }
     if (numThreads > 1) {
       assert(mappingAttr && "must specify a mapping attribute");
-      iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+      iree_compiler::buildTileFuseDistToForallWithNumThreads(
           /*b=*/b,
           /*rootH=*/opH,
           /*opsHToFuse=*/{},
@@ -243,7 +242,7 @@
   }
   if (numThreads > 1) {
     assert(mappingAttr && "must specify a mapping attribute");
-    iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+    iree_compiler::buildTileFuseDistToForallWithNumThreads(
         /*b=*/b,
         /*rootH=*/opH,
         /*opsHToFuse=*/{},
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.h b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.h
index d8d0dbf..7f36138 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.h
@@ -65,8 +65,8 @@
 /// Take a handle `opH` to a Linalg op of rank `rank`, sizes `opSizes` and for
 /// which we know the most minor dimension `mostMinorDim` (assuming all accesses
 /// are contiguous along that dimension for now).
-/// Build a schedule that maps `mostMinorDim` to a `scf.foreach_thread` op.
-/// When `numThreads` > 1, the `scf.foreach_thread` is also mapped to
+/// Build a schedule that maps `mostMinorDim` to a `scf.forall` op.
+/// When `numThreads` > 1, the `scf.forall` is also mapped to
 /// `mappingAttr` (which must then be non-null).
 /// The constructed schedule first performs a split of the largest possible
 /// multiple of `numThreads * maxVectorSize` to form a maximally divisible
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/SmallReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/SmallReductionStrategy.cpp
index 340bc31..6cf9c72 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/SmallReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/SmallReductionStrategy.cpp
@@ -24,9 +24,8 @@
 // TODO: significantly better namespacing.
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
-using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
-using iree_compiler::IREE::transform_dialect::
-    MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::MapNestedForallToGpuThreadsOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
 using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
 using transform::FuseIntoContainingOp;
@@ -102,8 +101,8 @@
   auto [fusionTargetH, fusionGroupH] =
       iree_compiler::buildSelectFirstNonEmpty(b, maybeTrailingH, reductionH);
   ArrayRef<Attribute> allThreadsRef(strategy.allThreadAttrs);
-  iree_compiler::TileToForeachThreadAndFuseAndDistributeResult tileResult =
-      iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+  iree_compiler::TileToForallAndFuseAndDistributeResult tileResult =
+      iree_compiler::buildTileFuseDistToForallWithNumThreads(
           /*builder=*/b,
           /*rootH=*/fusionTargetH,
           /*opsToFuseH=*/fusionGroupH,
@@ -112,9 +111,9 @@
           /*threadDimMapping=*/
           b.getArrayAttr(
               allThreadsRef.take_front(strategy.captures.reductionRank - 1)));
-  fillH = b.create<FuseIntoContainingOp>(fillH, tileResult.foreachThreadH);
+  fillH = b.create<FuseIntoContainingOp>(fillH, tileResult.forallH);
   maybeLeadingH =
-      b.create<FuseIntoContainingOp>(maybeLeadingH, tileResult.foreachThreadH);
+      b.create<FuseIntoContainingOp>(maybeLeadingH, tileResult.forallH);
 
   // 1. Scalarize all ops to ensure vectorization.
   auto pdlOperation = pdl::OperationType::get(b.getContext());
@@ -155,7 +154,7 @@
     const SmallReductionStrategy &strategy) {
   // Step 1. Apply block-level part of the strategy, keeps everything fused.
   auto [maybeLeadingHBlock, gridFillH, gridReductionH, maybeTiledTrailingHBlock,
-        foreachThread] =
+        forall] =
       buildReductionStrategyBlockDistribution(b, variantH, strategy);
 
   // Step 2. Apply thread-level part of the strategy, keeps everything fused.
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/StagedReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/StagedReductionStrategy.cpp
index 4ca0994..5f1b480 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/StagedReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/StagedReductionStrategy.cpp
@@ -25,10 +25,9 @@
 // TODO: significantly better namespacing.
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
-using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
-using iree_compiler::IREE::transform_dialect::
-    MapNestedForeachThreadToGpuThreadsOp;
-using iree_compiler::IREE::transform_dialect::ShareForeachThreadOperandsOp;
+using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::MapNestedForallToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::ShareForallOperandsOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
 using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
 using transform::FuseIntoContainingOp;
@@ -114,14 +113,14 @@
   }
 }
 
-static Value shareForeachArgument(ImplicitLocOpBuilder &b, Value foreachThread,
+static Value shareForeachArgument(ImplicitLocOpBuilder &b, Value Forall,
                                   ArrayRef<int64_t> indices) {
   auto foreachType = transform::OperationType::get(
-      b.getContext(), scf::ForeachThreadOp::getOperationName());
-  foreachThread = b.create<transform::CastOp>(foreachType, foreachThread);
-  return b.create<
-      iree_compiler::IREE::transform_dialect::ShareForeachThreadOperandsOp>(
-      foreachType, foreachThread, indices);
+      b.getContext(), scf::ForallOp::getOperationName());
+  Forall = b.create<transform::CastOp>(foreachType, Forall);
+  return b
+      .create<iree_compiler::IREE::transform_dialect::ShareForallOperandsOp>(
+          foreachType, Forall, indices);
 }
 
 static void buildStagedReductionStrategyThreadLevel(
@@ -148,7 +147,7 @@
   }
 
   // Staged reduction step 1: break gridReductionH apart.
-  auto [blockParallelForeachThreadOp, blockParallelFillH, blockCombinerOpH] =
+  auto [blockParallelForallOp, blockParallelFillH, blockCombinerOpH] =
       buildTileReductionUsingScfForeach(
           /*b=*/b,
           /*reductionH=*/gridReductionH,
@@ -160,7 +159,7 @@
   // Staged reduction step 2: multi-warp shuffle reduce.
   // Map the combiner reduction to one thread along y. Mapping this part along
   // y only will trigger the insertion of an `scf.if (threadIdx.x == 0)`
-  // predicate after `scf.foreach_thread` is lowered.
+  // predicate after `scf.forall` is lowered.
   // This predicate allows further vector distribution to kick in.
   Value root = blockCombinerOpH;
   SmallVector<Value> opsToFuse = {gridFillH};
@@ -182,7 +181,7 @@
     root = maybeTiledTrailingH;
     opsToFuse.push_back(blockCombinerOpH);
   }
-  iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
+  iree_compiler::buildTileFuseDistToForallWithTileSizes(
       /*b=*/b,
       /*rootH=*/root,
       /*opsToFuse=*/opsToFuse,
@@ -214,10 +213,10 @@
 void mlir::iree_compiler::gpu::buildStagedReductionStrategy(
     ImplicitLocOpBuilder &b, Value variantH,
     const StagedReductionStrategy &strategy) {
-  // Step 1. Match and tile to introduce the top-level scf.foreach_thread for
+  // Step 1. Match and tile to introduce the top-level scf.forall for
   // the block/workgroup level. Keep everything fused.
   auto [maybeLeadingHBlock, gridFillH, gridReductionH, maybeTiledTrailingHBlock,
-        commonEnclosingForeachThreadH] =
+        commonEnclosingForallH] =
       buildReductionStrategyBlockDistribution(b, variantH, strategy);
 
   // Step 2. Split the reduction and tile the pieces to ensure vector
@@ -226,15 +225,14 @@
                                           maybeLeadingHBlock,
                                           maybeTiledTrailingHBlock, strategy);
 
-  // Step 3. Make sure we don't create allocation by sharing foreach_thread
+  // Step 3. Make sure we don't create allocation by sharing forall
   // output. This amounts to injecting user-defined static information that each
   // thread accesses only a private slice. This needs to be added late, once we
   // don't need handles anymore, because contained handles are currently always
   // invalidated, even when modified inplace.
   // TODO: Relax nested invalidation for transforms that only move or modify
   // contained ops inplace.
-  shareForeachArgument(b, commonEnclosingForeachThreadH,
-                       ArrayRef<int64_t>({0}));
+  shareForeachArgument(b, commonEnclosingForallH, ArrayRef<int64_t>({0}));
 
   // Step 4-5. Common trailing steps.
   auto [variantH2, funcH] = buildCommonTrailingStrategy(b, variantH, strategy);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index 3ce4909..456bac3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -45,16 +45,16 @@
 }
 
 //===---------------------------------------------------------------------===//
-// Patterns for ForeachThreadOpToFlow rewrite.
+// Patterns for ForallOpToFlow rewrite.
 //===---------------------------------------------------------------------===//
 
 /// Populate the workgroup_count region of `dispatchOp`.
 /// For now, this only supports constant index ops and empty workload operands.
 /// Assumes the Flow::DispatchWorkgroupsOp is built with an empty region.
 static LogicalResult populateWorkgroupCountComputingRegion(
-    PatternRewriter &rewriter, scf::ForeachThreadOp foreachThreadOp,
+    PatternRewriter &rewriter, scf::ForallOp forallOp,
     Flow::DispatchWorkgroupsOp dispatchOp) {
-  Location loc = foreachThreadOp.getLoc();
+  Location loc = forallOp.getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   Region &r = dispatchOp.getWorkgroupCount();
   assert(r.empty() && "expected block-less workgroup_count region");
@@ -64,7 +64,7 @@
   SmallVector<Value> results;
   // For now, this assumes that we only pull in constants.
   // TODO: Iteratively pull operations that are only consuming IndexType.
-  for (Value v : foreachThreadOp.getNumThreads()) {
+  for (Value v : forallOp.getUpperBound(rewriter)) {
     auto op = dyn_cast_or_null<arith::ConstantIndexOp>(v.getDefiningOp());
     if (!op) return failure();
     results.push_back(
@@ -79,27 +79,29 @@
   return success();
 }
 
-/// Rewrite ParallelInsertSlice ops in `performConcurrentlyOp` as Flow
+/// Rewrite ParallelInsertSlice ops in `InParallelOp` as Flow
 /// DispatchTensorStoreOps.
 /// Ops are inserted just before the `block` terminator.
-static void rewriteParallelInsertSlices(
-    PatternRewriter &rewriter, scf::ForeachThreadOp foreachThreadOp,
-    scf::PerformConcurrentlyOp performConcurrentlyOp, Block &block,
-    ValueRange resultTensorOperands, ValueRange resultTensorsDynamicDims,
-    IRMapping tensorToFlowBvm) {
-  Location loc = performConcurrentlyOp.getLoc();
+static void rewriteParallelInsertSlices(PatternRewriter &rewriter,
+                                        scf::ForallOp forallOp,
+                                        scf::InParallelOp InParallelOp,
+                                        Block &block,
+                                        ValueRange resultTensorOperands,
+                                        ValueRange resultTensorsDynamicDims,
+                                        IRMapping tensorToFlowBvm) {
+  Location loc = InParallelOp.getLoc();
   int64_t resultIndex = 0;
   for (const Operation &yieldingOp :
-       llvm::make_early_inc_range(performConcurrentlyOp.getYieldingOps())) {
+       llvm::make_early_inc_range(InParallelOp.getYieldingOps())) {
     auto parallelInsertOp = cast<tensor::ParallelInsertSliceOp>(&yieldingOp);
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(block.getTerminator());
     auto dynamicDims = Util::findVariadicDynamicDims(
         resultIndex, resultTensorOperands, resultTensorsDynamicDims);
     BlockArgument destBbArg = parallelInsertOp.getDest().cast<BlockArgument>();
-    assert(destBbArg.getOwner()->getParentOp() == foreachThreadOp &&
+    assert(destBbArg.getOwner()->getParentOp() == forallOp &&
            "expected that dest is an output bbArg");
-    Value dest = foreachThreadOp.getTiedOpOperand(destBbArg)->get();
+    Value dest = forallOp.getTiedOpOperand(destBbArg)->get();
     // clang-format off
     rewriter.create<Flow::DispatchTensorStoreOp>(
         loc,
@@ -120,7 +122,7 @@
 /// dispatchOp as well as a IRMapping from tensor operands to the
 /// corresponding Flow dispatch tensor bbArgs.
 static void rewriteExtractSlices(PatternRewriter &rewriter,
-                                 scf::ForeachThreadOp foreachThreadOp,
+                                 scf::ForallOp forallOp,
                                  Flow::DispatchWorkgroupsOp dispatchOp,
                                  ValueRange tensorOperands,
                                  ValueRange tensorDynamicDims,
@@ -128,9 +130,8 @@
   dispatchOp->walk([&](tensor::ExtractSliceOp extractSliceOp) {
     Value source = extractSliceOp.getSource();
     if (auto sourceBbArg = source.dyn_cast<BlockArgument>())
-      if (sourceBbArg.getOwner()->getParentOp() ==
-          foreachThreadOp.getOperation())
-        source = foreachThreadOp.getTiedOpOperand(sourceBbArg)->get();
+      if (sourceBbArg.getOwner()->getParentOp() == forallOp.getOperation())
+        source = forallOp.getTiedOpOperand(sourceBbArg)->get();
 
     auto it = llvm::find(tensorOperands, source);
     if (it == tensorOperands.end()) return;
@@ -156,13 +157,12 @@
   });
 }
 
-static void cloneOpsIntoForeachThreadOp(RewriterBase &rewriter,
-                                        scf::ForeachThreadOp foreachThreadOp) {
-  // 1. Find all ops that should be cloned into the ForeachThreadOp.
+static void cloneOpsIntoForallOp(RewriterBase &rewriter,
+                                 scf::ForallOp forallOp) {
+  // 1. Find all ops that should be cloned into the ForallOp.
   llvm::SetVector<Value> valuesDefinedAbove;
-  mlir::getUsedValuesDefinedAbove(foreachThreadOp.getRegion(),
-                                  valuesDefinedAbove);
-  // Add all ops who's results are used inside the ForeachThreadOp to the
+  mlir::getUsedValuesDefinedAbove(forallOp.getRegion(), valuesDefinedAbove);
+  // Add all ops who's results are used inside the ForallOp to the
   // worklist.
   llvm::SetVector<Operation *> worklist;
   for (Value v : valuesDefinedAbove)
@@ -182,7 +182,7 @@
 
     // Do not clone ParallelInsertSliceOp destinations.
     bool isDestination =
-        any_of(foreachThreadOp.getTerminator().getYieldingOps(),
+        any_of(forallOp.getTerminator().getYieldingOps(),
                [&](Operation &insertOp) {
                  return cast<tensor::ParallelInsertSliceOp>(&insertOp)
                             .getDest()
@@ -200,16 +200,14 @@
     }
   }
 
-  // 2. Clone ops and replace their uses inside the ForeachThreadOp.
+  // 2. Clone ops and replace their uses inside the ForallOp.
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPointToStart(
-      &foreachThreadOp.getRegion().getBlocks().front());
+  rewriter.setInsertionPointToStart(&forallOp.getRegion().getBlocks().front());
   for (Operation *op : llvm::reverse(opsToClone)) {
     Operation *cloned = rewriter.clone(*op);
     SmallVector<OpOperand *> uses;
     for (OpOperand &use : op->getUses())
-      if (foreachThreadOp->isProperAncestor(use.getOwner()))
-        uses.push_back(&use);
+      if (forallOp->isProperAncestor(use.getOwner())) uses.push_back(&use);
     for (OpOperand *use : uses) {
       unsigned resultNum = use->get().cast<OpResult>().getResultNumber();
       rewriter.updateRootInPlace(
@@ -218,13 +216,13 @@
   }
 }
 
-/// Rewrite a ForeachThreadOp into a Flow::DispatchWorkGroupsOp.
+/// Rewrite a ForallOp into a Flow::DispatchWorkGroupsOp.
 /// This rewrite proceeds in a few steps:
-///   - Step 0: Clone certain ops into the ForeachThreadOp (as per IREE
+///   - Step 0: Clone certain ops into the ForallOp (as per IREE
 ///     heuristic), so that they are part of the dispatch region.
 ///   - Step 1: Compute the result types and their result dynamic dim operands.
 ///     This first step takes advantage of the ops contained in the
-///     ForeachThreadOp terminator and that are tied to the results.
+///     ForallOp terminator and that are tied to the results.
 ///   - Step 2: Get values defined above and separate them between non-tensors,
 ///     tensors and introduce appropriate tensor dims.
 ///   - Step 3: Create ordered vectors of operands to pass to the builder and
@@ -232,7 +230,7 @@
 ///   - Step 4: Populate the workgroupCount region of the dispatchOp and set
 ///     the workload operands to the values defined above.
 ///   - Step 5: Fixup dispatchOp bbArgs and terminator.
-///   - Step 6: Move the body of foreachThreadOp to the dispatchOp.
+///   - Step 6: Move the body of forallOp to the dispatchOp.
 ///   - Step 7: Set up bvm for RAUWIf. In particular, tensor operands become
 ///     flow dispatch tensor bbArgs and need to be
 ///     flow.dispatch.tensor.load'ed.
@@ -240,31 +238,30 @@
 ///   - Step 9. Rewrite tensor::ExtractSlice and ParallelInsert ops to the
 ///     relevant Flow DispatchTensorLoad/Store version.
 ///   - Step 10: Perform RAUWIf.
-///   - Step 11: Drop the terminator and replace foreachThreadOp.
-// TODO: n-D ForeachThreadOp
+///   - Step 11: Drop the terminator and replace forallOp.
+// TODO: n-D ForallOp
 FailureOr<Flow::DispatchWorkgroupsOp>
-rewriteForeachThreadToFlowDispatchWorkgroups(
-    scf::ForeachThreadOp foreachThreadOp, PatternRewriter &rewriter) {
-  // Step 0: Clone ops into the ForeachThreadOp.
-  cloneOpsIntoForeachThreadOp(rewriter, foreachThreadOp);
+rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp,
+                                             PatternRewriter &rewriter) {
+  // Step 0: Clone ops into the ForallOp.
+  cloneOpsIntoForallOp(rewriter, forallOp);
 
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(foreachThreadOp);
+  rewriter.setInsertionPoint(forallOp);
 
-  // Entry point start just before the foreachThreadOp.
-  Location loc = foreachThreadOp.getLoc();
-  scf::PerformConcurrentlyOp performConcurrentlyOp =
-      foreachThreadOp.getTerminator();
+  // Entry point start just before the forallOp.
+  Location loc = forallOp.getLoc();
+  scf::InParallelOp InParallelOp = forallOp.getTerminator();
 
   // Step 1: Compute all dynamic result dims.
   // The `dest` of the ParallelInsertSliceOp are tied to the results and carry
   // over to the Flow::DispatchWorkgroupsOp.
   // Use a SetVector to ensure tensor operand uniqueness.
   llvm::SetVector<Value> resultTensorOperands, resultTensorsDynamicDims;
-  for (const Operation &yieldingOp : performConcurrentlyOp.getYieldingOps()) {
+  for (const Operation &yieldingOp : InParallelOp.getYieldingOps()) {
     auto parallelInsertOp = cast<tensor::ParallelInsertSliceOp>(&yieldingOp);
     BlockArgument destBbArg = parallelInsertOp.getDest().cast<BlockArgument>();
-    Value dest = foreachThreadOp.getTiedOpOperand(destBbArg)->get();
+    Value dest = forallOp.getTiedOpOperand(destBbArg)->get();
     bool inserted = resultTensorOperands.insert(dest);
     if (!inserted) continue;
     auto dynamicDims =
@@ -273,16 +270,15 @@
       resultTensorsDynamicDims.insert(
           rewriter.create<tensor::DimOp>(loc, dest, dim));
   }
-  assert(resultTensorOperands.size() == foreachThreadOp.getNumResults() &&
-         "Expected as many resultTensorOperands as results of foreachThreadOp");
+  assert(resultTensorOperands.size() == forallOp.getNumResults() &&
+         "Expected as many resultTensorOperands as results of forallOp");
 
   // Step 2. Get values defined above and separate them between non-tensors,
   // tensors and introduce appropriate tensor dims.
   // Tensors that have already been recorded as resultTensorOperands are
   // omitted to avoid duplications.
   llvm::SetVector<Value> valuesDefinedAbove;
-  mlir::getUsedValuesDefinedAbove(foreachThreadOp.getRegion(),
-                                  valuesDefinedAbove);
+  mlir::getUsedValuesDefinedAbove(forallOp.getRegion(), valuesDefinedAbove);
 
   SmallVector<Value> nonTensorOperands, tensorOperands, tensorDynamicDims;
   for (Value v : valuesDefinedAbove) {
@@ -298,7 +294,7 @@
   }
   // Also add shared outputs. (These are usually already added as result
   // tensor operands.)
-  for (Value v : foreachThreadOp.getOutputs()) {
+  for (Value v : forallOp.getOutputs()) {
     auto tensorType = v.getType().cast<RankedTensorType>();
     if (resultTensorOperands.contains(v)) continue;
     tensorOperands.push_back(v);
@@ -313,7 +309,7 @@
   llvm::append_range(nonDimOperands, nonTensorOperands);
   llvm::append_range(nonDimOperands, tensorOperands);
   llvm::append_range(nonDimOperands, resultTensorOperands);
-  // scf::ForeachThreadOp tensors inserted into are tied to results and
+  // scf::ForallOp tensors inserted into are tied to results and
   // translate to the tied operands of the dispatch.
   int64_t sizeNonTensors = nonTensorOperands.size();
   int64_t sizeNonResultTensors = tensorOperands.size();
@@ -330,7 +326,7 @@
   auto dispatchOp = rewriter.create<Flow::DispatchWorkgroupsOp>(
       loc,
       /*workload=*/ValueRange{},
-      /*resultTypes=*/foreachThreadOp.getResultTypes(),
+      /*resultTypes=*/forallOp.getResultTypes(),
       /*resultDims=*/resultTensorsDynamicDims.getArrayRef(),
       /*operands=*/nonDimOperands,
       /*operandDims=*/allTensorDynamicDims,
@@ -339,9 +335,9 @@
 
   // Step 4. Outline the compute workload region and set up the workload
   // operands.
-  if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
+  if (failed(populateWorkgroupCountComputingRegion(rewriter, forallOp,
                                                    dispatchOp)))
-    return foreachThreadOp->emitOpError(
+    return forallOp->emitOpError(
                "failed to populate workload region for dispatchOp: ")
            << dispatchOp;
 
@@ -382,9 +378,9 @@
   assert(block->getNumArguments() == allOperands.size() &&
          "Expected as many bbArgs as operands");
 
-  // Step 6. Move the body of foreachThreadOp to the dispatchOp.
-  block->getOperations().splice(
-      block->begin(), foreachThreadOp.getRegion().front().getOperations());
+  // Step 6. Move the body of forallOp to the dispatchOp.
+  block->getOperations().splice(block->begin(),
+                                forallOp.getRegion().front().getOperations());
 
   // Step 7. Set up bvm for RAUWIf.
   // Generally, allOperands map to their corresponding bbArg but there is a
@@ -420,22 +416,22 @@
   rewriter.setInsertionPointToStart(block);
   SmallVector<Value, 8> workgroupIds, workgroupCounts;
   for (int64_t rank :
-       llvm::seq<int64_t>(0, foreachThreadOp.getThreadIndices().size())) {
+       llvm::seq<int64_t>(0, forallOp.getInductionVars().size())) {
     workgroupIds.push_back(
         rewriter.create<Flow::DispatchWorkgroupIDOp>(loc, rank));
     workgroupCounts.push_back(
         rewriter.create<Flow::DispatchWorkgroupCountOp>(loc, rank));
   }
-  bvm.map(foreachThreadOp.getThreadIndices(), workgroupIds);
-  bvm.map(foreachThreadOp.getNumThreads(), workgroupCounts);
+  bvm.map(forallOp.getInductionVars(), workgroupIds);
+  bvm.map(forallOp.getUpperBound(rewriter), workgroupCounts);
 
   // Step 9. Rewrite tensor::ExtractSlice and ParallelInsert ops to the
   // relevant Flow DispatchTensorLoad/Store version.
-  rewriteParallelInsertSlices(rewriter, foreachThreadOp, performConcurrentlyOp,
-                              *block, resultTensorOperands.getArrayRef(),
+  rewriteParallelInsertSlices(rewriter, forallOp, InParallelOp, *block,
+                              resultTensorOperands.getArrayRef(),
                               resultTensorsDynamicDims.getArrayRef(),
                               tensorToFlowBvm);
-  rewriteExtractSlices(rewriter, foreachThreadOp, dispatchOp, allTensorOperands,
+  rewriteExtractSlices(rewriter, forallOp, dispatchOp, allTensorOperands,
                        allTensorDynamicDims, tensorToFlowBvm);
 
   // Step 10. Perform RAUWIf.
@@ -447,9 +443,9 @@
     });
   }
 
-  // Step 11. Drop the terminator and replace foreachThreadOp.
-  rewriter.eraseOp(performConcurrentlyOp);
-  rewriter.replaceOp(foreachThreadOp, dispatchOp.getResults());
+  // Step 11. Drop the terminator and replace forallOp.
+  rewriter.eraseOp(InParallelOp);
+  rewriter.replaceOp(forallOp, dispatchOp.getResults());
 
   return dispatchOp;
 }
@@ -460,7 +456,7 @@
 
 DiagnosedSilenceableFailure
 transform_dialect::ForeachThreadToFlowDispatchWorkgroupsOp::applyToOne(
-    scf::ForeachThreadOp target, transform::ApplyToEachResultList &results,
+    scf::ForallOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &) {
   SimplePatternRewriter rewriter(target->getContext());
   FailureOr<Flow::DispatchWorkgroupsOp> result =
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h
index c843297..69dd4dd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h
@@ -15,7 +15,7 @@
 class DialectRegistry;
 
 namespace scf {
-class ForeachThreadOp;
+class ForallOp;
 }  // namespace scf
 
 namespace iree_compiler {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
index 037c39f..1150d93 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
@@ -13,13 +13,13 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
-def ForeachThreadToFlowDispatchWorkgroupsOp : Op<Transform_Dialect, "iree.foreach_thread_to_flow",
+def ForeachThreadToFlowDispatchWorkgroupsOp : Op<Transform_Dialect, "iree.forall_to_flow",
     [FunctionalStyleTransformOpTrait,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformEachOpTrait,
      TransformOpInterface]> {
   let description = [{
-    Rewrite an scf.foreach_thread to Flow::DispatchWorkgroups.
+    Rewrite an scf.forall to Flow::DispatchWorkgroups.
 
     #### Return modes
 
@@ -45,7 +45,7 @@
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForeachThreadOp target,
+        ::mlir::scf::ForallOp target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
index 96f7f09..6539fd4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
@@ -1,6 +1,6 @@
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [42, 67]
-  %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
+  %foreach_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [42, 67]
+  %dispatch_op = transform.iree.forall_to_flow %foreach_op
 }
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index 7232e2a..57908f1 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -7,7 +7,7 @@
 
 load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
 
-TENSORFLOW_COMMIT = "75eaca49ed62c37278113b270a3e19edab0ba26d"
+TENSORFLOW_COMMIT = "20ff2f32d85e79195b9b38f69039eb185f8cb1a5"
 
 git_repository(
     name = "org_tensorflow",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h
index 1fe0baf..1f49c79 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h
@@ -18,7 +18,7 @@
 namespace mlir {
 namespace scf {
 class ForOp;
-class ForeachThreadOp;
+class ForallOp;
 } // namespace scf
 } // namespace mlir
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index f47c2d8..acadacc 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -29,20 +29,20 @@
   let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
 }
 
-def RewriteForeachThreadToAsyncOp :
-  Op<Transform_Dialect, "foreach_thread_to_async",
+def RewriteForallToAsyncOp :
+  Op<Transform_Dialect, "forall_to_async",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformEachOpTrait,
      TransformOpInterface]> {
 
   let description = [{
-    Rewrite a bufferized scf.foreach_thread op to the async dialect.
+    Rewrite a bufferized scf.forall op to the async dialect.
 
     Return modes:
     =============
     This operation ignores non-Linalg ops and drops them in the return.
-    This transform is currently only implemented for 1-D scf.foreach_thread that
+    This transform is currently only implemented for 1-D scf.forall that
     have been bufferized and definitely fail for the rest.
 
     If all the operations referred to by the `target` PDLOperation lower
@@ -59,26 +59,26 @@
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForeachThreadOp target,
+        ::mlir::scf::ForallOp target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
 }
 
-def RewriteForeachThreadToScfForOp :
-  Op<Transform_Dialect, "foreach_thread_to_scf_for",
+def RewriteForallToScfForOp :
+  Op<Transform_Dialect, "forall_to_scf_for",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformEachOpTrait,
      TransformOpInterface]> {
 
   let description = [{
-    Rewrite a bufferized scf.foreach_thread to a sequential scf.for.
+    Rewrite a bufferized scf.forall to a sequential scf.for.
 
     Return modes:
     =============
     This operation ignores non-Linalg ops and drops them in the return.
-    This transform is currently only implemented for 1-D scf.foreach_thread that
+    This transform is currently only implemented for 1-D scf.forall that
     have been bufferized and definitely fail for the rest.
 
     If all the operations referred to by the `target` PDLOperation lower
@@ -95,7 +95,7 @@
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForeachThreadOp target,
+        ::mlir::scf::ForallOp target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 63795fa..5194775 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -16,7 +16,7 @@
 namespace mlir {
 namespace scf {
 class ForOp;
-class ForeachThreadOp;
+class ForallOp;
 } // namespace scf
 namespace linalg {
 class LinalgOp;
@@ -40,33 +40,31 @@
   }
 };
 
-/// Pattern to rewrite a scf::ForEachThreadOp to the async dialect.
-struct ForeachThreadOpToAsyncRewriter
-    : public OpRewritePattern<scf::ForeachThreadOp> {
+/// Pattern to rewrite a scf::ForallOp to the async dialect.
+struct ForallOpToAsyncRewriter : public OpRewritePattern<scf::ForallOp> {
   using OpRewritePattern::OpRewritePattern;
 
   FailureOr<Operation *>
-  returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp,
+  returningMatchAndRewrite(scf::ForallOp forallOp,
                            PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(scf::ForeachThreadOp foreachThreadOp,
+  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
                                 PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(foreachThreadOp, rewriter);
+    return returningMatchAndRewrite(forallOp, rewriter);
   }
 };
 
-/// Pattern to rewrite a ForeachThreadOp to an scf::ForOp.
-struct ForeachThreadOpToScfForRewriter
-    : public OpRewritePattern<scf::ForeachThreadOp> {
+/// Pattern to rewrite a ForallOp to an scf::ForOp.
+struct ForallOpToScfForRewriter : public OpRewritePattern<scf::ForallOp> {
   using OpRewritePattern::OpRewritePattern;
 
   FailureOr<scf::ForOp>
-  returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp,
+  returningMatchAndRewrite(scf::ForallOp forallOp,
                            PatternRewriter &rewriter) const;
 
-  LogicalResult matchAndRewrite(scf::ForeachThreadOp foreachThreadOp,
+  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
                                 PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(foreachThreadOp, rewriter);
+    return returningMatchAndRewrite(forallOp, rewriter);
   }
 };
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index e4ea7f1..04da78c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -122,11 +122,10 @@
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
-DiagnosedSilenceableFailure
-LinalgExt::RewriteForeachThreadToAsyncOp::applyToOne(
-    scf::ForeachThreadOp target, transform::ApplyToEachResultList &results,
+DiagnosedSilenceableFailure LinalgExt::RewriteForallToAsyncOp::applyToOne(
+    scf::ForallOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  LinalgExt::ForeachThreadOpToAsyncRewriter pattern(this->getContext());
+  LinalgExt::ForallOpToAsyncRewriter pattern(this->getContext());
   SimplePatternRewriter rewriter(target);
   FailureOr<Operation *> result =
       pattern.returningMatchAndRewrite(target, rewriter);
@@ -136,11 +135,10 @@
   return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilenceableFailure
-LinalgExt::RewriteForeachThreadToScfForOp::applyToOne(
-    scf::ForeachThreadOp target, transform::ApplyToEachResultList &results,
+DiagnosedSilenceableFailure LinalgExt::RewriteForallToScfForOp::applyToOne(
+    scf::ForallOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  LinalgExt::ForeachThreadOpToScfForRewriter pattern(this->getContext());
+  LinalgExt::ForallOpToScfForRewriter pattern(this->getContext());
   SimplePatternRewriter rewriter(target);
   FailureOr<Operation *> result =
       pattern.returningMatchAndRewrite(target, rewriter);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp
index 3cbabbf..d7b3a66 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToAsync.cpp
@@ -26,32 +26,31 @@
 using namespace mlir::iree_compiler::IREE::LinalgExt;
 
 FailureOr<Operation *>
-mlir::iree_compiler::IREE::LinalgExt::ForeachThreadOpToAsyncRewriter::
-    returningMatchAndRewrite(scf::ForeachThreadOp foreachThreadOp,
+mlir::iree_compiler::IREE::LinalgExt::ForallOpToAsyncRewriter::
+    returningMatchAndRewrite(scf::ForallOp forallOp,
                              PatternRewriter &rewriter) const {
-  if (foreachThreadOp.getNumResults() > 0)
-    return foreachThreadOp->emitError(
-        "only bufferized scf.foreach_thread lowers to async");
+  if (forallOp.getNumResults() > 0)
+    return forallOp->emitError("only bufferized scf.forall lowers to async");
 
-  if (foreachThreadOp.getNumThreads().size() > 1)
-    return foreachThreadOp->emitError(
-        "only single-dimension scf.foreach_thread lowers to async");
+  if (forallOp.getRank() > 1)
+    return forallOp->emitError(
+        "only single-dimension scf.forall lowers to async");
 
-  // Only consider the top level ForeachThreadOp op and skip if it already
+  // Only consider the top level ForallOp op and skip if it already
   // contains an ExecuteOp.
-  if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>() ||
-      llvm::any_of(foreachThreadOp.getBody()->getOperations(),
+  if (forallOp->getParentOfType<scf::ForallOp>() ||
+      llvm::any_of(forallOp.getBody()->getOperations(),
                    [](Operation &op) { return isa<async::ExecuteOp>(&op); }))
     return failure();
 
-  auto *ctx = foreachThreadOp.getContext();
-  Location loc = foreachThreadOp.getLoc();
+  auto *ctx = forallOp.getContext();
+  Location loc = forallOp.getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   // TODO: allow multi-dim.
-  Value numThreads = foreachThreadOp.getNumThreads().front();
+  Value numThreads = forallOp.getUpperBound(rewriter).front();
 
-  // Wrap the scf.foreach_thread into an async::ExecuteOp.
+  // Wrap the scf.forall into an async::ExecuteOp.
   // 1. Create the async::GroupType object on which we synchronize.
   Value asyncGroup = rewriter.create<async::CreateGroupOp>(
       loc, async::GroupType::get(ctx), numThreads);
@@ -68,15 +67,15 @@
                                         /*dependencies=*/ValueRange(),
                                         /*operands=*/ValueRange(), noopExec);
 
-  // 3. Steal the ops nested under scf::ForeachThread, except the terminator,
+  // 3. Steal the ops nested under scf::Forall, except the terminator,
   // into the body of the async::ExecuteOp, just before the terminator.
   SmallVector<Value> bbArgsTranslated{forOp.getInductionVar()};
-  rewriter.mergeBlocks(&foreachThreadOp.getRegion().front(),
-                       executeOp.getBody(), bbArgsTranslated);
-  // 3.b. Erase the terminator stolen from foreachThreadOp.
+  rewriter.mergeBlocks(&forallOp.getRegion().front(), executeOp.getBody(),
+                       bbArgsTranslated);
+  // 3.b. Erase the terminator stolen from forallOp.
   rewriter.eraseOp(&executeOp.getBody()->back());
-  // 3.c. Erase foreachThreadOp.
-  rewriter.eraseOp(foreachThreadOp);
+  // 3.c. Erase forallOp.
+  rewriter.eraseOp(forallOp);
   // 3.d. Add ExecuteOp terminator.
   rewriter.setInsertionPointToEnd(executeOp.getBody());
   rewriter.create<async::YieldOp>(loc, ValueRange{});
@@ -85,7 +84,7 @@
   rewriter.create<async::AddToGroupOp>(loc, rewriter.getIndexType(),
                                        executeOp.getToken(), asyncGroup);
 
-  // 4. After the iree_compiler::IREE::LinalgExt::ForeachThread, await all async
+  // 4. After the iree_compiler::IREE::LinalgExt::Forall, await all async
   // tasks in `asyncGroup`.
   rewriter.setInsertionPointAfter(forOp);
   return rewriter.create<async::AwaitAllOp>(loc, asyncGroup).getOperation();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
index dfe4434..a46512f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
@@ -24,7 +24,7 @@
 
 namespace {
 
-SmallVector<Value> getValuesToYield(scf::PerformConcurrentlyOp op) {
+SmallVector<Value> getValuesToYield(scf::InParallelOp op) {
   return llvm::to_vector(
       llvm::map_range(op.getYieldingOps(), [](Operation &op) -> Value {
         return cast<tensor::ParallelInsertSliceOp>(&op).getDest();
@@ -33,30 +33,28 @@
 
 } // namespace
 
-FailureOr<scf::ForOp> ForeachThreadOpToScfForRewriter::returningMatchAndRewrite(
-    scf::ForeachThreadOp foreachThreadOp, PatternRewriter &rewriter) const {
-  if (foreachThreadOp.getNumResults() > 0)
-    return foreachThreadOp->emitError(
-        "only bufferized scf.foreach_thread lowers to scf.for");
+FailureOr<scf::ForOp> ForallOpToScfForRewriter::returningMatchAndRewrite(
+    scf::ForallOp forallOp, PatternRewriter &rewriter) const {
+  if (forallOp.getNumResults() > 0)
+    return forallOp->emitError("only bufferized scf.forall lowers to scf.for");
 
-  if (foreachThreadOp.getNumThreads().size() > 1)
-    return foreachThreadOp->emitError(
-        "only single-dimension scf.foreach_thread lowers to scf.for");
+  if (forallOp.getRank() > 1)
+    return forallOp->emitError(
+        "only single-dimension scf.forall lowers to scf.for");
 
   // Construct the loop bounds based on the canonical arithmetic progression.
-  Location loc = foreachThreadOp.getLoc();
+  Location loc = forallOp.getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   // TODO: allow multi-dim.
-  Value numThreads = foreachThreadOp.getNumThreads().front();
+  Value numThreads = forallOp.getUpperBound(rewriter).front();
 
   // Construct the op without a body builder: we need to clone the ops in the
   // body explicitly after having access to the new bbArgs.
   // As a consequence, `ensureTerminator` is not called and the `forOp` body
   // has no terminator.
-  scf::PerformConcurrentlyOp performConcurrentlyOp =
-      foreachThreadOp.getTerminator();
-  SmallVector<Value> valuesToYield = getValuesToYield(performConcurrentlyOp);
+  scf::InParallelOp InParallelOp = forallOp.getTerminator();
+  SmallVector<Value> valuesToYield = getValuesToYield(InParallelOp);
   scf::ForOp forOp =
       rewriter.create<scf::ForOp>(loc, zero, numThreads, one, valuesToYield);
 
@@ -66,11 +64,10 @@
   bool hasTerminator =
       !body->empty() && body->back().hasTrait<OpTrait::IsTerminator>();
   if (hasTerminator) {
-    rewriter.mergeBlockBefore(&foreachThreadOp.getRegion().front(),
+    rewriter.mergeBlockBefore(&forallOp.getRegion().front(),
                               body->getTerminator(), bbArgsTranslated);
   } else {
-    rewriter.mergeBlocks(&foreachThreadOp.getRegion().front(), body,
-                         bbArgsTranslated);
+    rewriter.mergeBlocks(&forallOp.getRegion().front(), body, bbArgsTranslated);
   }
 
   rewriter.setInsertionPointToStart(body);
@@ -79,8 +76,8 @@
 
   // Create sequential insertSlice ops.
   SmallVector<Value> toYield;
-  rewriter.setInsertionPoint(performConcurrentlyOp);
-  for (Operation &operation : performConcurrentlyOp.getYieldingOps()) {
+  rewriter.setInsertionPoint(InParallelOp);
+  for (Operation &operation : InParallelOp.getYieldingOps()) {
     tensor::ParallelInsertSliceOp op =
         cast<tensor::ParallelInsertSliceOp>(&operation);
     toYield.push_back(rewriter.createOrFold<tensor::InsertSliceOp>(
@@ -88,7 +85,7 @@
         op.getMixedSizes(), op.getMixedStrides()));
   }
 
-  // performConcurrentlyOp.yieldedValues come from above, not from bbArgs.
+  // InParallelOp.yieldedValues come from above, not from bbArgs.
   // There is no rewriter method to make mergeBlocks update non-bbArgs.
   // Need to manually clone + bvm all uses that are now nested under forOp.
   // Warning: this replacement is currently optimistic and may change the
@@ -116,8 +113,8 @@
   }
 
   // Cleanup and replace.
-  rewriter.eraseOp(performConcurrentlyOp);
-  rewriter.replaceOp(foreachThreadOp, forOp.getResults());
+  rewriter.eraseOp(InParallelOp);
+  rewriter.replaceOp(forallOp, forOp.getResults());
 
   return forOp;
 }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
index 9e2c353..eb70e14 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
@@ -32,7 +32,7 @@
   // CHECK:   async.add_to_group %[[token]], %[[group]] : !async.token
   // CHECK: }
   // CHECK: async.await_all %[[group]]
-  scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
+  scf.forall (%arg3) in (%1) shared_outs() -> () {
       %3 = affine.apply #map1(%arg3)[%arg0]
       %4 = affine.apply #map2(%0, %3)
       %5 = affine.min #map3(%4, %arg0)
@@ -52,6 +52,6 @@
 
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
-  %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op : (!pdl.operation) -> !pdl.operation
-  %1 = foreach_thread_to_async %0
+  %0 = transform.structured.match ops{["scf.forall"]} in %module_op : (!pdl.operation) -> !pdl.operation
+  %1 = forall_to_async %0
 }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
index a01da7e..0aef6ab 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
@@ -23,7 +23,7 @@
   // CHECK: %[[C1:.*]] = arith.constant 1 : index
   // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
   // CHECK: scf.for %[[IV:.*]] = {{.*}} step %[[C1]] {
-  scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
+  scf.forall (%arg3) in (%1) shared_outs() -> () {
       %3 = affine.apply #map1(%arg3)[%arg0]
       %4 = affine.apply #map2(%0, %3)
       %5 = affine.min #map3(%4, %arg0)
@@ -46,6 +46,6 @@
 
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
-  %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op : (!pdl.operation) -> !pdl.operation
-  %1 = foreach_thread_to_scf_for %0
+  %0 = transform.structured.match ops{["scf.forall"]} in %module_op : (!pdl.operation) -> !pdl.operation
+  %1 = forall_to_scf_for %0
 }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
index 367da64..b790dbc 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
@@ -28,8 +28,8 @@
     } -> tensor<64xf32>
 
     %2 = affine.apply #map0()[%arg0]
-    // CHECK: scf.foreach_thread
-    %3 = scf.foreach_thread (%arg3) in (%2) shared_outs(%O = %arg2) -> (tensor<64xf32>) {
+    // CHECK: scf.forall
+    %3 = scf.forall (%arg3) in (%2) shared_outs(%O = %arg2) -> (tensor<64xf32>) {
       // CHECK:    %[[OFFSET:.*]] = affine.apply
       // CHECK:    %[[SIZE:.*]] = affine.min
       %4 = affine.apply #map1(%arg3)[%arg0]
@@ -44,7 +44,7 @@
 
       // CHECK:    %[[T4:.*]] = linalg.elemwise_unary ins(%[[T1]] {{.*}} outs(%[[T3]]
       %8 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%7 : tensor<?xf32>) -> tensor<?xf32>
-      scf.foreach_thread.perform_concurrently {
+      scf.forall.in_parallel {
         tensor.parallel_insert_slice %8 into %O[%4] [%5] [1] : tensor<?xf32> into tensor<64xf32>
       }
     }
@@ -62,7 +62,7 @@
     pdl.pattern @match_in_parallel : benefit(1) {
       %0 = operands
       %1 = types
-      %2 = operation "scf.foreach_thread"(%0 : !pdl.range<value>)  -> (%1 : !pdl.range<type>)
+      %2 = operation "scf.forall"(%0 : !pdl.range<value>)  -> (%1 : !pdl.range<type>)
       rewrite %2 with "transform.dialect"
     }
     transform.structured.canonicalized_sequence %arg0 failures(propagate) {
@@ -96,8 +96,8 @@
     // TODO: Choosing %arg2 here complicates the size computation.
     %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
     %1 = affine.apply #map0()[%d0, %arg0]
-    // CHECK: scf.foreach_thread
-    %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%O = %arg2) -> (tensor<?xf32>) {
+    // CHECK: scf.forall
+    %2 = scf.forall (%arg3) in (%1) shared_outs(%O = %arg2) -> (tensor<?xf32>) {
       // CHECK:    %[[OFFSET:.*]] = affine.apply
       // CHECK:    %[[SIZE:.*]] = affine.min
       %3 = affine.apply #map1(%arg3)[%arg0]
@@ -110,7 +110,7 @@
 
       // CHECK:    %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
       %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
-      scf.foreach_thread.perform_concurrently {
+      scf.forall.in_parallel {
         tensor.parallel_insert_slice %7 into %O[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
       }
     }
@@ -128,7 +128,7 @@
     pdl.pattern @match_in_parallel : benefit(1) {
       %0 = operands
       %1 = types
-      %2 = operation "scf.foreach_thread"(%0 : !pdl.range<value>)  -> (%1 : !pdl.range<type>)
+      %2 = operation "scf.forall"(%0 : !pdl.range<value>)  -> (%1 : !pdl.range<type>)
       rewrite %2 with "transform.dialect"
     }
     transform.structured.canonicalized_sequence %arg0 failures(propagate) {
diff --git a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
index 249a926..065c2a2 100644
--- a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
+++ b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
@@ -1,6 +1,6 @@
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [13, 33]
-  %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
+  %foreach_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [13, 33]
+  %dispatch_op = transform.iree.forall_to_flow %foreach_op
 }
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
index 3f25087..2ffb880 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
@@ -4,8 +4,8 @@
 ^bb1(%variant_op: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation
 
-  %foreach_thread, %tiled_generic =
-    transform.structured.tile_to_foreach_thread_op %0 num_threads [2] 
+  %forall, %tiled_generic =
+    transform.structured.tile_to_forall_op %0 num_threads [2] 
     // TODO: IREE needs own workgroup mapping attribute.
     ( mapping = [#gpu.block<x>] )
 
@@ -14,5 +14,5 @@
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   %func = transform.structured.match ops{["func.func"]} in %1 : (!pdl.operation) -> !pdl.operation
-  transform.iree.foreach_thread_to_workgroup %func
+  transform.iree.forall_to_workgroup %func
 }
diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
index 8d33c24..2d13cb7 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
@@ -4,10 +4,10 @@
 ^bb1(%variant_op: !pdl.operation):
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation
 
-  // Step 1. Tile to foreach_thread with tile_sizes [2].
+  // Step 1. Tile to forall with tile_sizes [2].
   // ===================================================
-  %foreach_thread, %tiled_generic =
-    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %matmul tile_sizes [2]
+  %forall, %tiled_generic =
+    transform.iree.tile_to_forall_and_workgroup_count_region %matmul tile_sizes [2]
       // TODO: IREE needs own workgroup mapping attribute.
       ( mapping = [#gpu.block<x>] )
 
@@ -21,5 +21,5 @@
   // Step 3. Post-bufferization mapping workgroup.
   // =========================================================
   %func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  transform.iree.foreach_thread_to_workgroup %func
+  transform.iree.forall_to_workgroup %func
 }
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
index beacf03..25a08b2 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -14,8 +14,8 @@
 
   // Step 2. First level of tiling + fusion parallelizes to blocks.
   // ===========================================================================
-  %foreach_thread_grid, %grid_combiner_op =
-    transform.structured.tile_to_foreach_thread_op %combiner_op tile_sizes [1]
+  %forall_grid, %grid_combiner_op =
+    transform.structured.tile_to_forall_op %combiner_op tile_sizes [1]
       ( mapping = [#gpu.block<x>] )
 
   // Step 2.1: Cannot fuse across the "expand_shape" produced by reduction
@@ -32,29 +32,29 @@
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %expanded_eltwise, %more_parallel_2, %combiner_2 =
     transform.split_handles %generics in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
-  %foreach_thread_grid_2 = transform.structured.match ops{["scf.foreach_thread"]} in %variant_op : (!pdl.operation) -> !pdl.operation
+  %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %not_combiner = transform.merge_handles %fill_2, %more_parallel_fill_2, %more_parallel_2, %expanded_eltwise : !pdl.operation
-  transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid_2
+  transform.structured.fuse_into_containing_op %not_combiner into %forall_grid_2
 
   // Step 3. Second level of tiling + fusion parallelizes to threads. Also
   // fuse in the leading elementwise.
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_combiner_op, %block_combiner_op =
-    transform.structured.tile_to_foreach_thread_op %combiner_2 tile_sizes [1] 
+  %forall_block_combiner_op, %block_combiner_op =
+    transform.structured.tile_to_forall_op %combiner_2 tile_sizes [1] 
     ( mapping = [#gpu.thread<z>] )
-  transform.structured.fuse_into_containing_op %fill_1d into %foreach_thread_block_combiner_op
+  transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_foreach_thread_op %grid_more_parallel_op tile_sizes [1, 1] 
+  %forall_block_more_parallel_op, %block_more_parallel_op =
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
-  transform.structured.fuse_into_containing_op %fill_2d into %foreach_thread_block_more_parallel_op
-  transform.structured.fuse_into_containing_op %grid_eltwise_op into %foreach_thread_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
 
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
@@ -72,8 +72,8 @@
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_5 = transform.iree.forall_to_workgroup %func_4
+  %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index 43dedf3..29f3860 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -17,7 +17,7 @@
   // trailing elementwise the same way we want to tile the reduction.
   // ===========================================================================
   %grid_loop, %trailing_eltwise_grid_op =
-    transform.structured.tile_to_foreach_thread_op %trailing_eltwise tile_sizes [1]
+    transform.structured.tile_to_forall_op %trailing_eltwise tile_sizes [1]
       ( mapping = [#gpu.block<x>] )
 
   // Step 2.1: Cannot fuse across the "expand_shape" produced by reduction
@@ -35,33 +35,33 @@
   %expanded_eltwise, %more_parallel_2, %combiner_2, %trailing_eltwise_2 =
     transform.split_handles %generics in [4]
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
-  %foreach_thread_grid_2 = transform.structured.match ops{["scf.foreach_thread"]} in %variant_op
+  %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op
   %not_trailing = transform.merge_handles %fill_2, %more_parallel_fill_2,
     %more_parallel_2, %expanded_eltwise, %combiner_2 : !pdl.operation
-  transform.structured.fuse_into_containing_op %not_trailing into %foreach_thread_grid_2
+  transform.structured.fuse_into_containing_op %not_trailing into %forall_grid_2
 
   // Step 3. Second level of tiling + fusion parallelizes to threads. Also
   // fuse in the leading and trailing elementwise.
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_trailing_eltwise_op, %block_trailing_eltwise_op =
-    transform.structured.tile_to_foreach_thread_op %trailing_eltwise_2 tile_sizes [1] 
+  %forall_trailing_eltwise_op, %block_trailing_eltwise_op =
+    transform.structured.tile_to_forall_op %trailing_eltwise_2 tile_sizes [1] 
     ( mapping = [#gpu.thread<z>] )
   %block_combiner_op = transform.structured.match ops{["linalg.generic"]}
     attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %fill_and_reduction = transform.merge_handles %fill_1d, %block_combiner_op : !pdl.operation
-  transform.structured.fuse_into_containing_op %fill_and_reduction into %foreach_thread_trailing_eltwise_op
+  transform.structured.fuse_into_containing_op %fill_and_reduction into %forall_trailing_eltwise_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_foreach_thread_op %grid_more_parallel_op tile_sizes [1, 1] 
+  %forall_block_more_parallel_op, %block_more_parallel_op =
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
-  transform.structured.fuse_into_containing_op %fill_2d into %foreach_thread_block_more_parallel_op
-  transform.structured.fuse_into_containing_op %grid_eltwise_op into %foreach_thread_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
 
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
@@ -79,8 +79,8 @@
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_5 = transform.iree.forall_to_workgroup %func_4
+  %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index a29843d..ddf3fc7 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -13,27 +13,27 @@
 
   // Step 2. First level of tiling + fusion parallelizes to blocks.
   // ===========================================================================
-  %foreach_thread_grid, %grid_combiner_op =
-    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %combiner_op tile_sizes [1]
+  %forall_grid, %grid_combiner_op =
+    transform.iree.tile_to_forall_and_workgroup_count_region %combiner_op tile_sizes [1]
       ( mapping = [#gpu.block<x>] )
   %not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op : !pdl.operation
-  transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid
+  transform.structured.fuse_into_containing_op %not_combiner into %forall_grid
 
   // Step 3. Second level of tiling + fusion parallelizes to threads.
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_combiner_op, %block_combiner_op =
-    transform.structured.tile_to_foreach_thread_op %grid_combiner_op tile_sizes [1] 
+  %forall_block_combiner_op, %block_combiner_op =
+    transform.structured.tile_to_forall_op %grid_combiner_op tile_sizes [1] 
     ( mapping = [#gpu.thread<z>] )
-  transform.structured.fuse_into_containing_op %fill_1d into %foreach_thread_block_combiner_op
+  transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_foreach_thread_op %grid_more_parallel_op tile_sizes [1, 1] 
+  %forall_block_more_parallel_op, %block_more_parallel_op =
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
-  transform.structured.fuse_into_containing_op %fill_2d into %foreach_thread_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
 
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
@@ -52,8 +52,8 @@
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_6 = transform.iree.foreach_thread_to_workgroup %func_5
-  %func_7 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_6
+  %func_6 = transform.iree.forall_to_workgroup %func_5
+  %func_7 = transform.iree.map_nested_forall_to_gpu_threads %func_6
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 69354fb..c7213ee 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -15,7 +15,7 @@
   // Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
   // trailing elementwise the same way we want to tile the reduction.
   // ===========================================================================
-  %grid_loop, %eltwise_grid_op = transform.iree.tile_to_foreach_thread_and_workgroup_count_region %eltwise 
+  %grid_loop, %eltwise_grid_op = transform.iree.tile_to_forall_and_workgroup_count_region %eltwise 
     tile_sizes [1] (mapping = [#gpu.block<x>])
   %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op : !pdl.operation
   transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop
@@ -24,7 +24,7 @@
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %eltwise_block_loop, %eltwise_block_op =
-    transform.structured.tile_to_foreach_thread_op %eltwise_grid_op tile_sizes [1]
+    transform.structured.tile_to_forall_op %eltwise_grid_op tile_sizes [1]
     ( mapping = [#gpu.thread<z>] )
   %block_combiner_op = transform.structured.match ops{["linalg.generic"]}
     attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
@@ -34,10 +34,10 @@
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_foreach_thread_op %grid_more_parallel_op tile_sizes [1, 1] 
+  %forall_block_more_parallel_op, %block_more_parallel_op =
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
-  transform.structured.fuse_into_containing_op %fill_2d into %foreach_thread_block_more_parallel_op
+  transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
 
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
@@ -56,8 +56,8 @@
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_6 = transform.iree.foreach_thread_to_workgroup %func_5
-  %func_7 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_6
+  %func_6 = transform.iree.forall_to_workgroup %func_5
+  %func_7 = transform.iree.map_nested_forall_to_gpu_threads %func_6
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
index 8224a2e..8951d6a 100644
--- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -7,29 +7,29 @@
 
   // Step 1. First level of tiling + fusion parallelizes to blocks.
   // ===========================================================================
-  %foreach_thread_grid, %grid_reduction =
-    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %reduction tile_sizes [1]
+  %forall_grid, %grid_reduction =
+    transform.iree.tile_to_forall_and_workgroup_count_region %reduction tile_sizes [1]
       ( mapping = [#gpu.block<x>] )
-  transform.structured.fuse_into_containing_op %fill into %foreach_thread_grid
+  transform.structured.fuse_into_containing_op %fill into %forall_grid
 
   // Step 2. Split the reduction to get meatier parallelism.
   // ===========================================================================
-  %foreach_thread, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 = 
+  %forall, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 = 
     transform.structured.tile_reduction_using_scf %grid_reduction by tile_sizes = [0, 128]
   %_1:2 =
-    transform.structured.tile_to_foreach_thread_op %block_more_parallel_op_2 num_threads [0, 32]
+    transform.structured.tile_to_forall_op %block_more_parallel_op_2 num_threads [0, 32]
     ( mapping = [#gpu.thread<x>] )
 
   // Step 3. Second level of tiling parallelizes to threads.
   // ===========================================================================
   // 1st op is [parallel, parallel], map it to threadIdx.x by 4.
   %_2:2 =
-    transform.structured.tile_to_foreach_thread_op %block_more_parallel_fill_op_2 tile_sizes [0, 4]
+    transform.structured.tile_to_forall_op %block_more_parallel_fill_op_2 tile_sizes [0, 4]
     ( mapping = [#gpu.thread<x>] )
   // 2nd op is [parallel, reduction] of 1x128, map the 1-dim to threadIdx.y to
   // trigger mapping of the reduction to threadIdx.x via predication via `if (x==0)`.
   %_3:2 =
-    transform.structured.tile_to_foreach_thread_op %block_combiner_op_2 tile_sizes [1] 
+    transform.structured.tile_to_forall_op %block_combiner_op_2 tile_sizes [1] 
     ( mapping = [#gpu.thread<y>] )
 
   // Step 4. Rank-reduce and vectorize.
@@ -51,8 +51,8 @@
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_8 = transform.iree.foreach_thread_to_workgroup %func_7
-  %func_9 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_8
+  %func_8 = transform.iree.forall_to_workgroup %func_7
+  %func_9 = transform.iree.map_nested_forall_to_gpu_threads %func_8
       { workgroup_size = [32, 1, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 867db8e..967d37a 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -7,26 +7,26 @@
 
   // Step 1. First level of tiling + fusion parallelizes to blocks.
   // ===========================================================================
-  %foreach_thread_grid, %grid_reduction =
-    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %reduction tile_sizes [1]
+  %forall_grid, %grid_reduction =
+    transform.iree.tile_to_forall_and_workgroup_count_region %reduction tile_sizes [1]
       ( mapping = [#gpu.block<x>] )
-  transform.structured.fuse_into_containing_op %fill into %foreach_thread_grid
+  transform.structured.fuse_into_containing_op %fill into %forall_grid
 
   // Step 2. Split the reduction to get meatier parallelism.
   // This also parallelizes to threads.
   // ===========================================================================
-  %foreach_thread, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 = 
-     transform.structured.tile_reduction_using_foreach_thread %grid_reduction 
+  %forall, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 = 
+     transform.structured.tile_reduction_using_forall %grid_reduction 
         by num_threads = [0, 1024], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
 
   // Fuse the fill and pointwise to privatize them.
   transform.structured.fuse_into_containing_op %block_more_parallel_fill_op_2
-    into %foreach_thread
+    into %forall
 
   // block_combiner_op_2 op is [parallel, reduction] of 1x384 that cannot fuse.
   // map the 1-dim to threadIdx.y to trigger mapping of the reduction to 
   // threadIdx.x via predication via `if (x==0)`.
-  transform.structured.tile_to_foreach_thread_op %block_combiner_op_2 num_threads [1] 
+  transform.structured.tile_to_forall_op %block_combiner_op_2 num_threads [1] 
     ( mapping = [#gpu.thread<y>] )
 
   // Step 3. Rank-reduce and vectorize.
@@ -50,8 +50,8 @@
   // Step 5. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
   %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_8 = transform.iree.foreach_thread_to_workgroup %func_7
-  %func_9 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_8
+  %func_8 = transform.iree.forall_to_workgroup %func_7
+  %func_9 = transform.iree.map_nested_forall_to_gpu_threads %func_8
       { workgroup_size = [1024, 1, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 9ace54f..df7464d 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -18,25 +18,25 @@
   // ==============================================================
   // This must be used with the custom dispatch region formation because IREE's
   // does not fuse even with --iree-flow-enable-aggressive-fusion.
-  // %foreach_thread, %_ =
-  // transform.iree.tile_to_foreach_thread_and_workgroup_count_region %div tile_sizes [1, 4]
+  // %forall, %_ =
+  // transform.iree.tile_to_forall_and_workgroup_count_region %div tile_sizes [1, 4]
   //   ( mapping = [#gpu.thread<x>, #gpu.thread<y>] )
-  %foreach_thread, %_ =
-    transform.structured.tile_to_foreach_thread_op %div tile_sizes [1, 4]
+  %forall, %_ =
+    transform.structured.tile_to_forall_op %div tile_sizes [1, 4]
       ( mapping = [#gpu.block<x>, #gpu.block<y>] )
   // TODO: Merging and fusing merged handles does not work properly atm.
-  transform.structured.fuse_into_containing_op %exps_sum into %foreach_thread
-  transform.structured.fuse_into_containing_op %exps into %foreach_thread
-  transform.structured.fuse_into_containing_op %exps_sum_fill into %foreach_thread
-  transform.structured.fuse_into_containing_op %input_max into %foreach_thread
-  transform.structured.fuse_into_containing_op %input_max_fill into %foreach_thread
-  // By default, fusion into scf.foreach_thread does not promote captured values
+  transform.structured.fuse_into_containing_op %exps_sum into %forall
+  transform.structured.fuse_into_containing_op %exps into %forall
+  transform.structured.fuse_into_containing_op %exps_sum_fill into %forall
+  transform.structured.fuse_into_containing_op %input_max into %forall
+  transform.structured.fuse_into_containing_op %input_max_fill into %forall
+  // By default, fusion into scf.forall does not promote captured values
   // to shared as this involves a cross-thread dependence analysis.
   // Instead, we activate it explicitly post-hoc to promote all the extract_slice
   // ops that we find and match the prerequisites
-  %foreach_thread_with_type = transform.cast %foreach_thread : !pdl.operation to !transform.op<"scf.foreach_thread">
-  transform.iree.share_foreach_thread_operands %foreach_thread_with_type
-    : (!transform.op<"scf.foreach_thread">) -> !transform.op<"scf.foreach_thread">
+  %forall_with_type = transform.cast %forall : !pdl.operation to !transform.op<"scf.forall">
+  transform.iree.share_forall_operands %forall_with_type
+    : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
 
   // Step 2. Second level of tiling + fusion parallelizes to threads.
   // ================================================================
@@ -58,14 +58,14 @@
                                                   %tiled_exp_and_exps_sum,
                                                   %tiled_exp_and_exps_sum_2
     : !pdl.operation
-  transform.structured.tile_to_foreach_thread_op %reduction_linalg_ops tile_sizes [1, 1]
+  transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   // Fully parallel ops are tiled and mapped.
   %parallel_linalg_ops = transform.merge_handles %tiled_input_max_fill,
                                                  %tiled_exps_sum_fill,
                                                  %tiled_div
     : !pdl.operation
-  transform.structured.tile_to_foreach_thread_op %parallel_linalg_ops num_threads [1, 4, 32]
+  transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32]
       ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] )
 
   // Step 3. Rank-reduce and vectorize.
@@ -84,8 +84,8 @@
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
-  transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+  %func_3 = transform.iree.forall_to_workgroup %func_2
+  transform.iree.map_nested_forall_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
index 46a4fe9..ba90328 100644
--- a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
@@ -12,10 +12,10 @@
   %red = transform.structured.match interface{LinalgOp}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %not_root = merge_handles %fill, %red : !pdl.operation
-  %foreach_thread, %tiled_generic =
-    transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4]
+  %forall, %tiled_generic =
+    transform.iree.tile_to_forall_and_workgroup_count_region %root tile_sizes [1, 4]
     ( mapping = [#gpu.block<x>, #gpu.block<y>] )
-  transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+  transform.structured.fuse_into_containing_op %not_root into %forall
 
   // Step 2. Second level of tiling + fusion parallelizes to threads.
   // ================================================================
@@ -24,21 +24,21 @@
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %foreach_thread_reduction, %tiled_reduction_generic =
-    transform.structured.tile_to_foreach_thread_op %reduction_linalg tile_sizes [1, 1]
+  %forall_reduction, %tiled_reduction_generic =
+    transform.structured.tile_to_forall_op %reduction_linalg tile_sizes [1, 1]
       ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   // TODO: this fusion currently does not happen properly, this is related to the clone
-  // behavior when fusing into scf.foreach_thread.
+  // behavior when fusing into scf.forall.
   // Once fixed we'll be able to fuse.
   // Fusion will save us one roundtrip to memory.
-  // transform.structured.fuse_into_containing_op %fill_linalg into %foreach_thread_reduction
-  transform.structured.tile_to_foreach_thread_op %parallel_linalg num_threads [1, 4, 32]
+  // transform.structured.fuse_into_containing_op %fill_linalg into %forall_reduction
+  transform.structured.tile_to_forall_op %parallel_linalg num_threads [1, 4, 32]
       ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] )
 
 
-  // Inability to tile reductions to scf.foreach_thread has 2 implications:
-  //   1. since no scf.foreach_thread is present, no gpu.barrier is added.
-  //      This should be fixed independently: ops that are not nested in an scf.foreach_thread
+  // Inability to tile reductions to scf.forall has 2 implications:
+  //   1. since no scf.forall is present, no gpu.barrier is added.
+  //      This should be fixed independently: ops that are not nested in an scf.forall
   //      should have a gpu.barrier. Later needs to be complemented by a barrier
   //      removal pass.
   //   2. Similarly, needs to be predicated under an if threadIx == 0 to avoid
@@ -67,8 +67,8 @@
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
-  transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+  %func_3 = transform.iree.forall_to_workgroup %func_2
+  transform.iree.map_nested_forall_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 8e44a16..89f18df 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -15,21 +15,21 @@
 
   // Step 1. First level of tiling + fusion parallelizes to blocks.
   // ==============================================================
-  %foreach_thread, %_ =
-  transform.iree.tile_to_foreach_thread_and_workgroup_count_region %div tile_sizes [1, 4]  
+  %forall, %_ =
+  transform.iree.tile_to_forall_and_workgroup_count_region %div tile_sizes [1, 4]  
     ( mapping = [#gpu.block<x>, #gpu.block<y>] )
   // TODO: Merging and fusing merged handles does not work properly atm.
-  transform.structured.fuse_into_containing_op %exp_and_exps_sum into %foreach_thread
-  transform.structured.fuse_into_containing_op %exps_sum_fill into %foreach_thread
-  transform.structured.fuse_into_containing_op %input_max into %foreach_thread
-  transform.structured.fuse_into_containing_op %input_max_fill into %foreach_thread
-  // By default, fusion into scf.foreach_thread does not promote captured values
+  transform.structured.fuse_into_containing_op %exp_and_exps_sum into %forall
+  transform.structured.fuse_into_containing_op %exps_sum_fill into %forall
+  transform.structured.fuse_into_containing_op %input_max into %forall
+  transform.structured.fuse_into_containing_op %input_max_fill into %forall
+  // By default, fusion into scf.forall does not promote captured values
   // to shared as this involves a cross-thread dependence analysis.
   // Instead, we activate it explicitly post-hoc to promote all the extract_slice
   // ops that we find and match the prerequisites
-  %foreach_thread_with_type = transform.cast %foreach_thread : !pdl.operation to !transform.op<"scf.foreach_thread">
-  transform.iree.share_foreach_thread_operands %foreach_thread_with_type
-    : (!transform.op<"scf.foreach_thread">) -> !transform.op<"scf.foreach_thread">
+  %forall_with_type = transform.cast %forall : !pdl.operation to !transform.op<"scf.forall">
+  transform.iree.share_forall_operands %forall_with_type
+    : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
 
   // Step 2. Second level of tiling + fusion parallelizes to threads.
   // ================================================================
@@ -49,14 +49,14 @@
   %reduction_linalg_ops = transform.merge_handles %tiled_input_max,
                                                   %tiled_exp_and_exps_sum
     : !pdl.operation
-  transform.structured.tile_to_foreach_thread_op %reduction_linalg_ops tile_sizes [1, 1]
+  transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   // Fully parallel ops are tiled and mapped.
   %parallel_linalg_ops = transform.merge_handles %tiled_input_max_fill,
                                                  %tiled_exps_sum_fill,
                                                  %tiled_div
     : !pdl.operation
-  transform.structured.tile_to_foreach_thread_op %parallel_linalg_ops num_threads [1, 4, 32]
+  transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] )
 
   // Step 3. Rank-reduce and vectorize.
@@ -75,8 +75,8 @@
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
-  transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
+  %func_3 = transform.iree.forall_to_workgroup %func_2
+  transform.iree.map_nested_forall_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
index f55dac8..5748777 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
@@ -3,7 +3,7 @@
   // Step 1. Find three linalg.generics and tile to GPU thread blocks.
   // ===========================================================================
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  transform.iree.tile_to_foreach_thread_and_workgroup_count_region %generics 
+  transform.iree.tile_to_forall_and_workgroup_count_region %generics 
                   tile_sizes [5, 3] ( mapping = [#gpu.block<z>, #gpu.block<x>])
 
   // Step 2. Rank reduce and bufferize and drop HAL decriptor from memref ops.
@@ -18,5 +18,5 @@
   // Step 3. Map to GPU thread blocks.
   // ===========================================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  transform.iree.foreach_thread_to_workgroup %func_2
+  transform.iree.forall_to_workgroup %func_2
 }
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
index 84836dd..f651ed1 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -2,6 +2,6 @@
 ^bb1(%variant_op: !pdl.operation):
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   // Tile only one dimension, skip the other one.
-  transform.iree.tile_to_foreach_thread_and_workgroup_count_region %generics 
+  transform.iree.tile_to_forall_and_workgroup_count_region %generics 
                   tile_sizes [0, 3] ( mapping = [#gpu.block<z>])
 }
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 677ea5e..eb14186 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 677ea5eb2e5b25b137495221a21877504ee5f2ce
+Subproject commit eb14186771e7bca992c043637aac3ed7104eaa1f
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index a913e03..50584fa 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit a913e03964df57009a51b46ccba09b322f2ba31b
+Subproject commit 50584fafb42af5dc34355282e380b6a74c355b29