Integrate llvm-project at 0c852dc88e9276b74532fd7d233dd23ec1bbed6f (#13486)

* Reset third_party/llvm-project:
0c852dc88e9276b74532fd7d233dd23ec1bbed6f (2023-05-09 07:42:20 +0000):
[clang][dataflow][NFC] Remove `SkipPast` param from `getValue(const
ValueDecl &)`.
* Reset third_party/mlir-hlo: 768b11e9baa9b2e8b153b84c6e2403b048935dfe
(2023-05-09 08:24:44 -0700): [XLA:CPU Next] Outline reduce fusion
clusters.

Update ops implementing the `CallOpInterface`.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
index bbe2abd..81f19da 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
@@ -440,7 +440,7 @@
 //          CHECK:     scf.for
 //          CHECK:       scf.for
 //          CHECK:         vector.outerproduct
-//          CHECK:       %{{.+}} = "tosa.apply_scale"({{.+}}) {double_round = true} : (vector<8x32xi32>, vector<8x32xi32>, vector<8x32xi8>) -> vector<8x32xi32>
+//          CHECK:       %{{.+}} = "tosa.apply_scale"({{.+}}) <{double_round = true}> : (vector<8x32xi32>, vector<8x32xi32>, vector<8x32xi8>) -> vector<8x32xi32>
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
index a08ff4a..7dfb916 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
@@ -42,7 +42,7 @@
 using transform::LowerTransferOp;
 using transform::LowerTransposeOp;
 using transform::MatchOp;
-using transform::SplitHandlesOp;
+using transform::SplitHandleOp;
 using transform::SplitTransferFullPartialOp;
 using transform::TransferToScfOp;
 using transform_ext::AllDims;
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
index 007a1d2..2992a9a 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
@@ -39,7 +39,7 @@
 using transform::MergeHandlesOp;
 using transform::PrintOp;
 using transform::SequenceOp;
-using transform::SplitHandlesOp;
+using transform::SplitHandleOp;
 using transform::SplitReductionOp;
 using transform::TileToForallOp;
 using transform::VectorizeOp;
@@ -53,8 +53,8 @@
 auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
                     MatchingArgs... args) {
   Value matchedH = b.create<MatchOp>(targetH, args...);
-  auto matchOp = b.create<SplitHandlesOp>(matchedH,
-                                          /*numHandles=*/N);
+  auto matchOp = b.create<SplitHandleOp>(matchedH,
+                                         /*numHandles=*/N);
   assert(matchOp->getNumResults() == N && "Unexpected number of results");
   std::array<Value, N> a;
   for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 1738937..7f081ad 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -947,6 +947,11 @@
       return (*this)->getAttrOfType<SymbolRefAttr>("callee");
     }
 
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
+
     // StreamableOpInterface:
     bool isTransfer() { return false; }
 
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index d7e570f..9ef9324 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -2156,6 +2156,11 @@
       return (*this)->getAttrOfType<SymbolRefAttr>("callee");
     }
 
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
+
     Value getOperandSize(unsigned idx) {
       return findValueSizeInList(idx, getOperands(), getResourceOperandSizes());
     }
@@ -2816,6 +2821,11 @@
       return (*this)->getAttrOfType<SymbolRefAttr>("callee");
     }
 
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
+
     Value getOperandSize(unsigned idx) {
       return findValueSizeInList(idx, getOperands(), getResourceOperandSizes());
     }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index a2ce605..6b63555 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -3798,6 +3798,11 @@
     CallInterfaceCallable getCallableForCallee() {
       return getOperation()->getAttrOfType<FlatSymbolRefAttr>("callee");
     }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
   }];
 }
 
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir b/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
index 449b01e..1b811d3 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
+++ b/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
@@ -7,7 +7,7 @@
   // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x1xi32>
   // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
   // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<1x2x1xi32>)
-  // CHECK-DAG: %[[CONCAT:.+]] = "tosa.concat"(%[[FILL]], %[[EXPANDIDX]]) {axis = 2 : i64}
+  // CHECK-DAG: %[[CONCAT:.+]] = "tosa.concat"(%[[FILL]], %[[EXPANDIDX]]) <{axis = 2 : i64}>
   // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
   // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<1x2x2xi32> into tensor<2x2xi32>
   // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
@@ -26,13 +26,12 @@
 
 // -----
 
-
 // CHECK-LABEL: @scatter_static_batched
 func.func @scatter_static_batched(%arg0 : tensor<2x4x5xf32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<2x2x5xf32>) ->  tensor<2x4x5xf32> {
   // CHECK: %[[EXPANDIDX:.+]] = tensor.expand_shape %arg1
   // CHECK-SAME{literal}: [[0], [1, 2]] : tensor<2x2xi32> into tensor<2x2x1xi32>
   // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x1xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
   // CHECK-SAME: indexing_maps = [#map, #map]
   // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
   // CHECK-SAME: ins(%[[EXPANDIDX]] : tensor<2x2x1xi32>)
@@ -40,7 +39,7 @@
   // CHECK:   %[[IDX:.+]] = linalg.index 0 : index
   // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]] : index to i32
   // CHECK:   linalg.yield %[[CAST]] : i32
-  // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPANDIDX]]) {axis = 2 : i64}
+  // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPANDIDX]]) <{axis = 2 : i64}>
   // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
   // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<2x2x2xi32> into tensor<4x2xi32>
   // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
@@ -67,15 +66,15 @@
   // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPAND]], %[[C0]] : tensor<?x?x1xi32>
   // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPAND]], %[[C1]] : tensor<?x?x1xi32>
   // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x1xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
   // CHECK-SAME: ins(%[[EXPAND]] : tensor<?x?x1xi32>) outs(%[[EMPTY]] : tensor<?x?x1xi32>) {
-  // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPAND]]) {axis = 2 : i64}
+  // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPAND]]) <{axis = 2 : i64}>
   // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
   // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x2xi32> into tensor<?x2xi32>
   // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
   // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
-  // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter 
-  // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<?x?xf32>, tensor<?x2xi32>) outs(%arg0 : tensor<?x?x?xf32>) 
+  // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+  // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<?x?xf32>, tensor<?x2xi32>) outs(%arg0 : tensor<?x?x?xf32>)
   %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<?x?x?xf32>)  -> (tensor<?x?x?xf32>)
 
   // CHECK: return %[[SCATTER]]
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
index c678ec4..2b02c8c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -20,7 +20,7 @@
       %0 = pdl_match @some_operation in %arg1 : (!pdl.operation) -> !pdl.operation
       // Make sure we don't crash on wrong operation type.
       // expected-error@below {{failed to outline}}
-      transform.loop.outline %0 {func_name = "outlined"} : (!pdl.operation) -> !pdl.operation
+      transform.loop.outline %0 {func_name = "outlined"} : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
     }
   }
 }
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
index 7fa9a71..3855f35 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -7,7 +7,7 @@
   // Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
   // ===========================================================================
   %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %eltwise, %reduction = transform.split_handles %0 in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  %eltwise, %reduction = transform.split_handle %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   %init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
     transform.structured.split_reduction %reduction
       { split_factor = 2, insert_split_dimension = 1 }
@@ -27,11 +27,11 @@
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   transform.iree.apply_patterns %func { bubble_expand } : (!pdl.operation) -> ()
   %fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %fill_2, %more_parallel_fill_2 = transform.split_handles %fills in [2]
+  %fill_2, %more_parallel_fill_2 = transform.split_handle %fills
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %expanded_eltwise, %more_parallel_2, %combiner_2 =
-    transform.split_handles %generics in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+    transform.split_handle %generics : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
   %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %not_combiner = transform.merge_handles %fill_2, %more_parallel_fill_2, %more_parallel_2, %expanded_eltwise : !pdl.operation
   transform.structured.fuse_into_containing_op %not_combiner into %forall_grid_2
@@ -41,17 +41,17 @@
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %forall_block_combiner_op, %block_combiner_op =
-    transform.structured.tile_to_forall_op %combiner_2 tile_sizes [1] 
+    transform.structured.tile_to_forall_op %combiner_2 tile_sizes [1]
     ( mapping = [#gpu.thread<z>] )
   transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
-  %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
+  %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
+  %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %forall_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
   transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
@@ -64,27 +64,26 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
-  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+  transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+  transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_5 = transform.iree.forall_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
-      { workgroup_size = [32, 2, 1] }
+  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+  transform.iree.forall_to_workgroup %func_4 : (!pdl.operation) -> ()
+  transform.iree.map_nested_forall_to_gpu_threads %func_4 workgroup_dims = [32, 2, 1] : (!pdl.operation) -> ()
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  transform.iree.apply_patterns %func_6 {  rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
+  transform.iree.apply_patterns %func_4 {  rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
-  transform.iree.vector.warp_distribute %func_6
+  transform.iree.vector.warp_distribute %func_4 : (!pdl.operation) -> ()
 }
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index 12e549d..3931171 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -7,7 +7,7 @@
   // Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
   // ===========================================================================
   %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %leading_eltwise, %reduction, %trailing_eltwise = transform.split_handles %0 in [3]
+  %leading_eltwise, %reduction, %trailing_eltwise = transform.split_handle %0
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
   %init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
     transform.structured.split_reduction %reduction
@@ -29,13 +29,13 @@
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   transform.iree.apply_patterns %func { bubble_expand } : (!pdl.operation) -> ()
   %fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %fill_2, %more_parallel_fill_2 = transform.split_handles %fills in [2]
+  %fill_2, %more_parallel_fill_2 = transform.split_handle %fill
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %expanded_eltwise, %more_parallel_2, %combiner_2, %trailing_eltwise_2 =
-    transform.split_handles %generics in [4]
+    transform.split_handle %generics
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
-  %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op
+  %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %not_trailing = transform.merge_handles %fill_2, %more_parallel_fill_2,
     %more_parallel_2, %expanded_eltwise, %combiner_2 : !pdl.operation
   transform.structured.fuse_into_containing_op %not_trailing into %forall_grid_2
@@ -45,7 +45,7 @@
   // ===========================================================================
   %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
   %forall_trailing_eltwise_op, %block_trailing_eltwise_op =
-    transform.structured.tile_to_forall_op %trailing_eltwise_2 tile_sizes [1] 
+    transform.structured.tile_to_forall_op %trailing_eltwise_2 tile_sizes [1]
     ( mapping = [#gpu.thread<z>] )
   %block_combiner_op = transform.structured.match ops{["linalg.generic"]}
     attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
@@ -58,7 +58,7 @@
   %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]}
     attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %forall_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
   transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
@@ -66,32 +66,31 @@
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
   %func_1 = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %func_2 = transform.iree.apply_patterns %func_1 {  rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
-  %func_3 = transform.structured.vectorize %func_2
+  transform.iree.apply_patterns %func_1 {  rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
+  %func_2 = transform.structured.vectorize %func_1
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
-  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+  transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+  transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_5 = transform.iree.forall_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
-      { workgroup_size = [32, 2, 1] }
+  %func_3 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+  transform.iree.forall_to_workgroup %func_3 : (!pdl.operation) -> ()
+  transform.iree.map_nested_forall_to_gpu_threads %func_3 workgroup_dims = [32, 2, 1] : (!pdl.operation) -> ()
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
+  transform.iree.apply_patterns %func_3 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
-  transform.iree.vector.warp_distribute %func_7
+  transform.iree.vector.warp_distribute %func_3 : (!pdl.operation) -> ()
 }
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
index 798fbe4..a60b7ee 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
@@ -8,7 +8,7 @@
   %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %matmul = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %variant_op : (!pdl.operation) -> !pdl.operation
   %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
-  %reduce, %broadcast = transform.split_handles %generics in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  %reduce, %broadcast = transform.split_handle %generics : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
 
   // Step 2. Tile the matmul and fuse the fill
   // ===========================================================================
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
index 26f04c0..995cb5e 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
@@ -6,7 +6,7 @@
     in %variant_op : (!pdl.operation) -> !pdl.operation
 
   %fill0, %fill1, %matmul, %reduce, %broadcast =
-    transform.split_handles %ops in [5]
+    transform.split_handle %ops
       : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                              !pdl.operation, !pdl.operation)
 
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 2c5acfe..4543ede 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -2,14 +2,14 @@
 
 transform.sequence failures(propagate) {
 ^bb1(%variant_op: !pdl.operation):
-  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op 
+  %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
     : (!pdl.operation) -> !pdl.operation
-  
+
   // Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
   // ===========================================================================
-  %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op 
+  %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op
     : (!pdl.operation) -> !pdl.operation
-  %reduction, %eltwise = transform.split_handles %0 in [2] 
+  %reduction, %eltwise = transform.split_handle %0
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   %init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
     transform.structured.split_reduction %reduction
@@ -22,10 +22,10 @@
   // Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
   // trailing elementwise the same way we want to tile the reduction.
   // ===========================================================================
-  %grid_loop, %eltwise_grid_op = transform.structured.tile_to_forall_op %eltwise 
+  %grid_loop, %eltwise_grid_op = transform.structured.tile_to_forall_op %eltwise
     tile_sizes [1] (mapping = [#gpu.block<x>])
   transform.iree.populate_workgroup_count_region_using_num_threads_slice %grid_loop : (!pdl.operation) -> ()
-  %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op 
+  %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op
     : !pdl.operation
   transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop
 
@@ -35,13 +35,13 @@
 
   // Step 3. Second level of tiling + fusion parallelizes to threads.
   // ===========================================================================
-  %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op 
+  %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op
     : (!pdl.operation) -> !pdl.operation
   %eltwise_block_loop, %eltwise_block_op =
     transform.structured.tile_to_forall_op %eltwise_grid_op tile_sizes [1]
     ( mapping = [#gpu.thread<z>] )
   %block_combiner_op = transform.structured.match ops{["linalg.generic"]}
-    attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op 
+    attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
       : (!pdl.operation) -> !pdl.operation
   %combined_and_fill = transform.merge_handles %fill_1d, %block_combiner_op : !pdl.operation
   transform.structured.fuse_into_containing_op %combined_and_fill into %eltwise_block_loop
@@ -50,13 +50,13 @@
   transform.iree.apply_patterns %variant_op
     { canonicalization, tiling_canonicalization, licm, cse } : (!pdl.operation) -> ()
 
-  %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op 
+  %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
     : (!pdl.operation) -> !pdl.operation
   %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
-    attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op 
+    attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
       : (!pdl.operation) -> !pdl.operation
   %forall_block_more_parallel_op, %block_more_parallel_op =
-    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] 
+    transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
     ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
   transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
 
@@ -75,7 +75,7 @@
   transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!pdl.operation) -> ()
   transform.iree.eliminate_empty_tensors %variant_op: (!pdl.operation) -> ()
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> (!pdl.operation)
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
     : (!pdl.operation) -> !pdl.operation
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func: (!pdl.operation) -> ()
 
@@ -89,7 +89,7 @@
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
   transform.iree.apply_patterns %func_5 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
     : (!pdl.operation) -> !pdl.operation
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 3f9b939..6e375a5 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -10,7 +10,7 @@
   %exps_sum_fill,
   %exps,
   %exps_sum,
-  %div = transform.split_handles %ops in [6]
+  %div = transform.split_handle %ops
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                            !pdl.operation, !pdl.operation, !pdl.operation)
 
@@ -44,7 +44,7 @@
   %tiled_exps_sum_fill,
   %tiled_exp_and_exps_sum,
   %tiled_exp_and_exps_sum_2,
-  %tiled_div = transform.split_handles %tiled_ops in [6]
+  %tiled_div = transform.split_handle %tiled_ops
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                            !pdl.operation, !pdl.operation, !pdl.operation)
   // Leaving the reduction untiled on threadIdx.x makes it sequential on
@@ -73,17 +73,16 @@
 
   // Step 4. Bufferize and drop HAL decriptor from memref ops.
   // =========================================================
-  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
-  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+  transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
 
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
-  %func_3 = transform.iree.forall_to_workgroup %func_2
-  transform.iree.map_nested_forall_to_gpu_threads %func_3
-    { workgroup_size = [32, 4, 1] }
+  transform.iree.forall_to_workgroup %func_2 : (!pdl.operation) -> ()
+  transform.iree.map_nested_forall_to_gpu_threads %func_2 workgroup_dims = [32, 4, 1] : (!pdl.operation) -> ()
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
index 804ce2d..55105a6 100644
--- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -7,7 +7,7 @@
     in %variant_op : (!pdl.operation) -> !pdl.operation
 
   %input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum, %div =
-    transform.split_handles %ops in [6]
+    transform.split_handle %ops
       : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                              !pdl.operation, !pdl.operation, !pdl.operation)
 
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 6682418..e09f305 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -9,14 +9,14 @@
   %input_max,
   %exps_sum_fill,
   %exp_and_exps_sum,
-  %div = transform.split_handles %ops in [5]
+  %div = transform.split_handle %ops
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                            !pdl.operation, !pdl.operation)
 
   // Step 1. First level of tiling + fusion parallelizes to blocks.
   // ==============================================================
   %forall, %_ =
-  transform.structured.tile_to_forall_op %div tile_sizes [1, 4]  
+  transform.structured.tile_to_forall_op %div tile_sizes [1, 4]
     ( mapping = [#gpu.block<x>, #gpu.block<y>] )
   transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall : (!pdl.operation) -> ()
 
@@ -36,7 +36,7 @@
   // Canonicalizations.
   transform.iree.apply_patterns %variant_op
     { canonicalization, tiling_canonicalization, licm, cse } : (!pdl.operation) -> ()
-  
+
 
   // Step 2. Second level of tiling + fusion parallelizes to threads.
   // ================================================================
@@ -46,7 +46,7 @@
   %tiled_input_max,
   %tiled_exps_sum_fill,
   %tiled_exp_and_exps_sum,
-  %tiled_div = transform.split_handles %tiled_ops in [5]
+  %tiled_div = transform.split_handle %tiled_ops
     : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
                            !pdl.operation, !pdl.operation)
   // Leaving the reduction untiled on threadIdx.x makes it sequential on
@@ -92,7 +92,7 @@
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
   transform.iree.apply_patterns %memref_func { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
     : (!pdl.operation) -> !pdl.operation
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %memref_func
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1eb7d58..0c852dc 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1eb7d5865292e0b384e4f97977374399cdcd944d
+Subproject commit 0c852dc88e9276b74532fd7d233dd23ec1bbed6f
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index db8ea95..768b11e 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit db8ea95cc7349d14ba53c35f93b4f1f77b6ad667
+Subproject commit 768b11e9baa9b2e8b153b84c6e2403b048935dfe