Integrate llvm-project at 0c852dc88e9276b74532fd7d233dd23ec1bbed6f (#13486)
* Reset third_party/llvm-project:
0c852dc88e9276b74532fd7d233dd23ec1bbed6f (2023-05-09 07:42:20 +0000):
[clang][dataflow][NFC] Remove `SkipPast` param from `getValue(const
ValueDecl &)`.
* Reset third_party/mlir-hlo: 768b11e9baa9b2e8b153b84c6e2403b048935dfe
(2023-05-09 08:24:44 -0700): [XLA:CPU Next] Outline reduce fusion
clusters.
Update ops implementing the `CallOpInterface`.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
index bbe2abd..81f19da 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
@@ -440,7 +440,7 @@
// CHECK: scf.for
// CHECK: scf.for
// CHECK: vector.outerproduct
-// CHECK: %{{.+}} = "tosa.apply_scale"({{.+}}) {double_round = true} : (vector<8x32xi32>, vector<8x32xi32>, vector<8x32xi8>) -> vector<8x32xi32>
+// CHECK: %{{.+}} = "tosa.apply_scale"({{.+}}) <{double_round = true}> : (vector<8x32xi32>, vector<8x32xi32>, vector<8x32xi8>) -> vector<8x32xi32>
// -----
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
index a08ff4a..7dfb916 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
@@ -42,7 +42,7 @@
using transform::LowerTransferOp;
using transform::LowerTransposeOp;
using transform::MatchOp;
-using transform::SplitHandlesOp;
+using transform::SplitHandleOp;
using transform::SplitTransferFullPartialOp;
using transform::TransferToScfOp;
using transform_ext::AllDims;
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
index 007a1d2..2992a9a 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
@@ -39,7 +39,7 @@
using transform::MergeHandlesOp;
using transform::PrintOp;
using transform::SequenceOp;
-using transform::SplitHandlesOp;
+using transform::SplitHandleOp;
using transform::SplitReductionOp;
using transform::TileToForallOp;
using transform::VectorizeOp;
@@ -53,8 +53,8 @@
auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
MatchingArgs... args) {
Value matchedH = b.create<MatchOp>(targetH, args...);
- auto matchOp = b.create<SplitHandlesOp>(matchedH,
- /*numHandles=*/N);
+ auto matchOp = b.create<SplitHandleOp>(matchedH,
+ /*numHandles=*/N);
assert(matchOp->getNumResults() == N && "Unexpected number of results");
std::array<Value, N> a;
for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 1738937..7f081ad 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -947,6 +947,11 @@
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
+
// StreamableOpInterface:
bool isTransfer() { return false; }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index d7e570f..9ef9324 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -2156,6 +2156,11 @@
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
+
Value getOperandSize(unsigned idx) {
return findValueSizeInList(idx, getOperands(), getResourceOperandSizes());
}
@@ -2816,6 +2821,11 @@
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
+
Value getOperandSize(unsigned idx) {
return findValueSizeInList(idx, getOperands(), getResourceOperandSizes());
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index a2ce605..6b63555 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -3798,6 +3798,11 @@
CallInterfaceCallable getCallableForCallee() {
return getOperation()->getAttrOfType<FlatSymbolRefAttr>("callee");
}
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
}];
}
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir b/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
index 449b01e..1b811d3 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
+++ b/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
@@ -7,7 +7,7 @@
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x1xi32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<1x2x1xi32>)
- // CHECK-DAG: %[[CONCAT:.+]] = "tosa.concat"(%[[FILL]], %[[EXPANDIDX]]) {axis = 2 : i64}
+ // CHECK-DAG: %[[CONCAT:.+]] = "tosa.concat"(%[[FILL]], %[[EXPANDIDX]]) <{axis = 2 : i64}>
// CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
// CHECK-SAME{literal}: [[0, 1], [2]] : tensor<1x2x2xi32> into tensor<2x2xi32>
// CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
@@ -26,13 +26,12 @@
// -----
-
// CHECK-LABEL: @scatter_static_batched
func.func @scatter_static_batched(%arg0 : tensor<2x4x5xf32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<2x2x5xf32>) -> tensor<2x4x5xf32> {
// CHECK: %[[EXPANDIDX:.+]] = tensor.expand_shape %arg1
// CHECK-SAME{literal}: [[0], [1, 2]] : tensor<2x2xi32> into tensor<2x2x1xi32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x1xi32>
- // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[EXPANDIDX]] : tensor<2x2x1xi32>)
@@ -40,7 +39,7 @@
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] : index to i32
// CHECK: linalg.yield %[[CAST]] : i32
- // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPANDIDX]]) {axis = 2 : i64}
+ // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPANDIDX]]) <{axis = 2 : i64}>
// CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
// CHECK-SAME{literal}: [[0, 1], [2]] : tensor<2x2x2xi32> into tensor<4x2xi32>
// CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
@@ -67,15 +66,15 @@
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPAND]], %[[C0]] : tensor<?x?x1xi32>
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPAND]], %[[C1]] : tensor<?x?x1xi32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x1xi32>
- // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x?x1xi32>) outs(%[[EMPTY]] : tensor<?x?x1xi32>) {
- // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPAND]]) {axis = 2 : i64}
+ // CHECK: %[[CONCAT:.+]] = "tosa.concat"(%[[GENERIC]], %[[EXPAND]]) <{axis = 2 : i64}>
// CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
// CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x2xi32> into tensor<?x2xi32>
// CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
// CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
- // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
- // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<?x?xf32>, tensor<?x2xi32>) outs(%arg0 : tensor<?x?x?xf32>)
+ // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+ // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<?x?xf32>, tensor<?x2xi32>) outs(%arg0 : tensor<?x?x?xf32>)
%0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
// CHECK: return %[[SCATTER]]
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
index c678ec4..2b02c8c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -20,7 +20,7 @@
%0 = pdl_match @some_operation in %arg1 : (!pdl.operation) -> !pdl.operation
// Make sure we don't crash on wrong operation type.
// expected-error@below {{failed to outline}}
- transform.loop.outline %0 {func_name = "outlined"} : (!pdl.operation) -> !pdl.operation
+ transform.loop.outline %0 {func_name = "outlined"} : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}
}
}
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
index 7fa9a71..3855f35 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -7,7 +7,7 @@
// Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
// ===========================================================================
%0 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %eltwise, %reduction = transform.split_handles %0 in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ %eltwise, %reduction = transform.split_handle %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
%init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
transform.structured.split_reduction %reduction
{ split_factor = 2, insert_split_dimension = 1 }
@@ -27,11 +27,11 @@
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
transform.iree.apply_patterns %func { bubble_expand } : (!pdl.operation) -> ()
%fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %fill_2, %more_parallel_fill_2 = transform.split_handles %fills in [2]
+ %fill_2, %more_parallel_fill_2 = transform.split_handle %fills
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)
%generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%expanded_eltwise, %more_parallel_2, %combiner_2 =
- transform.split_handles %generics in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+ transform.split_handle %generics : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
%forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%not_combiner = transform.merge_handles %fill_2, %more_parallel_fill_2, %more_parallel_2, %expanded_eltwise : !pdl.operation
transform.structured.fuse_into_containing_op %not_combiner into %forall_grid_2
@@ -41,17 +41,17 @@
// ===========================================================================
%fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_block_combiner_op, %block_combiner_op =
- transform.structured.tile_to_forall_op %combiner_2 tile_sizes [1]
+ transform.structured.tile_to_forall_op %combiner_2 tile_sizes [1]
( mapping = [#gpu.thread<z>] )
transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op
%fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
- %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
+ %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]} : (!pdl.operation) -> !pdl.operation
+ %grid_eltwise_op = transform.structured.match ops{["linalg.generic"]}
attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_block_more_parallel_op, %block_more_parallel_op =
- transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
+ transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
@@ -64,27 +64,26 @@
// Step 5. Bufferize and drop HAL decriptor from memref ops.
// ===========================================================================
- %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
- %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+ transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
+ %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+ transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
// Step 6. Post-bufferization mapping to blocks and threads.
// ===========================================================================
- %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- %func_5 = transform.iree.forall_to_workgroup %func_4
- %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
- { workgroup_size = [32, 2, 1] }
+ %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+ transform.iree.forall_to_workgroup %func_4 : (!pdl.operation) -> ()
+ transform.iree.map_nested_forall_to_gpu_threads %func_4 workgroup_dims = [32, 2, 1] : (!pdl.operation) -> ()
// Step 7. Post-bufferization vector distribution with rank-reduction.
// ===========================================================================
- transform.iree.apply_patterns %func_6 { rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
+ transform.iree.apply_patterns %func_4 { rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
// Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
// at this point.
- transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
+ transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
^bb0(%arg0: !pdl.operation):
transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
}
- transform.iree.vector.warp_distribute %func_6
+ transform.iree.vector.warp_distribute %func_4 : (!pdl.operation) -> ()
}
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index 12e549d..3931171 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -7,7 +7,7 @@
// Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
// ===========================================================================
%0 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %leading_eltwise, %reduction, %trailing_eltwise = transform.split_handles %0 in [3]
+ %leading_eltwise, %reduction, %trailing_eltwise = transform.split_handle %0
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
%init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
transform.structured.split_reduction %reduction
@@ -29,13 +29,13 @@
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
transform.iree.apply_patterns %func { bubble_expand } : (!pdl.operation) -> ()
%fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %fill_2, %more_parallel_fill_2 = transform.split_handles %fills in [2]
+ %fill_2, %more_parallel_fill_2 = transform.split_handle %fill
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)
%generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%expanded_eltwise, %more_parallel_2, %combiner_2, %trailing_eltwise_2 =
- transform.split_handles %generics in [4]
+ transform.split_handle %generics
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
- %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op
+ %forall_grid_2 = transform.structured.match ops{["scf.forall"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%not_trailing = transform.merge_handles %fill_2, %more_parallel_fill_2,
%more_parallel_2, %expanded_eltwise, %combiner_2 : !pdl.operation
transform.structured.fuse_into_containing_op %not_trailing into %forall_grid_2
@@ -45,7 +45,7 @@
// ===========================================================================
%fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_trailing_eltwise_op, %block_trailing_eltwise_op =
- transform.structured.tile_to_forall_op %trailing_eltwise_2 tile_sizes [1]
+ transform.structured.tile_to_forall_op %trailing_eltwise_2 tile_sizes [1]
( mapping = [#gpu.thread<z>] )
%block_combiner_op = transform.structured.match ops{["linalg.generic"]}
attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation
@@ -58,7 +58,7 @@
%grid_eltwise_op = transform.structured.match ops{["linalg.generic"]}
attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>]} in %variant_op : (!pdl.operation) -> !pdl.operation
%forall_block_more_parallel_op, %block_more_parallel_op =
- transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
+ transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
transform.structured.fuse_into_containing_op %grid_eltwise_op into %forall_block_more_parallel_op
@@ -66,32 +66,31 @@
// Step 4. Rank-reduce and vectorize.
// ===========================================================================
%func_1 = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %func_2 = transform.iree.apply_patterns %func_1 { rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
- %func_3 = transform.structured.vectorize %func_2
+ transform.iree.apply_patterns %func_1 { rank_reducing_linalg, rank_reducing_vector } : (!pdl.operation) -> ()
+ %func_2 = transform.structured.vectorize %func_1
// Step 5. Bufferize and drop HAL decriptor from memref ops.
// ===========================================================================
- %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
- %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+ transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
+ %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+ transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
// Step 6. Post-bufferization mapping to blocks and threads.
// ===========================================================================
- %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- %func_5 = transform.iree.forall_to_workgroup %func_4
- %func_6 = transform.iree.map_nested_forall_to_gpu_threads %func_5
- { workgroup_size = [32, 2, 1] }
+ %func_3 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
+ transform.iree.forall_to_workgroup %func_3 : (!pdl.operation) -> ()
+ transform.iree.map_nested_forall_to_gpu_threads %func_3 workgroup_dims = [32, 2, 1] : (!pdl.operation) -> ()
// Step 7. Post-bufferization vector distribution with rank-reduction.
// ===========================================================================
- %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
+ transform.iree.apply_patterns %func_3 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
// Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
// at this point.
- transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
+ transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
^bb0(%arg0: !pdl.operation):
transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
}
- transform.iree.vector.warp_distribute %func_7
+ transform.iree.vector.warp_distribute %func_3 : (!pdl.operation) -> ()
}
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
index 798fbe4..a60b7ee 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
@@ -8,7 +8,7 @@
%fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%matmul = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %variant_op : (!pdl.operation) -> !pdl.operation
%generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
- %reduce, %broadcast = transform.split_handles %generics in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ %reduce, %broadcast = transform.split_handle %generics : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
// Step 2. Tile the matmul and fuse the fill
// ===========================================================================
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
index 26f04c0..995cb5e 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_dispatch_spec.mlir
@@ -6,7 +6,7 @@
in %variant_op : (!pdl.operation) -> !pdl.operation
%fill0, %fill1, %matmul, %reduce, %broadcast =
- transform.split_handles %ops in [5]
+ transform.split_handle %ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation)
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 2c5acfe..4543ede 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -2,14 +2,14 @@
transform.sequence failures(propagate) {
^bb1(%variant_op: !pdl.operation):
- %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
: (!pdl.operation) -> !pdl.operation
-
+
// Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism.
// ===========================================================================
- %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op
+ %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op
: (!pdl.operation) -> !pdl.operation
- %reduction, %eltwise = transform.split_handles %0 in [2]
+ %reduction, %eltwise = transform.split_handle %0
: (!pdl.operation) -> (!pdl.operation, !pdl.operation)
%init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op =
transform.structured.split_reduction %reduction
@@ -22,10 +22,10 @@
// Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
// trailing elementwise the same way we want to tile the reduction.
// ===========================================================================
- %grid_loop, %eltwise_grid_op = transform.structured.tile_to_forall_op %eltwise
+ %grid_loop, %eltwise_grid_op = transform.structured.tile_to_forall_op %eltwise
tile_sizes [1] (mapping = [#gpu.block<x>])
transform.iree.populate_workgroup_count_region_using_num_threads_slice %grid_loop : (!pdl.operation) -> ()
- %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op
+ %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op
: !pdl.operation
transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop
@@ -35,13 +35,13 @@
// Step 3. Second level of tiling + fusion parallelizes to threads.
// ===========================================================================
- %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op
+ %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op
: (!pdl.operation) -> !pdl.operation
%eltwise_block_loop, %eltwise_block_op =
transform.structured.tile_to_forall_op %eltwise_grid_op tile_sizes [1]
( mapping = [#gpu.thread<z>] )
%block_combiner_op = transform.structured.match ops{["linalg.generic"]}
- attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
+ attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
: (!pdl.operation) -> !pdl.operation
%combined_and_fill = transform.merge_handles %fill_1d, %block_combiner_op : !pdl.operation
transform.structured.fuse_into_containing_op %combined_and_fill into %eltwise_block_loop
@@ -50,13 +50,13 @@
transform.iree.apply_patterns %variant_op
{ canonicalization, tiling_canonicalization, licm, cse } : (!pdl.operation) -> ()
- %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
+ %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
: (!pdl.operation) -> !pdl.operation
%grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]}
- attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
+ attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op
: (!pdl.operation) -> !pdl.operation
%forall_block_more_parallel_op, %block_more_parallel_op =
- transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
+ transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1]
( mapping = [#gpu.thread<z>, #gpu.thread<y>] )
transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op
@@ -75,7 +75,7 @@
transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!pdl.operation) -> ()
transform.iree.eliminate_empty_tensors %variant_op: (!pdl.operation) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> (!pdl.operation)
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
+ %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
: (!pdl.operation) -> !pdl.operation
transform.iree.erase_hal_descriptor_type_from_memref %memref_func: (!pdl.operation) -> ()
@@ -89,7 +89,7 @@
// Step 7. Post-bufferization vector distribution with rank-reduction.
// ===========================================================================
transform.iree.apply_patterns %func_5 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
: (!pdl.operation) -> !pdl.operation
// Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
// at this point.
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 3f9b939..6e375a5 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -10,7 +10,7 @@
%exps_sum_fill,
%exps,
%exps_sum,
- %div = transform.split_handles %ops in [6]
+ %div = transform.split_handle %ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation, !pdl.operation)
@@ -44,7 +44,7 @@
%tiled_exps_sum_fill,
%tiled_exp_and_exps_sum,
%tiled_exp_and_exps_sum_2,
- %tiled_div = transform.split_handles %tiled_ops in [6]
+ %tiled_div = transform.split_handle %tiled_ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation, !pdl.operation)
// Leaving the reduction untiled on threadIdx.x makes it sequential on
@@ -73,17 +73,16 @@
// Step 4. Bufferize and drop HAL decriptor from memref ops.
// =========================================================
- %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
- %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+ transform.iree.eliminate_empty_tensors %variant_op : (!pdl.operation) -> ()
+ %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!pdl.operation) -> !pdl.operation
%memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- transform.iree.erase_hal_descriptor_type_from_memref %memref_func
+ transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!pdl.operation) -> ()
// Step 5. Post-bufferization mapping to blocks and threads.
// =========================================================
%func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation
- %func_3 = transform.iree.forall_to_workgroup %func_2
- transform.iree.map_nested_forall_to_gpu_threads %func_3
- { workgroup_size = [32, 4, 1] }
+ transform.iree.forall_to_workgroup %func_2 : (!pdl.operation) -> ()
+ transform.iree.map_nested_forall_to_gpu_threads %func_2 workgroup_dims = [32, 4, 1] : (!pdl.operation) -> ()
// Step 6. Post-bufferization vector distribution with rank-reduction.
// ===================================================================
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
index 804ce2d..55105a6 100644
--- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -7,7 +7,7 @@
in %variant_op : (!pdl.operation) -> !pdl.operation
%input_max_fill, %input_max, %exps_sum_fill, %exps, %exps_sum, %div =
- transform.split_handles %ops in [6]
+ transform.split_handle %ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation, !pdl.operation)
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 6682418..e09f305 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -9,14 +9,14 @@
%input_max,
%exps_sum_fill,
%exp_and_exps_sum,
- %div = transform.split_handles %ops in [5]
+ %div = transform.split_handle %ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation)
// Step 1. First level of tiling + fusion parallelizes to blocks.
// ==============================================================
%forall, %_ =
- transform.structured.tile_to_forall_op %div tile_sizes [1, 4]
+ transform.structured.tile_to_forall_op %div tile_sizes [1, 4]
( mapping = [#gpu.block<x>, #gpu.block<y>] )
transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall : (!pdl.operation) -> ()
@@ -36,7 +36,7 @@
// Canonicalizations.
transform.iree.apply_patterns %variant_op
{ canonicalization, tiling_canonicalization, licm, cse } : (!pdl.operation) -> ()
-
+
// Step 2. Second level of tiling + fusion parallelizes to threads.
// ================================================================
@@ -46,7 +46,7 @@
%tiled_input_max,
%tiled_exps_sum_fill,
%tiled_exp_and_exps_sum,
- %tiled_div = transform.split_handles %tiled_ops in [5]
+ %tiled_div = transform.split_handle %tiled_ops
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,
!pdl.operation, !pdl.operation)
// Leaving the reduction untiled on threadIdx.x makes it sequential on
@@ -92,7 +92,7 @@
// Step 6. Post-bufferization vector distribution with rank-reduction.
// ===================================================================
transform.iree.apply_patterns %memref_func { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } : (!pdl.operation) -> ()
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
: (!pdl.operation) -> !pdl.operation
%warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
transform.iree.vector.warp_distribute %memref_func
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1eb7d58..0c852dc 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1eb7d5865292e0b384e4f97977374399cdcd944d
+Subproject commit 0c852dc88e9276b74532fd7d233dd23ec1bbed6f
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index db8ea95..768b11e 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit db8ea95cc7349d14ba53c35f93b4f1f77b6ad667
+Subproject commit 768b11e9baa9b2e8b153b84c6e2403b048935dfe