Disabling inlining on the torch async function. (#16739)
Without this we end up inlining the async function into the sync
function and effectively doubling the compile time of global opt/flow
stages as we can't discard the async function we inlined. This ensures
that the sync function remains a simple wrapper that calls the async
function.
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index 00a1b05..6b4cdbe 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -598,10 +598,8 @@
convertedFuncInfo.funcOp = asyncFuncOp;
asyncFuncOp.setSymVisibilityAttr(torchFunc.getSymVisibilityAttr());
// Handle defacto attrs to specialized ones.
- if (torchFunc->hasAttr("noinline")) {
- asyncFuncOp.setInliningPolicyAttr(
- rewriter.getAttr<IREE::Util::InlineNeverAttr>());
- }
+ asyncFuncOp.setInliningPolicyAttr(
+ rewriter.getAttr<IREE::Util::InlineNeverAttr>());
retainFunctionAttributes(torchFunc, asyncFuncOp);
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
index 09ce2b4..d64bc2c 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir
@@ -5,10 +5,11 @@
// There shouldn't be much need to further verify the sync wrapper function.
// CHECK-LABEL: @immutable_import_export
// CHECK: util.func public @main$async(
-// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
-// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) ->
-// CHECK-SAME: (!hal.buffer_view, !hal.buffer_view)
-// CHECK-SAME: attributes {iree.abi.model = "coarse-fences", iree.abi.stub}
+// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
+// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) ->
+// CHECK-SAME: (!hal.buffer_view, !hal.buffer_view)
+// CHECK-SAME: iree.abi.model = "coarse-fences"
+// CHECK-SAME: iree.abi.stub
// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import wait(%arg2) => %arg0 : !hal.buffer_view -> tensor<4x5xi32>
// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import wait(%arg2) => %arg1 : !hal.buffer_view -> tensor<5x4xf32>
// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]] : tensor<4x5xi32> -> !torch.vtensor<[4,5],si32>
@@ -22,8 +23,9 @@
// CHECK-DAG: %[[FUNC_RESULT1:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1
// CHECK: util.return %[[FUNC_RESULT0]], %[[FUNC_RESULT1]]
//
-// CHECK: util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
-// CHECK-SAME: -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub}
+// CHECK: util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
+// CHECK-SAME: -> (!hal.buffer_view, !hal.buffer_view)
+// CHECK-SAME: iree.abi.stub
// CHECK-DAG: %[[CONSTANT0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CONSTANT1:.+]] = arith.constant -1 : i32
// CHECK-DAG: %[[DEVICE0:.+]] = hal.devices.get %[[CONSTANT0]] : !hal.device
@@ -33,7 +35,7 @@
// CHECK: %[[AWAIT_STATUS:.+]] = hal.fence.await until([%[[NEW_FENCE]]]) timeout_millis(%[[CONSTANT1]])
// CHECK: util.return %[[CALL_RESULTS]]#0, %[[CALL_RESULTS]]#1 : !hal.buffer_view, !hal.buffer_view
builtin.module @immutable_import_export {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.vtensor<[5,4],f32>)
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.vtensor<[5,4],f32>)
-> (!torch.vtensor<[4,5],si32>, !torch.vtensor<[5,4],f32>) {
%0 = torch.operator "foobar0"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
%1 = torch.operator "foobar1"(%arg1) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
@@ -61,7 +63,7 @@
// immutable.
// CHECK-LABEL: @mutable_input_overwrite_no_return
// CHECK: util.func public @main$async(
-// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
+// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view
// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import wait(%arg2) => %arg0
// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]]
@@ -77,7 +79,7 @@
// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1 :
// CHECK: util.return %[[EXPORT_RESULT0]]
builtin.module @mutable_input_overwrite_no_return {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f32>)
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f32>)
-> (!torch.vtensor<[4,5],si32>) {
%0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32>
%1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
@@ -95,7 +97,7 @@
// Not a good idea to do but legal. This verifies that if returning a mutated
// tensor's intermediate value, you will get two exports, indicating a copy.
// CHECK-LABEL: @mutable_input_overwrite_return_alias_copies
-// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%{{.*}}, %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>)
+// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%{{.*}}, %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>)
// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#0 into(%arg0 : !hal.buffer_view)
// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#1 :
builtin.module @mutable_input_overwrite_return_alias_copies {
@@ -114,11 +116,11 @@
// CHECK: util.func public @main(
// CHECK-SAME: iree.reflection = {some.attr = 4 : index}
builtin.module @retained_attribute_reflection {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
attributes {
iree.reflection = {
some.attr = 4 : index
- }
+ }
}
{
return %arg0 : !torch.vtensor<[4,5],si32>
@@ -130,7 +132,7 @@
// CHECK: util.func public @main$async(
// CHECK-NOT: iree.nonretained
builtin.module @retained_attribute_ignored {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
attributes {
iree.nonretained = "dummy"
}
@@ -146,7 +148,7 @@
// CHECK: util.func public @main(
// CHECK-NOT: inlining_policy
builtin.module @retained_attribute_noinline {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
attributes {
noinline
}
@@ -160,7 +162,7 @@
// CHECK: util.func private @main$async
// CHECK: util.func private @main
builtin.module @private_visibility {
-func.func private @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
+func.func private @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
{
return %arg0 : !torch.vtensor<[4,5],si32>
}
@@ -172,7 +174,7 @@
// CHECK: util.func public @main(%arg0: !hal.buffer_view) -> !hal.buffer_view
// CHECK: = util.call @main$async{{.*}} -> %arg0
builtin.module @tied_operand {
-func.func @main(%arg0: !torch.vtensor<[4,5],si32>) ->
+func.func @main(%arg0: !torch.vtensor<[4,5],si32>) ->
(!torch.vtensor<[4,5],si32> {iree.abi.tied = 0})
{
return %arg0 : !torch.vtensor<[4,5],si32>
@@ -185,7 +187,7 @@
// CHECK: hal.buffer_view.dim<%arg0
// CHECK: hal.buffer_view.dim<%arg1
builtin.module @immutable_import_export {
-func.func @main(%arg0: !torch.vtensor<[4,?],si32>, %arg1: !torch.vtensor<[?,4],f32>)
+func.func @main(%arg0: !torch.vtensor<[4,?],si32>, %arg1: !torch.vtensor<[?,4],f32>)
-> (!torch.vtensor<[4,?],si32>, !torch.vtensor<[?,4],f32>) {
%0 = torch.operator "foobar0"(%arg0) : (!torch.vtensor<[4,?],si32>) -> !torch.vtensor<[4,?],si32>
%1 = torch.operator "foobar1"(%arg1) : (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>