Integrate llvm-project and bump dependencies 20230322 (#12730)

* llvm-project: 411b1d8f0795
* mlir-hlo: fa4f6f47d7a7c0123cbbc85c8929a29a51143b3b
* tensorflow: 6bedca8e818b152ea594a911faf6a6add9f7d795

mlir-hlo patch
 * move LLVMSupport to LINK_COMPONENTS

tensorflow patch:
* revert
https://github.com/tensorflow/tensorflow/commit/c5920f0727124a374e72146a70a2b32153cdcfab

---------

Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
Co-authored-by: Nicolas Vasilache <nicolas.vasilache@gmail.com>
Co-authored-by: Lei Zhang <antiagainst@google.com>
Co-authored-by: Hanhan Wang <hanchung@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
index 5d66cc9..c7675de 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -647,9 +647,9 @@
 }  // namespace
 
 FailureOr<scf::ForOp> pipelineSharedMemoryCopy(
-    scf::ForOp forOp, PipeliningSchedulingStrategy startegy, bool peelEpilogue,
-    int64_t depth, RewriterBase& rewriter) {
-  return applyPipelining(forOp, depth, peelEpilogue, startegy);
+    RewriterBase& rewriter, scf::ForOp forOp,
+    PipeliningSchedulingStrategy strategy, bool peelEpilogue, int64_t depth) {
+  return applyPipelining(forOp, depth, peelEpilogue, strategy);
 }
 
 /// Pass options
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index df6bb45..c79e5be 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -425,8 +425,8 @@
 
     IRRewriter rewriter(context);
     if (failed(tileAndFuseDispatchUsingSCFForOp(
-            cast<TilingInterface>(computeOps.back()), linalgTilingOptions,
-            rewriter))) {
+            rewriter, cast<TilingInterface>(computeOps.back()),
+            linalgTilingOptions))) {
       funcOp.emitOpError("Tile+Distribute failed");
       return signalPassFailure();
     }
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index 978cf70..d03f8d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -180,9 +180,12 @@
 /// `flow.dispatch.tensor.store` that writes only a tile of the result at
 /// offsets given by `tiledOffsets` and sizes given by `tiledSizes`, using
 /// `tiledValue` as the source.
-static LogicalResult replaceStoreWithTiledVersion(
-    RewriterBase &rewriter, OpResult untiledValue, OpResult tiledValue,
-    ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes) {
+static LogicalResult replaceStoresWithTiledVersion(
+    RewriterBase &rewriter, OpResult untiledValue, Value tiledValue,
+    ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
+    Block *innerLoopBody) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(innerLoopBody->getTerminator());
   SmallVector<IREE::Flow::DispatchTensorStoreOp> storeOps;
   for (OpOperand &use : untiledValue.getUses()) {
     auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(use.getOwner());
@@ -226,7 +229,9 @@
 static LogicalResult replaceAllStoresWithTiledVersion(
     RewriterBase &rewriter, TilingInterface untiledOp,
     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-    Operation *tiledOp) {
+    ValueRange tiledValues, Block *innerLoopBody) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(innerLoopBody->getTerminator());
   for (auto [index, result] : llvm::enumerate(untiledOp->getResults())) {
     SmallVector<OpFoldResult> resultOffsets, resultSizes;
     if (failed(untiledOp.getResultTilePosition(rewriter, index, offsets, sizes,
@@ -234,12 +239,9 @@
       return rewriter.notifyMatchFailure(
           untiledOp, "failed to rewrite destructive update");
     }
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(tiledOp->getBlock()->getTerminator());
-    if (failed(replaceStoreWithTiledVersion(
-            rewriter, result.cast<OpResult>(),
-            tiledOp->getResult(index).cast<OpResult>(), resultOffsets,
-            resultSizes))) {
+    if (failed(replaceStoresWithTiledVersion(rewriter, result.cast<OpResult>(),
+                                             tiledValues[index], resultOffsets,
+                                             resultSizes, innerLoopBody))) {
       return failure();
     }
   }
@@ -248,8 +250,9 @@
 
 namespace {
 /// Result of the tiled operation.
-struct TilingResult {
-  Operation *tiledOp = nullptr;
+struct IREETilingResult {
+  SmallVector<Operation *> tiledOps;
+  SmallVector<Value> tiledValues;
   SmallVector<scf::ForOp> loops;
   llvm::SmallBitVector tiledLoops;
   SmallVector<OpFoldResult> tileOffsets;
@@ -257,9 +260,9 @@
 };
 }  // namespace
 
-static FailureOr<TilingResult> tileDispatchUsingSCFFopOp(
-    TilingInterface op, linalg::LinalgTilingOptions options,
-    RewriterBase &rewriter) {
+static FailureOr<IREETilingResult> tileDispatchUsingSCFFopOp(
+    RewriterBase &rewriter, TilingInterface op,
+    linalg::LinalgTilingOptions options) {
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
@@ -273,7 +276,7 @@
   Location loc = op.getLoc();
   size_t numLoops = iterationDomainOfr.size();
   if (numLoops == 0) {
-    return TilingResult();
+    return IREETilingResult();
   }
   auto iterationDomain =
       llvm::to_vector(llvm::map_range(iterationDomainOfr, [&](Range r) {
@@ -295,7 +298,7 @@
   }
   tileSizeVector.resize(numLoops);
 
-  TilingResult tilingResult;
+  IREETilingResult tilingResult;
   tilingResult.tiledLoops.resize(numLoops, false);
   for (auto [index, tileSize] : llvm::enumerate(tileSizeVector)) {
     if (!isZero(tileSize)) {
@@ -304,10 +307,9 @@
   }
 
   if (!tilingResult.tiledLoops.any()) {
-    return TilingResult();
+    return IREETilingResult();
   }
 
-  SmallVector<Operation *> tiledImplementation;
   {
     SmallVector<OpFoldResult> offsets, sizes;
     // If there is an interchange specified, permute the iteration domain and
@@ -383,8 +385,14 @@
     if (!tilingResult.loops.empty())
       rewriter.setInsertionPoint(
           tilingResult.loops.back().getBody()->getTerminator());
-    tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes);
-    tilingResult.tiledOp = tiledImplementation.back();
+    FailureOr<TilingResult> tiledImplementation =
+        op.getTiledImplementation(rewriter, offsets, sizes);
+    if (failed(tiledImplementation)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to generate tiled implementation");
+    }
+    std::swap(tilingResult.tiledOps, tiledImplementation->tiledOps);
+    std::swap(tilingResult.tiledValues, tiledImplementation->tiledValues);
 
     LLVM_DEBUG({
       if (!tilingResult.loops.empty()) {
@@ -406,7 +414,7 @@
   // of the store. Its valid to this for all stores of the root untiled op.
   if (failed(replaceAllStoresWithTiledVersion(
           rewriter, op, tilingResult.tileOffsets, tilingResult.tileSizes,
-          tilingResult.tiledOp))) {
+          tilingResult.tiledValues, tilingResult.loops.back().getBody()))) {
     return failure();
   }
   return tilingResult;
@@ -458,19 +466,19 @@
   return sliceOps;
 }
 
-FailureOr<TileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
-    TilingInterface op, linalg::LinalgTilingOptions tilingOptions,
-    RewriterBase &rewriter) {
-  TileAndFuseResult tileAndFuseResult;
+FailureOr<IREETileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
+    RewriterBase &rewriter, TilingInterface op,
+    linalg::LinalgTilingOptions tilingOptions) {
+  IREETileAndFuseResult tileAndFuseResult;
   auto fusableProducers = getAllFusableProducers(op);
   // Apply the tiling pattern.
-  FailureOr<TilingResult> tilingResult =
-      tileDispatchUsingSCFFopOp(op, tilingOptions, rewriter);
+  FailureOr<IREETilingResult> tilingResult =
+      tileDispatchUsingSCFFopOp(rewriter, op, tilingOptions);
   if (failed(tilingResult)) {
     return failure();
   }
-  tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
-  tileAndFuseResult.loops.append(tilingResult->loops);
+  tileAndFuseResult.tiledAndFusedOps = tilingResult->tiledOps;
+  tileAndFuseResult.loops = tilingResult->loops;
 
   // If there is no tiling then there is nothing to do for fusion.
   if (!tilingResult->tiledLoops.any()) {
@@ -488,67 +496,34 @@
     // could use the first slice as the insertion point.
     auto sliceOps = getAllFusableProducerUses(
         fusableProducer, tileAndFuseResult.tiledAndFusedOps);
-    if (sliceOps.empty()) continue;
-    tensor::ExtractSliceOp sliceOp = sliceOps.front();
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(sliceOp);
 
-    // Generate the tiled implementation of the producer.
-    FailureOr<Value> tiledProducerVal =
-        tensor::replaceExtractSliceWithTiledProducer(
-            rewriter, sliceOp, sliceOp.getSource().cast<OpResult>());
-    if (failed(tiledProducerVal)) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "fusion along slice op failed");
-    }
-    auto tiledProducer = tiledProducerVal->getDefiningOp();
-    if (tiledProducer->getNumResults() != fusableProducer->getNumResults()) {
-      return rewriter.notifyMatchFailure(fusableProducer,
-                                         "fused operation expected to produce "
-                                         "an op with same number of results");
-    }
-
-    // 2b. Assume that the tile sizes used are such that all tiled loops are
-    //     "common parallel loops" for the consumer and all pulled in
-    //     producers. So using the tile size of the tiled consumer op, and the
-    //     information about which loops are tiled and which arent, compute
-    //     the tile sizes to use for the producer as well.
-    SmallVector<OpFoldResult> producerOffset, producerSizes;
-    SmallVector<Range> producerIterationDomain =
-        fusableProducer.getIterationDomain(rewriter);
-    for (auto [index, range] : llvm::enumerate(producerIterationDomain)) {
-      if (index < tilingResult->tiledLoops.size() &&
-          index < sliceOp.getMixedOffsets().size() &&
-          tilingResult->tiledLoops.test(index)) {
-        // It is not true that the tiling sizes for produces are always as same
-        // as the tiling sizes for consumers. The tensor.extract_slice op
-        // carries the information, so we can get the tiling sizes and offsets
-        // from it.
-        producerOffset.push_back(sliceOp.getMixedOffsets()[index]);
-        producerSizes.push_back(sliceOp.getMixedSizes()[index]);
-      } else {
-        producerOffset.push_back(range.offset);
-        producerSizes.push_back(range.size);
-      }
-    }
-
-    // 2c. Finally replace any `flow.dispatch.tensor.store` operation with
-    //     tiled version of the operation. It is only valid to do this under the
-    //     above assumption that the producer and consumer share the loops
-    //     that can be tiled.
-    if (failed(replaceAllStoresWithTiledVersion(rewriter, fusableProducer,
-                                                producerOffset, producerSizes,
-                                                tiledProducer))) {
-      return failure();
-    }
-    // Replace all uses of the slices processed in this step with values from
-    // the producer.
     for (auto sliceOp : sliceOps) {
-      unsigned resultNumber =
-          sliceOp.getSource().cast<OpResult>().getResultNumber();
-      rewriter.replaceOp(sliceOp, tiledProducer->getResult(resultNumber));
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(sliceOp);
+
+      // Generate the tiled implementation of the producer.
+      OpResult untiledValue = sliceOp.getSource().cast<OpResult>();
+      FailureOr<TilingResult> swapSliceResult =
+          tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
+                                                       untiledValue);
+      if (failed(swapSliceResult) || swapSliceResult->tiledValues.size() != 1) {
+        return rewriter.notifyMatchFailure(sliceOp,
+                                           "fusion along slice op failed");
+      }
+
+      // 2c. Finally replace any `flow.dispatch.tensor.store` operation with
+      //     tiled version of the operation. It is only valid to do this under
+      //     the above assumption that the producer and consumer share the loops
+      //     that can be tiled.
+      if (failed(replaceStoresWithTiledVersion(
+              rewriter, untiledValue, swapSliceResult->tiledValues[0],
+              sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+              tileAndFuseResult.loops.back().getBody()))) {
+        return failure();
+      }
+      rewriter.replaceOp(sliceOp, swapSliceResult->tiledValues[0]);
+      tileAndFuseResult.tiledAndFusedOps.append(swapSliceResult->tiledOps);
     }
-    tileAndFuseResult.tiledAndFusedOps.push_back(tiledProducer);
   }
 
   return tileAndFuseResult;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
index e966c9d..31cabb3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel
@@ -80,6 +80,7 @@
         "@llvm-project//mlir:ArithUtils",
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:BufferizationTransforms",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 0a186a0..de9c6c7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -38,6 +38,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/PassManager.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 083fb5b..ca99cb5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -30,17 +30,18 @@
 /// For a given operation within a dispatch, tile and distribute the operation
 /// to workgroups as well as tile + fuse its producers. Returns the
 /// generated tiled and fused ops, as well as the loops used for distribution.
-struct TileAndFuseResult {
+struct IREETileAndFuseResult {
   SmallVector<Operation *> tiledAndFusedOps;
   SmallVector<scf::ForOp> loops;
 };
-FailureOr<TileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
-    TilingInterface op, linalg::LinalgTilingOptions tilingOptions,
-    RewriterBase &rewriter);
+
+FailureOr<IREETileAndFuseResult> tileAndFuseDispatchUsingSCFForOp(
+    RewriterBase &rewriter, TilingInterface op,
+    linalg::LinalgTilingOptions tilingOptions);
 
 FailureOr<scf::ForOp> pipelineSharedMemoryCopy(
-    scf::ForOp forOp, PipeliningSchedulingStrategy startegy, bool peelEpilogue,
-    int64_t depth, RewriterBase &rewriter);
+    RewriterBase &rewriter, scf::ForOp forOp,
+    PipeliningSchedulingStrategy startegy, bool peelEpilogue, int64_t depth);
 
 /// Populate patterns related to clean up the IR after tile and distribute to
 /// workgroups.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 68982b2..4ff32ab 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -378,13 +378,13 @@
 
     // Move all the regions from the old op to the new op and legalize its
     // signature.
-    for (auto &[index, region] : llvm::enumerate(op->getRegions())) {
+    for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
       Region &newOpRegion = newOp->getRegion(index);
       rewriter.inlineRegionBefore(region, newOpRegion, newOpRegion.begin());
       TypeConverter::SignatureConversion signatureConverter(
           newOpRegion.getNumArguments());
       bool doSignatureConversion = false;
-      for (auto &[argIndex, arg] :
+      for (const auto &[argIndex, arg] :
            llvm::enumerate(newOpRegion.getArguments())) {
         Type argType = arg.getType();
         Type legalizedType = this->typeConverter->convertType(argType);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
index 8936261..59dd467 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
@@ -81,6 +81,7 @@
 
 // -----
 
+
 func.func @store_subspan_with_all_dynamic_dim(%value: f32, %offset : index, %i0: index, %i1: index, %i2: index, %i3: index) {
   %dim0 = hal.interface.constant.load[0] : index
   %dim1 = hal.interface.constant.load[1] : index
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 299fcdb..2a45439 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -1855,7 +1855,7 @@
 //   CHECK-DAG:         flow.dispatch.tensor.store %[[GENERIC0]], %[[RESULT_BINDING2]], offsets = [%[[IV0]], %[[IV1]]]
 
 // -----
-
+ 
 hal.executable private @no_tile {
   hal.executable.variant public @embedded_elf_x86_64, target = <"llvm-cpu", "embedded-elf-x86_64", {}> {
     hal.executable.export public @no_tile ordinal(0) layout(#hal.pipeline.layout<
@@ -2279,7 +2279,6 @@
         %18 = tensor.pack %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} : tensor<384x512xf32> -> tensor<48x512x8x1xf32>
         flow.dispatch.tensor.store %18, %6, offsets = [0, 0, 0, 0], sizes = [48, 512, 8, 1], strides = [1, 1, 1, 1] : tensor<48x512x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x512x8x1xf32>>
         flow.dispatch.tensor.store %16#0, %7, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
-        flow.dispatch.tensor.store %16#1, %8, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
         return
       }
     }
@@ -2292,7 +2291,6 @@
 // CHECK:             %[[PACK:.+]] = tensor.pack
 // CHECK-DAG:         flow.dispatch.tensor.store %[[PACK]], {{.*}} sizes = [8, 64, 8, 1]
 // CHECK-DAG:         flow.dispatch.tensor.store %[[ELEM]]#0, {{.*}} sizes = [64, 64]
-// CHECK-DAG:         flow.dispatch.tensor.store %[[ELEM]]#1, {{.*}} sizes = [64, 64]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
index 0788de8..d7afcd6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
@@ -179,10 +179,9 @@
     }
 
     // Add more fusion candidates to the worklist.
-    if (auto fusedProducerOp =
-            fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
-      addCandidateSlices(fusedProducerOp, candidates);
-      tiledOps.push_back(fusedProducerOp);
+    for (auto tiledOp : fusedProducer->tiledOps) {
+      addCandidateSlices(tiledOp, candidates);
+      tiledOps.push_back(tiledOp);
     }
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
index 735ee94..f915ee7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistribute.cpp
@@ -42,29 +42,11 @@
         [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
 
     IRRewriter rewriter(funcOp->getContext());
-    rewriter.setInsertionPoint(funcOp);
-    MLIRContext* ctx = funcOp->getContext();
-    SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
-        gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimX),
-        gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
-        gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
-
-    auto threadIdGenerator = [](RewriterBase& rewriter, scf::ForallOp forallOp,
-                                SmallVectorImpl<Value>& threadIds) {
-      Location loc = forallOp.getLoc();
-      IndexType indexType = rewriter.getIndexType();
-      threadIds.assign(
-          {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
-           rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
-           rewriter.create<gpu::ThreadIdOp>(loc, indexType,
-                                            gpu::Dimension::z)});
-    };
-
-    DiagnosedSilenceableFailure const result =
+    rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+    DiagnosedSilenceableFailure result =
         mlir::transform::gpu::mapNestedForallToThreadsImpl(
-            rewriter, funcOp, workgroupSize, threadIdGenerator, false,
-            std::nullopt, threadMappingAttributes);
-
+            rewriter, std::nullopt, funcOp, workgroupSize, /*warpDims=*/{},
+            false);
     if (!result.succeeded()) return signalPassFailure();
   }
 };
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
index bf73d9e..1199289 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel
@@ -63,8 +63,11 @@
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
         "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:BufferizationDialect",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUTransformOps",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
index 521b5ae..00deb7a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt
@@ -33,6 +33,8 @@
     ::LLVMGPUExtensionsOpGen
     IREEDialectsTransforms
     IREELinalgTransformDialect
+    LLVMSupport
+    MLIRAffineDialect
     MLIRArithDialect
     MLIRBufferizationDialect
     MLIRFuncDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 9090e3c..f52c5a6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -11,7 +11,10 @@
 #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -22,13 +25,25 @@
 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Region.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+using llvm::dbgs;
+
+#define DEBUG_TYPE "transform-llvmgpu-extensions"
+
+#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(dbgs() << '[' << DEBUG_TYPE << "] " << X)
+
 using namespace mlir;
 using namespace mlir::iree_compiler::IREE;
 
@@ -61,8 +76,10 @@
     transform::TransformState &state) {
   if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
     state.getTopLevel()->emitOpError(
-        "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel to "
-        "attach the workgroup size information to a nested ExecutableExportOp");
+        "requires HAL::ExecutableOp or HAL::ExecutableVariantOp "
+        "toplevel to "
+        "attach the workgroup size information to a nested "
+        "ExecutableExportOp");
     return emitDefaultDefiniteFailure(target);
   }
 
@@ -75,95 +92,22 @@
     return emitDefaultDefiniteFailure(target);
   }
 
-  SmallVector<int64_t> workgroupSize{getWorkgroupDims()};
-  // TODO: no magic constant but IREE uses this extensively.
-  workgroupSize.resize(/*size=*/3, /*value=*/1);
-
   auto transformOp = cast<transform::TransformOpInterface>(getOperation());
-  SimplePatternRewriter rewriter(target);
 
-  MLIRContext *ctx = target->getContext();
-  SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
-      gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimX),
-      gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimY),
-      gpu::GPUThreadMappingAttr::get(ctx, gpu::Threads::DimZ)};
-
-  auto threadIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
-                              SmallVectorImpl<Value> &threadIds) {
-    Location loc = forallOp.getLoc();
-    IndexType indexType = rewriter.getIndexType();
-    threadIds.assign(
-        {rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x),
-         rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
-         rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::z)});
-  };
-
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPointToStart(&target.getBody().front());
   DiagnosedSilenceableFailure diag =
       mlir::transform::gpu::mapNestedForallToThreadsImpl(
-          rewriter, target, workgroupSize, threadIdGenerator, true, transformOp,
-          threadMappingAttributes);
+          rewriter, transformOp, target, getWorkgroupDims(), getWarpDims(),
+          true);
 
   if (diag.succeeded()) {
-    auto newAttr = rewriter.getIndexArrayAttr(workgroupSize);
-    // TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
+    auto newAttr = rewriter.getIndexArrayAttr(getWorkgroupDims());
+    rewriter.startRootUpdate(exportOp);
     exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
+    rewriter.finalizeRootUpdate(exportOp);
   }
 
-  // Map warpIds, only if threadIdx.x is a multiple of the warp size.
-  // TODO: more advanced mechanism to linearize/delinearize the threadIds to
-  // warps.
-  SmallVector<DeviceMappingAttrInterface> warpMappingAttributes = {
-      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimX),
-      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimY),
-      gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimZ)};
-  if (diag.succeeded() && (workgroupSize[0] % kWarpSize == 0)) {
-    auto warpIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp,
-                              SmallVectorImpl<Value> &warpIds) {
-      Location loc = forallOp.getLoc();
-      IndexType indexType = rewriter.getIndexType();
-      Value threadIdX =
-          rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::x);
-      Value cstWarpSize =
-          rewriter.create<arith::ConstantIndexOp>(loc, kWarpSize);
-      Value warpIdX =
-          rewriter.create<arith::DivUIOp>(loc, threadIdX, cstWarpSize);
-      warpIds.assign(
-          {warpIdX,
-           rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpu::Dimension::y),
-           rewriter.create<gpu::ThreadIdOp>(loc, indexType,
-                                            gpu::Dimension::z)});
-    };
-    SmallVector<int64_t> numWarps = {workgroupSize[0] / kWarpSize,
-                                     workgroupSize[1], workgroupSize[2]};
-    diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
-        rewriter, target, numWarps, warpIdGenerator, true, transformOp,
-        warpMappingAttributes);
-  }
-
-  auto walkResult = target->walk([&warpMappingAttributes](
-                                     scf::ForallOp forallOp) -> WalkResult {
-    auto maybeMapping = forallOp.getMapping();
-    if (!maybeMapping) return WalkResult::advance();
-    for (Attribute attr : *maybeMapping) {
-      for (const auto &warpAttr : warpMappingAttributes) {
-        if (attr == warpAttr) {
-          forallOp->emitOpError(
-              "Mapping failed: is threadIdx.x a multiple of the warp size?");
-          return WalkResult::interrupt();
-        }
-      }
-    }
-    return WalkResult::advance();
-  });
-  if (walkResult.wasInterrupted()) {
-    return emitDefaultDefiniteFailure(target);
-  }
-
-  auto newAttr = rewriter.getIndexArrayAttr(workgroupSize);
-  // TODO: should really be: exportOp.setWorkgroupSizeAttr(newAttr);
-  rewriter.startRootUpdate(exportOp);
-  exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
-  rewriter.finalizeRootUpdate(exportOp);
   return diag;
 }
 
@@ -273,19 +217,21 @@
       isThreadIdxxZeroPredicate(ifOp);
   if (failed(maybeThreadIdxxOp)) return failure();
 
-  // All the code below will be executed on a single warp given a fixed
-  // (threadIdxy, threadIdxz).
-  // Note, we reuse `maybeThreadIdxxOp` here because we later want to replace
-  // this op instance by 0 without relying on CSE or canonicalizations.
+  // All the code below will be executed on a single warp given a
+  // fixed (threadIdxy, threadIdxz). Note, we reuse
+  // `maybeThreadIdxxOp` here because we later want to replace this
+  // op instance by 0 without relying on CSE or canonicalizations.
   Value threadIdxx = *maybeThreadIdxxOp;
 
   assert(workgroupSizeX % warpSize == 0);
   if (workgroupSizeX != warpSize) {
-    // Add a guard for `threadIdxx < warp size` around the WarpExecuteOnLane0Op.
+    // Add a guard for `threadIdxx < warp size` around the
+    // WarpExecuteOnLane0Op.
     Value predicate = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ult, threadIdxx,
         rewriter.create<arith::ConstantIndexOp>(loc, warpSize));
-    // Note: return-less IfOp is built with a terminator, no need to add one.
+    // Note: return-less IfOp is built with a terminator, no need to
+    // add one.
     auto newIfOp =
         rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
     rewriter.setInsertionPointToStart(&newIfOp.getThenRegion().front());
@@ -293,7 +239,8 @@
   auto warpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
       loc, TypeRange(), threadIdxx, warpSize);
 
-  // Move the code from the previous ifOp to the WarpExecuteOnLane0Op.
+  // Move the code from the previous ifOp to the
+  // WarpExecuteOnLane0Op.
   Block &sourceBlock = ifOp.getThenRegion().front();
   Block &targetBlock = warpOp.getWarpRegion().front();
   Block::iterator insertionPoint = targetBlock.begin();
@@ -310,8 +257,8 @@
   // This simple rewrite propagates zero in lieu of laneId within the
   // warp_execute_on_lane_0 op.
   // Atm, this **must** occur before any hoisting of code.
-  // TODO: Replace this by a more robust scoped SCCP that will make it more
-  // robust re. hoisting.
+  // TODO: Replace this by a more robust scoped SCCP that will make
+  // it more robust re. hoisting.
   (void)replaceAllUsesOfLaneWithin(rewriter, warpOp);
 
   // Hoist the scalar code outside of the warp region.
@@ -341,9 +288,12 @@
   if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
     results.assign(1, nullptr);
     return emitDefaultSilenceableFailure(state.getTopLevel())
-           << "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel "
-              "so that IR is properly isolated. This is required so we can "
-              "safely inspect the HAL::ExecutableExportOp under multi-threaded "
+           << "requires HAL::ExecutableOp or "
+              "HAL::ExecutableVariantOp toplevel "
+              "so that IR is properly isolated. This is required so "
+              "we can "
+              "safely inspect the HAL::ExecutableExportOp under "
+              "multi-threaded "
               "pass assumptions.";
   }
 
@@ -353,29 +303,35 @@
   HAL::ExecutableExportOp exportOp =
       getExecutableExportOpForFunc(halExecutableVariantOp, funcOp);
   if (!halExecutableVariantOp || !funcOp || !exportOp) {
-    // Return a silenceable failure and set the expected 1 result to nullptr.
+    // Return a silenceable failure and set the expected 1 result to
+    // nullptr.
     results.assign(1, nullptr);
     return emitDefaultSilenceableFailure(target)
-           << "export op is missing --- the transform is not applied";
+           << "export op is missing --- the transform is not "
+              "applied";
   }
 
   std::optional<ArrayAttr> maybeAttr = exportOp.getWorkgroupSize();
   // TODO: Pervasive 3 constant in IREE.
   if (!maybeAttr || maybeAttr->size() != 3) {
-    // Return a silenceable failure and set the expected 1 result to nullptr.
+    // Return a silenceable failure and set the expected 1 result to
+    // nullptr.
     results.assign(1, nullptr);
     return emitDefaultSilenceableFailure(target)
-           << "export op must have workgroup_size attribute set with 3 entries "
+           << "export op must have workgroup_size attribute set "
+              "with 3 entries "
               "--- the transform is not applied";
   }
 
   int64_t workgroupSizeX = (*maybeAttr)[0].cast<IntegerAttr>().getInt();
   int64_t warpSize = getWarpSize();
   if (workgroupSizeX % warpSize != 0) {
-    // Return a silenceable failure and set the expected 1 result to nullptr.
+    // Return a silenceable failure and set the expected 1 result to
+    // nullptr.
     results.assign(1, nullptr);
     return emitDefaultSilenceableFailure(target)
-           << "vector distribution requires workgroup size for x to be a "
+           << "vector distribution requires workgroup size for x to "
+              "be a "
            << "multiple of the warp size: " << workgroupSizeX << " vs "
            << warpSize << " --- the transform is not applied";
   }
@@ -385,10 +341,12 @@
       rewriteScfIfAsWarpExecuteOnLane0(rewriter, target->getLoc(), target,
                                        workgroupSizeX, warpSize);
   if (failed(vectorDistributionResult)) {
-    // Return a silenceable failure and set the expected 1 result to nullptr.
+    // Return a silenceable failure and set the expected 1 result to
+    // nullptr.
     results.assign(1, nullptr);
     return emitDefaultSilenceableFailure(target)
-           << "scf::ifOp needs to be predicated on threadIdx.x == 0 --- the "
+           << "scf::ifOp needs to be predicated on threadIdx.x == 0 "
+              "--- the "
               "transform is not applied";
   }
   results.push_back(vectorDistributionResult->warpOp);
@@ -499,15 +457,15 @@
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(warpOp);
     // TODO: generalize this.
-    // options.warpSyncronizationFn currently must take a WarpExecuteOnLane0Op
-    // which we don't have here.
+    // options.warpSyncronizationFn currently must take a
+    // WarpExecuteOnLane0Op which we don't have here.
     rewriter.create<gpu::BarrierOp>(load.getLoc());
     Value newRead = rewriter.create<memref::LoadOp>(
         load.getLoc(), distributedVal.getType(), load.getMemref(), indices);
 
-    // The result type of WarpExecuteOnLane0Op may or may not match the yielded
-    // type depending on whether the op has "broadcast" behavior (see the doc
-    // of WarpExecuteOnLane0Op).
+    // The result type of WarpExecuteOnLane0Op may or may not match
+    // the yielded type depending on whether the op has "broadcast"
+    // behavior (see the doc of WarpExecuteOnLane0Op).
     for (OpOperand &use : distributedVal.getUses()) {
       rewriter.startRootUpdate(use.getOwner());
       Value replacement = newRead;
@@ -539,8 +497,8 @@
     auto warpParent = alloc->getParentOfType<vector::WarpExecuteOnLane0Op>();
     if (!warpParent) return failure();
     alloc->moveBefore(warpParent);
-    // Conservatively move the dealloc after the warpOp. This may extend the
-    // liverange of the allocation but is always correct.
+    // Conservatively move the dealloc after the warpOp. This may
+    // extend the liverange of the allocation but is always correct.
     for (Operation *user : alloc->getUsers()) {
       if (isa<memref::DeallocOp>(user)) user->moveAfter(warpParent);
     }
@@ -633,7 +591,8 @@
     transform::TransformState &state) {
   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
     target->emitOpError(
-        "applies only to isolated-from-above targets because it needs to apply "
+        "applies only to isolated-from-above targets because it "
+        "needs to apply "
         "patterns greedily");
     return emitDefaultDefiniteFailure(target);
   }
@@ -642,8 +601,8 @@
   // automatically get listening capabilities.
 
   MLIRContext *ctx = target->getContext();
-  // MultiReduction lowering is necessary until we have explicit support for
-  // distributing that op.
+  // MultiReduction lowering is necessary until we have explicit
+  // support for distributing that op.
   RewritePatternSet preProcessingPatterns(ctx);
   populateMultiReductionLoweringPatterns(target, preProcessingPatterns,
                                          /*benefit=*/1);
@@ -657,8 +616,10 @@
   }
 
   RewritePatternSet patterns(ctx);
-  populateVectorTransferWriteDistribution(target, patterns, /*benefit=*/2);
-  populatePropagateVectorDistribution(target, patterns, /*benefit=*/1);
+  populateVectorTransferWriteDistribution(target, patterns,
+                                          /*benefit=*/2);
+  populatePropagateVectorDistribution(target, patterns,
+                                      /*benefit=*/1);
   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
     return mlir::emitDefiniteFailure(
         target, "warp distribution patterns failed to apply");
@@ -668,7 +629,8 @@
   vector::WarpExecuteOnLane0LoweringOptions options;
   options.warpAllocationFn = allocateGlobalSharedMemory;
   options.warpSyncronizationFn = warpSyncronizationFn;
-  populateWarpExecuteOnLane0ToScf(target, endPatterns, options, /*benefit=*/0);
+  populateWarpExecuteOnLane0ToScf(target, endPatterns, options,
+                                  /*benefit=*/0);
   if (failed(applyPatternsAndFoldGreedily(target, std::move(endPatterns)))) {
     return mlir::emitDefiniteFailure(
         target, "warp execute on lane 0 to scf patterns failed to apply");
@@ -689,7 +651,8 @@
     transform::TransformState &state) {
   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
     target->emitOpError(
-        "applies only to isolated-from-above targets because it needs to apply "
+        "applies only to isolated-from-above targets because it "
+        "needs to apply "
         "patterns greedily");
     return emitDefaultDefiniteFailure(target);
   }
@@ -709,8 +672,8 @@
   MLIRContext *ctx = target->getContext();
 
   // Unrolling to native vector size must have previously occurred.
-  // TODO: Add pattern to propagate the extract through the scf.for ops.
-  // Convert slice of contract operations to mma_sync/wmma ops.
+  // TODO: Add pattern to propagate the extract through the scf.for
+  // ops. Convert slice of contract operations to mma_sync/wmma ops.
   RewritePatternSet patterns(ctx);
   mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
   populatePrepareVectorToMMAPatterns(patterns, getUseMmaSync());
@@ -788,8 +751,8 @@
   IRRewriter rewriter(getContext());
   int64_t depth(getDepth());
   FailureOr<scf::ForOp> pipelinedFor = iree_compiler::pipelineSharedMemoryCopy(
-      forOp, PipeliningSchedulingStrategy::loadGlobalStage0, false, depth,
-      rewriter);
+      rewriter, forOp, PipeliningSchedulingStrategy::loadGlobalStage0, false,
+      depth);
   if (failed(pipelinedFor)) return emitDefaultSilenceableFailure(forOp);
   results.push_back(pipelinedFor.value());
   return DiagnosedSilenceableFailure::success();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 34e4d38..768041d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -88,12 +88,14 @@
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$workgroup_dims);
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$workgroup_dims,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims);
   let results = (outs);
 
   let assemblyFormat = [{
     $target
     `workgroup_dims` `=` $workgroup_dims
+    (`warp_dims` `=` $warp_dims^)?
     attr-dict
     `:` functional-type($target, results)
   }];
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 167798c..f53464f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -19,7 +19,7 @@
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
   transform.iree.map_nested_forall_to_gpu_threads %memref_func 
-    workgroup_dims = [10, 11] : (!pdl.operation) -> ()
+    workgroup_dims = [10, 11, 1] : (!pdl.operation) -> ()
 
   // Late canonicalizations to cleanup and pass the checks
   transform.iree.apply_patterns %memref_func
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
index 4173942..3267723 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -13,6 +13,8 @@
       hal.return %arg1, %c1, %c1 : index, index, index
     }
     builtin.module {
+
+//       CHECK: #[[$DIV32MOD8:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 8)>
 // CHECK-LABEL: func.func @distribute
       func.func @distribute() {
         %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
@@ -23,7 +25,7 @@
         memref.assume_alignment %1, 64 : memref<2xf16>
         %workgroup_id_x = hal.interface.workgroup.id[0] : index
         %subview = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xf16> to memref<1xf16, strided<[1], offset: ?>>
-// CHECK: %[[C32:.+]] = arith.constant 32 : index
+
 // CHECK: %[[TX:.+]] = gpu.thread_id  x
 // CHECK: %[[COND:.*]] = arith.cmpi ult
 // CHECK: scf.if %[[COND]] {
@@ -32,7 +34,8 @@
           vector.transfer_write %cst_0, %subview[%arg0]
           {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
         } {mapping = [#gpu.thread<x>]}
-// CHECK: %[[WX:.+]] = arith.divui %[[TX]], %[[C32]] : index
+
+// CHECK: %[[WX:.+]] = affine.apply #[[$DIV32MOD8]]()[%[[TX]]]
 // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[WX]]] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>>
         scf.forall (%arg0) in (%c8) {
           vector.transfer_write %cst_0, %subview[%arg0]
@@ -46,7 +49,7 @@
         %17 = transform.structured.match ops{["func.func"]} in %variant_op 
           : (!pdl.operation) -> !pdl.operation
         transform.iree.map_nested_forall_to_gpu_threads %17 
-          workgroup_dims = [256, 1, 1] : (!pdl.operation) -> ()
+          workgroup_dims = [256, 1, 1] warp_dims = [8, 1, 1] : (!pdl.operation) -> ()
 
         // Late canonicalizations to cleanup and pass the checks.
         // Needs to occur on the whole variant to perform cse on the workgroup_count region
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
index 7e24df8..24ffbf1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
@@ -200,8 +200,8 @@
 }
 
 // CHECK-LABEL:   hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 {
-//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[D0:.*]] = gpu.thread_id  x
 //       CHECK:   %[[D1:.*]] = gpu.thread_id  y
 //       CHECK:   %[[D2:.*]] = gpu.thread_id  z
@@ -267,8 +267,8 @@
 }
 
 // CHECK-LABEL:   hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 {
-//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[D0:.*]] = gpu.thread_id  x
 //       CHECK:   %[[D1:.*]] = gpu.thread_id  y
 //       CHECK:   %[[D2:.*]] = gpu.thread_id  z
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
index 9919567..babde9d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_promotion.mlir
@@ -182,15 +182,15 @@
 //  CHECK-COUNT-32:     vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x64x20xf32, #gpu.address_space<workgroup>>, vector<4xf32>
 //  CHECK-COUNT-16:     vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x16x68xf32, #gpu.address_space<workgroup>>, vector<4xf32>
 // CHECK-COUNT-128:     vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
-//           CHECK:     affine.apply #[[MAP]]
+//       CHECK-DAG:     %[[APPLY:.+]] = affine.apply #[[MAP]]
+//       CHECK-DAG:     vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
+//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
 //           CHECK:     vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
-//           CHECK:     vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<128x512xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
+//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x64x20xf32, #gpu.address_space<workgroup>>
 //           CHECK:     vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<512x256xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
+//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
 //           CHECK:     vector.transfer_read %{{.+}}, %[[CST0]] {__pipelining_first_stage__, in_bounds = [true]} : memref<512x256xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}} {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
+//           CHECK:     vector.transfer_write %{{.+}}, %{{.+}}[%[[APPLY]], {{.+}}] {__pipelining_first_stage__, in_bounds = [true]} : vector<4xf32>, memref<3x16x68xf32, #gpu.address_space<workgroup>>
 //           CHECK:     gpu.barrier {__pipelining_first_stage__}
 //           CHECK:     scf.yield
 //  CHECK-COUNT-32:   vector.transfer_read %{{.+}}, %[[CST0]] {in_bounds = [true]} : memref<3x64x20xf32, #gpu.address_space<workgroup>>, vector<4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 5a35273..01c2510 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -96,7 +96,7 @@
   // For zero dim tensor, consider it's too small to access using all threads.
   if (shape.size() == 0) return false;
   int64_t threadsAvailable = threadCount;
-  for (auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) {
+  for (const auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) {
     int64_t numElementPerThread = index == 0 ? vectorSize : 1;
     int64_t numThreads = dim / numElementPerThread;
     if (numThreads == 0) return false;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
index 0fcda13..3a37335 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
@@ -139,7 +139,7 @@
   llvm::PassInstrumentationCallbacks pic;
 
   llvm::StandardInstrumentations si(module.getContext(), false);
-  si.registerCallbacks(pic, &fam);
+  si.registerCallbacks(pic, &mam);
 
   llvm::PassBuilder pb(&targetMachine, pto, std::nullopt, &pic);
   llvm::ModulePassManager mpm;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 3394389..0f3ef14 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -93,7 +93,7 @@
       auto newOperands =
           llvm::to_vector<4>(adaptor.getODSOperands(inputSetIndex));
       ++inputSetIndex;
-      if (auto inputTupleType = inputType.dyn_cast<TupleType>()) {
+      if (auto inputTupleType = inputType.template dyn_cast<TupleType>()) {
         // Unpack a tuple<...> from the variadic.
         // This only supports a single level of unpacking.
         if (inputTupleType.size() != newOperands.size()) {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index af14977..56d0957 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -3032,7 +3032,7 @@
         funcOp.getNumArguments() - op.getOperands().size();
 
     IRMapping refMapping;
-    for (auto &pair : llvm::enumerate(op.getOperands())) {
+    for (const auto &pair : llvm::enumerate(op.getOperands())) {
       Value operand = pair.value();
       size_t index = pair.index();
 
@@ -3492,7 +3492,7 @@
         this->template getTypeConverter<IREE::VM::EmitCTypeConverter>();
 
     SmallVector<Value> unwrappedOperands;
-    for (auto &operand : llvm::enumerate(adaptor.getOperands())) {
+    for (const auto &operand : llvm::enumerate(adaptor.getOperands())) {
       if (refArgumentIndices.contains(operand.index())) {
         Type originalType =
             op.getOperation()->getOperand(operand.index()).getType();
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index 4cf13ed..7b66de5 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -7,7 +7,7 @@
 
 load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
 
-TENSORFLOW_COMMIT = "49ffedb82a7e6597b5e5dae9ec6f948fab858842"
+TENSORFLOW_COMMIT = "b4c19400d3a5be519a487cb8461fb7b57e2150ca"
 
 git_repository(
     name = "org_tensorflow",
diff --git a/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run b/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
index c376280..594b7a6 100644
--- a/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
+++ b/integrations/tensorflow/test/iree_tf_tests/layers/vulkan__dynamic_dims_Concatenate.run
@@ -1,3 +1,2 @@
 # REQUIRES: vulkan
 # RUN: %PYTHON -m iree_tf_tests.layers.layers_test --target_backends=iree_vulkan --dynamic_dims=true --training=false --test_default_kwargs_only=true --layer=Concatenate --artifacts_dir=%t
-# XFAIL: *
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 398a356..c57bb75 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -30,7 +30,7 @@
 struct SwapTilingInterfaceOp : public OpRewritePattern<tensor::ExtractSliceOp> {
   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
 
-  FailureOr<Operation *>
+  FailureOr<TilingResult>
   returningMatchAndRewrite(tensor::ExtractSliceOp sliceOp,
                            PatternRewriter &rewriter) const;
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index d00a15e..bf2cd6b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -267,7 +267,7 @@
   return ranges;
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 ScatterOp::getTiledImplementation(OpBuilder &builder,
                                   ArrayRef<OpFoldResult> offsets,
                                   ArrayRef<OpFoldResult> sizes) {
@@ -316,7 +316,8 @@
   Operation *tiledScatterOp =
       mlir::clone(builder, getOperation(), resultTypes,
                   ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
-  return {tiledScatterOp};
+  return TilingResult{{tiledScatterOp},
+                      SmallVector<Value>(tiledScatterOp->getResults())};
 }
 
 LogicalResult ScatterOp::getResultTilePosition(
@@ -484,7 +485,7 @@
   return loopBounds;
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 SortOp::getTiledImplementation(OpBuilder &builder,
                                ArrayRef<OpFoldResult> offsets,
                                ArrayRef<OpFoldResult> sizes) {
@@ -506,7 +507,8 @@
   }
   Operation *tiledSortOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-  return {tiledSortOp};
+  return TilingResult{{tiledSortOp},
+                      SmallVector<Value>{tiledSortOp->getResults()}};
 }
 
 LogicalResult SortOp::getResultTilePosition(
@@ -807,7 +809,7 @@
   return success();
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 FftOp::getTiledImplementation(OpBuilder &builder,
                               ArrayRef<OpFoldResult> offsets,
                               ArrayRef<OpFoldResult> sizes) {
@@ -828,7 +830,8 @@
   }
   Operation *tiledFftOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-  return {tiledFftOp};
+  return TilingResult{{tiledFftOp},
+                      SmallVector<Value>(tiledFftOp->getResults())};
 }
 
 LogicalResult FftOp::getResultTilePosition(
@@ -1006,7 +1009,7 @@
   return success();
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 ScanOp::getTiledImplementation(OpBuilder &builder,
                                ArrayRef<OpFoldResult> offsets,
                                ArrayRef<OpFoldResult> sizes) {
@@ -1042,7 +1045,8 @@
 
   Operation *tiledScanOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-  return {tiledScanOp};
+  return TilingResult{{tiledScanOp},
+                      SmallVector<Value>(tiledScanOp->getResults())};
 }
 
 LogicalResult ScanOp::getResultTilePosition(
@@ -1161,7 +1165,7 @@
   return success();
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 ReverseOp::getTiledImplementation(OpBuilder &builder,
                                   ArrayRef<OpFoldResult> offsets,
                                   ArrayRef<OpFoldResult> sizes) {
@@ -1191,7 +1195,8 @@
   Operation *tiledRevOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return {tiledRevOp};
+  return TilingResult{{tiledRevOp},
+                      SmallVector<Value>(tiledRevOp->getResults())};
 }
 
 LogicalResult ReverseOp::getResultTilePosition(
@@ -1424,7 +1429,7 @@
   return success();
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 TopkOp::getTiledImplementation(OpBuilder &builder,
                                ArrayRef<OpFoldResult> offsets,
                                ArrayRef<OpFoldResult> sizes) {
@@ -1465,7 +1470,8 @@
 
   Operation *tiledTopkOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-  return {tiledTopkOp};
+  return TilingResult{{tiledTopkOp},
+                      SmallVector<Value>(tiledTopkOp->getResults())};
 }
 
 LogicalResult TopkOp::getResultTilePosition(
@@ -2170,7 +2176,7 @@
   return iteratorTypes;
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
@@ -2213,7 +2219,7 @@
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return {tiledOp};
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
 }
 
 LogicalResult WinogradInputTransformOp::getResultTilePosition(
@@ -2332,7 +2338,7 @@
   return iteratorTypes;
 }
 
-SmallVector<Operation *> WinogradOutputTransformOp::getTiledImplementation(
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
     OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes) {
 
@@ -2374,7 +2380,7 @@
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return {tiledOp};
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
 }
 
 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
@@ -2450,7 +2456,7 @@
   return iteratorTypes;
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
                                   ArrayRef<OpFoldResult> offsets,
                                   ArrayRef<OpFoldResult> sizes) {
@@ -2470,7 +2476,7 @@
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return {tiledOp};
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2546,7 +2552,7 @@
   return iteratorTypes;
 }
 
-SmallVector<Operation *>
+FailureOr<TilingResult>
 AttentionOp::getTiledImplementation(OpBuilder &builder,
                                     ArrayRef<OpFoldResult> offsets,
                                     ArrayRef<OpFoldResult> sizes) {
@@ -2594,7 +2600,7 @@
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return {tiledOp};
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
 }
 
 LogicalResult AttentionOp::getResultTilePosition(
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
index ae77fb3..bb020cb 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -89,30 +89,27 @@
   // the op by invoking the TiledOpInterface methods.
   if (loopDepth == tileSizes.size()) {
     TiledOp ret;
-    SmallVector<Operation *> tiledOps =
+    FailureOr<TilingResult> tiledOps =
         tilableOp.getTiledImplementation(builder, offsets, tileSizes);
-    if (tiledOps.empty()) {
+    if (failed(tiledOps)) {
       return static_cast<LogicalResult>(
           tilableOp.emitOpError("failed to get tiled implementation"));
     }
-    assert(
-        (tiledOps.size() == 1) &&
-        "expected only a single operation returned from tiling implementation");
-    ret.op.assign(tiledOps);
-    for (auto result : llvm::enumerate(ret.op.back()->getResults())) {
-      if (!result.value().getType().isa<RankedTensorType>()) {
-        ret.results.push_back(result.value());
+    ret.op.append(tiledOps->tiledOps);
+    for (auto [index, result] : llvm::enumerate(tilableOp->getResults())) {
+      if (!result.getType().isa<RankedTensorType>()) {
+        ret.results.push_back(result);
         continue;
       }
       SmallVector<OpFoldResult> resultOffsets, resultSizes;
-      if (succeeded(tilableOp.getResultTilePosition(
-              builder, result.index(), offsets, tileSizes, resultOffsets,
-              resultSizes))) {
+      if (succeeded(tilableOp.getResultTilePosition(builder, index, offsets,
+                                                    tileSizes, resultOffsets,
+                                                    resultSizes))) {
         SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
                                                 builder.getIndexAttr(1));
         Value insertSlice = builder.create<tensor::InsertSliceOp>(
-            loc, ret.op.back()->getResult(result.index()),
-            outputs[result.index()], resultOffsets, resultSizes, resultStrides);
+            loc, tiledOps->tiledValues[index], outputs[index], resultOffsets,
+            resultSizes, resultStrides);
         ret.results.push_back(insertSlice);
       }
     }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
index 7069bed..a73a90b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
@@ -38,12 +38,18 @@
       return failure();
 
     // Tile the producer.
-    FailureOr<Value> tiledProducer = producerOp.generateResultTileValue(
-        rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes());
-    if (failed(tiledProducer))
+    FailureOr<TilingResult> tileAndFuseResult =
+        producerOp.generateResultTileValue(rewriter, /*resultNumber=*/0,
+                                           sliceOp.getMixedOffsets(),
+                                           sliceOp.getMixedSizes());
+    if (failed(tileAndFuseResult))
       return failure();
-    fusedOps.push_back(cast<TilingInterface>(tiledProducer->getDefiningOp()));
+    for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
+      auto interfaceOp = dyn_cast<TilingInterface>(tileAndFusedOp);
+      if (!interfaceOp)
+        continue;
+      fusedOps.push_back(interfaceOp);
+    }
   }
 
   // Update the consumer in-place using the tiled producer results.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
index de4bff6..7a6b068 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -195,15 +195,16 @@
 
 /// Second pattern to implement the switch of `TilingInterface ->
 /// tensor.extract_slice` to `tensor.extract_slice -> `TilingInterface`.
-FailureOr<Operation *> SwapTilingInterfaceOp::returningMatchAndRewrite(
+FailureOr<TilingResult> SwapTilingInterfaceOp::returningMatchAndRewrite(
     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
   auto sourceOp = sliceOp.getSource().getDefiningOp<TilingInterface>();
   if (!sourceOp)
     return failure();
-  SmallVector<Operation *> tiledOps = sourceOp.getTiledImplementation(
+  FailureOr<TilingResult> tilingResult = sourceOp.getTiledImplementation(
       rewriter, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes());
-  assert(tiledOps.size() && "expected single tiled op");
-  Operation *tiledOp = tiledOps.front();
-  rewriter.replaceOp(sliceOp, tiledOp->getResults());
-  return tiledOp;
+  if (failed(tilingResult)) {
+    return failure();
+  }
+  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+  return tilingResult.value();
 }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index ddbca39..474f4ab 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -295,10 +295,9 @@
     }
 
     // Add more fusion candidates to the worklist.
-    if (auto fusedProducerOp =
-            fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
-      addCandidateSlices(fusedProducerOp, candidates);
-      tiledOps.push_back(fusedProducerOp);
+    for (auto tiledOp : fusedProducer->tiledOps) {
+      addCandidateSlices(tiledOp, candidates);
+      tiledOps.push_back(tiledOp);
     }
   }
 
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 75e86f4..411b1d8 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 75e86f472075c3c34a5b2466fa42e3248a432f78
+Subproject commit 411b1d8f079533860a990ee615abae3b0e6dbd8b
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 34b9f70..1f096a7 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 34b9f70f784ff67797fffddf6aad94a44321a09f
+Subproject commit 1f096a793ab7f73ae8f62deb8b6502c543763ca1