Merge main -> google
* f0a07f63 Allow iree.placeholder as an op between splittable ops. (#2999)
* 6bd3b1dc Use hidden visibility for all symbols by default (#2997)
* d6500826 VMLA Dynamic Iota support with Shape dialect work. (#2965)
* c724ee45 Add function to hide SavedModel roundtrip (#2995)
* 2fedce64 Merge pull request #2996
* c5f7030b Merge branch 'main' into google-to-main
* cf3211d9 [vulkan] Reset TimePointFence status when releasing back to pool (#2905)
* ebc6a833 Add lint action to check for tabs (#2984)
* 932339ce Merge pull request #2964 from silvasean/add-dynamic-dot-example
* df727ae3 Integrate MLIR-EmitC at iml130/mlir-emitc@560cd8c (#2990)
* e824d364 Refactor for explicit dialect registration (#2978)
* 38b7d2cf Allow ModelRunner to receive an array of extra symbols available during JITing..
* 264a97df Fuse linalg.tensor_reshape operations with hal.interface* operations. (#2973)
* 0e3c7373 Fold flow.tensor.update when all operands are constant (#2982)
* c7a21c2a Merge pull request #2983 from rsuderman/google-to-main
* fd55a860 Add a dynamically shaped mhlo.dot lowering example
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/3003 from rsuderman:main-to-google f0a07f6316245209f3902e4d2a9a3d349e693794
PiperOrigin-RevId: 328574119
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index fff6e2f..6f98806 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -117,6 +117,7 @@
CLANG_OR_GCC
"-Wno-unused-parameter"
"-Wno-undef"
+ "-fvisibility=hidden"
MSVC_OR_CLANG_CL
"/DWIN32_LEAN_AND_MEAN"
"/wd4624"
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index e5d7330..659c1ef 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -31,11 +31,14 @@
"tf_load_saved_model",
"tf_load_signature_def_saved_model",
"tf_compile_saved_model",
+ "tf_module_to_compiler_module",
]
+import tempfile
from typing import Collection, Optional, Sequence
from . import binding as binding
+import tensorflow as tf
# Native aliases (matches those in the generic compiler).
llvm = binding.llvm
@@ -179,3 +182,34 @@
input_module = tf_load_saved_model(saved_model_dir, compiler_context,
exported_names, pass_pipeline)
return input_module.compile(target_backends=target_backends)
+
+
+def tf_module_to_compiler_module(module: tf.Module,
+ exported_names: Collection[str] = (),
+ sm_path: str = None):
+ """Converts a tf.Module into a MLIR module.
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ sm_path: the path to save the tf.Module to, if any. Defaults to None.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+
+ def _convert(sm_path):
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(module, sm_path, options=options)
+ return tf_load_saved_model(
+ sm_path, exported_names=exported_names, pass_pipeline=())
+
+ if sm_path is None:
+ with tempfile.TemporaryDirectory() as sm_path:
+ compiler_module = _convert(sm_path)
+ else:
+ compiler_module = _convert(sm_path)
+ return compiler_module
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 3137d81..b216236 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -19,7 +19,6 @@
import os
import random
import re
-import tempfile
from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
from absl import flags
@@ -155,73 +154,61 @@
artifacts_dir is provided.
"""
- def _compile_from_path(sm_path: str) -> compiler.binding.OpaqueBlob:
- """Helper function for compile_tf_module."""
- if artifacts_dir is not None:
- # Set up a crash reproducer for debugging.
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backends_string}.mlir")
- try:
- # We break up the compilation here so we can save intermediary artifacts.
- compiler_context = compiler.Context()
-
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_load_saved_model(
- sm_path,
- exported_names=exported_names,
- compiler_context=compiler_context,
- pass_pipeline=())
-
- if artifacts_dir is not None:
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- # Now run the passes manually that tf_load_saved_model would usually do.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- target_backends = []
- for backend_info in backend_infos:
- target_backends.extend(backend_info.compiler_targets)
- compiled_module = compiler_module.compile(target_backends=target_backends)
-
- compiled_path = None
- if artifacts_dir is not None:
- compiled_path = _get_backends_path("compiled", backend_infos,
- artifacts_dir)
- compiled_path = f"{compiled_path}.vmfb"
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
-
- return compiled_module, compiled_path
- except Exception: # pylint: disable=broad-except
- if artifacts_dir is not None:
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- compiler.Context.default_crash_reproducer_path = None
- raise
-
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- backends_string = backends_to_str(backend_infos)
if artifacts_dir is not None and FLAGS.keep_saved_model:
# Create a saved model for these target backends to avoid a race condition
# when running a test suite.
# TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
sm_path = _get_backends_path("saved_model", backend_infos, artifacts_dir)
- tf.saved_model.save(tf_module, sm_path, options=options)
- return _compile_from_path(sm_path)
else:
- # Round-trip the saved model through a temporary directory.
- with tempfile.TemporaryDirectory() as sm_path:
- tf.saved_model.save(tf_module, sm_path, options=options)
- return _compile_from_path(sm_path)
+ sm_path = None
+
+ if artifacts_dir is not None:
+ # Set up a crash reproducer for debugging.
+ backends_string = backends_to_str(backend_infos)
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backends_string}.mlir")
+
+ try:
+ # Convert the tf_module into raw TF input MLIR.
+ compiler_module = compiler.tf_module_to_compiler_module(
+ tf_module, exported_names, sm_path)
+
+ if artifacts_dir is not None:
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ # Now run the passes manually that tf_load_saved_model would usually do.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ target_backends = []
+ for backend_info in backend_infos:
+ target_backends.extend(backend_info.compiler_targets)
+ compiled_module = compiler_module.compile(target_backends=target_backends)
+
+ compiled_path = None
+ if artifacts_dir is not None:
+ compiled_path = _get_backends_path("compiled", backend_infos,
+ artifacts_dir)
+ compiled_path = f"{compiled_path}.vmfb"
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+
+ except Exception: # pylint: disable=broad-except
+ if artifacts_dir is not None:
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ compiler.Context.default_crash_reproducer_path = None
+ raise
+
+ return compiled_module, compiled_path
class CompiledModule(object):
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 38d13b4..d5e0fef 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -66,12 +66,14 @@
# keep sorted
LLVM_FAILING = [
+ "broadcast_to_test.py",
"broadcasting_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_test.py",
+ "range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
@@ -80,6 +82,7 @@
# keep sorted
VULKAN_FAILING = [
"bool_test.py",
+ "broadcast_to_test.py",
"broadcasting_test.py",
"control_flow_test.py",
"dynamic_mlp_relu_test.py",
@@ -87,6 +90,7 @@
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_test.py",
+ "range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
diff --git a/integrations/tensorflow/e2e/broadcast_to_test.py b/integrations/tensorflow/e2e/broadcast_to_test.py
new file mode 100644
index 0000000..6d57d6f
--- /dev/null
+++ b/integrations/tensorflow/e2e/broadcast_to_test.py
@@ -0,0 +1,49 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class BroadcastToModule(tf.Module):
+
+ def __init__(self):
+ pass
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([], tf.float32),
+ tf.TensorSpec([2], tf.int32)
+ ])
+ def scalar_broadcast_to(self, x, shape):
+ return tf.broadcast_to(x, shape)
+
+
+@tf_test_utils.compile_module(BroadcastToModule)
+class BroadcastToTest(tf_test_utils.TracedModuleTestCase):
+
+ def test_scalar_broadcast_to(self):
+
+ def scalar_broadcast_to(module):
+ x = np.array(1, dtype=np.float32)
+ shape = np.array([3, 3], dtype=np.int32)
+ result = module.scalar_broadcast_to(x, shape)
+
+ self.compare_backends(scalar_broadcast_to)
+
+
+if __name__ == "__main__":
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+ tf.test.main()
diff --git a/integrations/tensorflow/e2e/range_test.py b/integrations/tensorflow/e2e/range_test.py
new file mode 100644
index 0000000..f1e093c
--- /dev/null
+++ b/integrations/tensorflow/e2e/range_test.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class RangeModule(tf.Module):
+
+ def __init__(self):
+ pass
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([], tf.float32),
+ tf.TensorSpec([], tf.float32),
+ tf.TensorSpec([], tf.float32)
+ ])
+ def range(self, start, stop, delta):
+ return tf.range(start, stop, delta)
+
+
+@tf_test_utils.compile_module(RangeModule)
+class RangeTest(tf_test_utils.TracedModuleTestCase):
+
+ def test_range(self):
+
+ def range(module):
+ start = np.array(3., dtype=np.float32)
+ stop = np.array(12., dtype=np.float32)
+ delta = np.array(3, dtype=np.float32)
+ result = module.range(start, stop, delta)
+
+ self.compare_backends(range)
+
+
+if __name__ == "__main__":
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+ tf.test.main()
diff --git a/iree/base/dynamic_library_test_library.cc b/iree/base/dynamic_library_test_library.cc
index 1d237d2..4d99552 100644
--- a/iree/base/dynamic_library_test_library.cc
+++ b/iree/base/dynamic_library_test_library.cc
@@ -21,7 +21,7 @@
#if defined(_WIN32)
#define IREE_SYM_EXPORT __declspec(dllexport)
#else
-#define IREE_SYM_EXPORT
+#define IREE_SYM_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
IREE_API_EXPORT int IREE_SYM_EXPORT times_two(int value) { return value * 2; }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 177c0c0..9a98f6f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -68,7 +68,8 @@
for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
nextOp != ops.end(); ++currOp, ++nextOp) {
Operation *iter = (*currOp)->getNextNode();
- while (iter != *nextOp && MemoryEffectOpInterface::hasNoEffect(iter))
+ while (iter != *nextOp && (MemoryEffectOpInterface::hasNoEffect(iter) ||
+ isa<IREE::PlaceholderOp>(iter)))
iter = iter->getNextNode();
if (iter != *nextOp) return false;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 717f758..dbffa87 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -229,3 +229,48 @@
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x4xf32>
// CHECK: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4xf32>
// CHECK: linalg.generic {{.*}} %[[IN]], %[[OUT]]
+
+// -----
+
+module {
+ func @predict_ex_dispatch_0() {
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x512x1xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1} : memref<4x8x16xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x512x1xf32>
+ linalg.copy(%2, %0) : memref<1x512x1xf32>, memref<1x512x1xf32>
+ %3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<4x8x16xf32>
+ linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1, d2) -> (-d0 + 3, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]} %3, %1 {
+ ^bb0(%arg0: f32, %arg1: f32): // no predecessors
+ linalg.yield %arg0 : f32
+ }: memref<4x8x16xf32>, memref<4x8x16xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+// CHECK: module attributes {vkspv.entry_point_schedule =
+// CHECK-SAME: ["predict_ex_dispatch_0_dispatch_0",
+// CHECK-SAME: "predict_ex_dispatch_0_dispatch_1"]}
+// CHECK: func @predict_ex_dispatch_0_dispatch_1
+// CHECK-NEXT: iree.placeholder
+// CHECK-SAME: binding = @legacy_io::@ret1
+// CHECK-NEXT: iree.placeholder
+// CHECK-SAME: binding = @legacy_io::@arg1
+// CHECK-NEXT: linalg.generic
+// CHECK: linalg.yield
+// CHECK-NOT: linalg
+// CHECK: return
+// CHECK: func @predict_ex_dispatch_0_dispatch_0
+// CHECK-NEXT: iree.placeholder
+// CHECK-SAME: binding = @legacy_io::@ret0
+// CHECK-NEXT: iree.placeholder
+// CHECK-SAME: binding = @legacy_io::@arg0
+// CHECK-NEXT: linalg.copy
+// CHECK-NOT: linalg
+// CHECK: return
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
index a0731df..4674425 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -18,6 +18,7 @@
#include "iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h"
#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
@@ -89,6 +90,10 @@
mhlo::SetupMaterializeBroadcastsLegality(context, &conversionTarget);
mhlo::PopulateMaterializeBroadcastsPatterns(context, &conversionPatterns);
+ Shape::populateShapeToStandardConversionPatterns(conversionPatterns,
+ context);
+ Shape::setupShapeToStandardLegality(conversionTarget);
+
// Early conversion of ops that have matches we want to route through.
// For example, DynamicUpdateSlice should end up as a stream operation.
setupDirectHLOToFlowLegality(context, conversionTarget);
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index b0b11be..86db03f 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Utils/PatternUtils.h"
#include "llvm/Support/Debug.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
index de5a06f..1afc5ca 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
@@ -49,8 +49,11 @@
auto shapedType = op.operand().getType().dyn_cast<ShapedType>();
auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
if (shapedType && shapedType.hasRank() && rsType) {
- if (!shapedType.getShape().equals(rsType.getAllDims())) {
- return op.emitOpError("dims must match between tensor and shape");
+ for (auto it : llvm::zip(shapedType.getShape(), rsType.getAllDims())) {
+ if ((std::get<0>(it) != -1 && std::get<1>(it) != -1) &&
+ std::get<0>(it) != std::get<1>(it)) {
+ return op.emitOpError("dims must match between tensor and shape");
+ }
}
}
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
index 72845f1..539c4c5 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td
+++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
@@ -360,6 +360,27 @@
}
//===----------------------------------------------------------------------===//
+// Iota operations.
+//===----------------------------------------------------------------------===//
+
+def Shape_IotaOp : Shape_PureOp<"iota"> {
+ let summary = "Creates an iota of the desired 1-D shape.";
+ let description = [{
+ Creates an iota of the desired 1-D shape.
+
+ Usage:
+ %0 = shapex.iota %shp0 : !shapex.ranked_shape<...>
+ }];
+
+ let arguments = (ins Shape_RankedShape:$result_shape);
+ let results = (outs AnyRankedTensor:$result);
+
+ // TODO: Custom parser/printer
+ let parser = ?;
+ let printer = ?;
+}
+
+//===----------------------------------------------------------------------===//
// Shape manipulations.
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
index 7f5395b..5f40847 100644
--- a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
+++ b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
@@ -1,13 +1,27 @@
// RUN: iree-opt -split-input-file -verify-diagnostics %s
// -----
-func @tie_shape_mismatch_type(%arg0 : tensor<2x?x4xf32>, %arg1 : !shapex.ranked_shape<[1]>) {
+func @tie_shape_mismatch_rank(%arg0 : tensor<2x?x4xf32>, %arg1 : !shapex.ranked_shape<[1]>) {
// expected-error @+1 {{dims must match between tensor and shape}}
%0 = shapex.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shapex.ranked_shape<[1]>
return
}
// -----
+func @tie_shape_dynamic_tensor(%arg0 : tensor<2x?x4xf32>, %arg1 : !shapex.ranked_shape<[2, 1]>) {
+ %0 = shapex.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shapex.ranked_shape<[2, 1]>
+ return
+}
+
+// -----
+
+func @tie_shape_mistmatc_dim(%arg0 : tensor<2x?x4xf32>, %arg1 : !shapex.ranked_shape<[1, 1]>) {
+ // expected-error @+1 {{dims must match between tensor and shape}}
+ %0 = shapex.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shapex.ranked_shape<[1, 1]>
+ return
+}
+
+// -----
func @get_ranked_shape_same_rank(%arg0 : tensor<2x?x4xf32>) {
// expected-error @+1 {{op operand and result must be of same rank}}
%0 = shapex.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shapex.ranked_shape<[2]>
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index 9024e88..7a589f0 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -158,6 +158,13 @@
return bidOp.result_shape();
}
+Value rewriteShapexIota(RankedShapeType resultShape,
+ iree_compiler::Shape::IotaOp iotaOp,
+ OpBuilder &builder) {
+ if (!iotaOp) return nullptr;
+ return iotaOp.result_shape();
+}
+
Value rewriteTranspose(RankedShapeType resultShape, TransposeOp transposeOp,
OpBuilder &builder) {
if (!transposeOp) return nullptr;
@@ -460,6 +467,7 @@
b.insertOpRankedShapeBuilder<DotOp>(rewriteXlaDotOpShape);
b.insertOpRankedShapeBuilder<RankedBroadcastInDimOp>(
rewriteShapexRankedBroadcastInDim);
+ b.insertOpRankedShapeBuilder<iree_compiler::Shape::IotaOp>(rewriteShapexIota);
b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
b.insertOpRankedShapeBuilder<mhlo::DotGeneralOp>(rewriteDotGeneral);
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
index 82c50b2..6d26b42 100644
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -23,6 +23,7 @@
srcs = [
"CleanupPlaceholdersPass.cpp",
"ConvertHLOToShapeDialectPass.cpp",
+ "ConvertShapeToStandard.cpp",
"FunctionSignatureExpansionPass.cpp",
"HoistShapeCalculationsPass.cpp",
"MaterializeShapeCalculations.cpp",
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
index e3a1449..081065c 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -23,6 +23,7 @@
SRCS
"CleanupPlaceholdersPass.cpp"
"ConvertHLOToShapeDialectPass.cpp"
+ "ConvertShapeToStandard.cpp"
"FunctionSignatureExpansionPass.cpp"
"HoistShapeCalculationsPass.cpp"
"MaterializeShapeCalculations.cpp"
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
index 306e7f7..851556e 100644
--- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
@@ -47,6 +47,23 @@
}
};
+class ConvertDynamicIota : public OpConversionPattern<mhlo::DynamicIotaOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mhlo::DynamicIotaOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultTy = op.getType().cast<ShapedType>();
+ if (resultTy.getRank() != 1) {
+ return failure();
+ }
+
+ auto rankedShape = rewriter.create<Shape::FromExtentTensorOp>(
+ op.getLoc(), op.getOperand());
+ rewriter.replaceOpWithNewOp<Shape::IotaOp>(op, op.getType(), rankedShape);
+ return success();
+ }
+};
+
class ConvertHLOToShapePass
: public PassWrapper<ConvertHLOToShapePass, FunctionPass> {
void runOnFunction() override {
@@ -60,6 +77,9 @@
conversionTarget.addIllegalOp<mhlo::DynamicBroadcastInDimOp>();
conversionPatterns.insert<ConvertDynamicBroadcastInDim>(&getContext());
+ conversionTarget.addIllegalOp<mhlo::DynamicIotaOp>();
+ conversionPatterns.insert<ConvertDynamicIota>(&getContext());
+
if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertShapeToStandard.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertShapeToStandard.cpp
new file mode 100644
index 0000000..856b815
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertShapeToStandard.cpp
@@ -0,0 +1,79 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace Shape {
+
+namespace {
+
+class ConvertFromExtent : public OpConversionPattern<FromExtentTensorOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ FromExtentTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto input = op.extent_tensor();
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ if (!inputTy.hasRank() || inputTy.getRank() != 1) {
+ return failure();
+ }
+
+ llvm::SmallVector<Value, 4> extracted_elements;
+ auto valueCount = inputTy.getDimSize(0);
+ extracted_elements.reserve(valueCount);
+ for (int i = 0; i < valueCount; i++) {
+ auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
+ Value dim = rewriter.create<ExtractElementOp>(
+ op.getLoc(), inputTy.getElementType(), input, index.getResult());
+ if (!dim.getType().isIndex()) {
+ dim = rewriter.create<IndexCastOp>(op.getLoc(), rewriter.getIndexType(),
+ dim);
+ }
+ extracted_elements.push_back(dim);
+ }
+
+ SmallVector<int64_t, 4> dims;
+ dims.resize(valueCount, -1);
+ rewriter.replaceOpWithNewOp<Shape::MakeRankedShapeOp>(
+ op, Shape::RankedShapeType::get(dims, op.getContext()),
+ extracted_elements);
+
+ return success();
+ }
+};
+
+} // namespace
+
+// Populates patterns that will convert shape calculations into standard ops.
+void populateShapeToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<ConvertFromExtent>(context);
+}
+
+// Sets up legality for shape calculation materialization conversions.
+void setupShapeToStandardLegality(ConversionTarget &target) {
+ target.addIllegalOp<FromExtentTensorOp>();
+ target.addLegalOp<Shape::MakeRankedShapeOp>();
+}
+
+} // namespace Shape
+} // namespace iree_compiler
+} // namespace mlir
\ No newline at end of file
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
index 6c59951..0a78738 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp
@@ -69,6 +69,7 @@
RankedDimOp::getCanonicalizationPatterns(patterns, context);
RankedDimsOp::getCanonicalizationPatterns(patterns, context);
TieShapeOp::getCanonicalizationPatterns(patterns, context);
+ FromExtentTensorOp::getCanonicalizationPatterns(patterns, context);
applyPatternsAndFoldGreedily(getOperation(), patterns);
}
};
diff --git a/iree/compiler/Dialect/Shape/Transforms/Patterns.h b/iree/compiler/Dialect/Shape/Transforms/Patterns.h
index 7492ce4..41042ee 100644
--- a/iree/compiler/Dialect/Shape/Transforms/Patterns.h
+++ b/iree/compiler/Dialect/Shape/Transforms/Patterns.h
@@ -31,6 +31,13 @@
void populateMaterializeShapeCalculationsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
+// Sets up legality for shape calculation materialization conversions.
+void setupShapeToStandardLegality(ConversionTarget &target);
+
+// Populates patterns that will convert shape calculations into standard ops.
+void populateShapeToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context);
+
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index d213051..48c6fb7 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
@@ -144,6 +145,40 @@
TypeConverter &typeConverter;
};
+struct IotaOpConversion : public OpConversionPattern<Shape::IotaOp> {
+ IotaOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+ LogicalResult matchAndRewrite(
+ Shape::IotaOp op, ArrayRef<Value> operandValues,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultTy = op.getResult().getType().cast<ShapedType>();
+
+ int32_t elementSize = VMLATypeConverter::getRoundedElementByteWidth(
+ resultTy.getElementType());
+ auto elementSizeValue =
+ rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), elementSize);
+
+ auto shapeDim0 = rewriter.createOrFold<Shape::RankedDimOp>(
+ op.getLoc(), rewriter.getIndexType(), op.getOperand(),
+ rewriter.getI64IntegerAttr(0));
+
+ auto bufferSize = rewriter.createOrFold<mlir::MulIOp>(
+ op.getLoc(), elementSizeValue, shapeDim0);
+
+ auto dst = rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
+ op.getLoc(), IREE::VMLA::BufferType::get(rewriter.getContext()),
+ bufferSize);
+
+ rewriter.createOrFold<IREE::VMLA::IotaOp>(
+ op.getLoc(), dst, TypeAttr::get(resultTy.getElementType()));
+ rewriter.replaceOp(op, {dst});
+
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct CanonicalizeBroadcastOp : public OpRewritePattern<mhlo::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::BroadcastOp op,
@@ -791,6 +826,7 @@
patterns.insert<ScatterOpConversion>(context, typeConverter);
patterns.insert<SliceOpConversion>(context, typeConverter);
patterns.insert<DynamicSliceOpConversion>(context, typeConverter);
+ patterns.insert<IotaOpConversion>(context, typeConverter);
// Tensor-level canonicalizations to reduce the op surface area of the
// runtime.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 513e18a..6e967b9 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -284,6 +284,7 @@
VMLA_SIZED_IMPORT_OP(IREE::VMLA::GatherOp, "vmla.gather");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ScatterOp, "vmla.scatter");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::BroadcastOp, "vmla.broadcast");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::IotaOp, "vmla.iota");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TileOp, "vmla.tile");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::NotOp, "vmla.not");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index c9f32d1..8b5c425 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -282,6 +282,17 @@
}];
}
+def VMLA_IotaOp : VMLA_ElementTypeOp<"iota"> {
+ let arguments = (ins
+ VMLA_Buffer:$dst,
+ VMLA_AnyTypeAttr:$element_type
+ );
+
+ let assemblyFormat = [{
+ `out` $dst attr-dict `:` $element_type
+ }];
+}
+
def VMLA_TileOp : VMLA_ElementTypeOp<"tile", [VMLA_IncludeShapes]> {
let arguments = (ins
VMLA_Buffer:$src,
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index f66ef63..b9ec7c2 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -113,11 +113,12 @@
// This is skating on thin ice.
// TODO(silvasean): Legalize ToExtentTensorOp and FromExtentTensorOp.
conversionTarget.addIllegalOp<Shape::FromExtentTensorOp>();
- // RankedBroadcastInDimOp is an logically something that should be an
- // mhlo op (or in a dialect at a similar level of abstraction), but since
- // it isn't technically in that dialect, we need to special-case mark it as
- // illegal here.
+ // IotaOp and RankedBroadcastInDimOp is an logically something that should
+ // be an mhlo op (or in a dialect at a similar level of abstraction), but
+ // since it isn't technically in that dialect, we need to special-case mark
+ // it as illegal here.
// TODO(silvasean): Reconcile the dialect layering here.
+ conversionTarget.addIllegalOp<Shape::IotaOp>();
conversionTarget.addIllegalOp<Shape::RankedBroadcastInDimOp>();
if (failed(applyPartialConversion(getOperation(), conversionTarget,
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 5914fbf..5f2177d 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -217,6 +217,11 @@
%dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
)
+vm.import @iota.i8(%dst : !vm.ref<!vmla.buffer>)
+vm.import @iota.i16(%dst : !vm.ref<!vmla.buffer>)
+vm.import @iota.i32(%dst : !vm.ref<!vmla.buffer>)
+vm.import @iota.f32(%dst : !vm.ref<!vmla.buffer>)
+
vm.import @tile.x8(
%src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
%dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index 6cf3b27..bdacd77 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -168,6 +168,11 @@
absl::Span<T> dst_buffer);
};
+struct Iota {
+ template <typename T>
+ static Status Execute(absl::Span<T> dst_buffer);
+};
+
struct Tile {
template <typename T>
static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index c436798..1d08230 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -510,6 +510,16 @@
}
template <typename T>
+Status Iota::Execute(absl::Span<T> dst_buffer) {
+ T value = 0;
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = value;
+ value += 1;
+ }
+ return OkStatus();
+}
+
+template <typename T>
Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer,
ShapeSpan src_shape, ShapeSpan dst_shape) {
// This implementation is .... not fast.
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index e74fb3f..6354eaf 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -330,6 +330,12 @@
// Common helpers for defining ops
//===--------------------------------------------------------------------===//
+#define IREE_VMLA_NONARY_OP(name, kernel, type) \
+ Status name(vm::ref<Buffer> dst) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernel::Execute<type>(dst->As<type>()); \
+ }
+
#define IREE_VMLA_UNARY_OP(name, kernel, type) \
Status name(vm::ref<Buffer> src, vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
@@ -506,6 +512,11 @@
IREE_VMLA_BROADCAST_OP(BroadcastX16, uint16_t);
IREE_VMLA_BROADCAST_OP(BroadcastX32, uint32_t);
+ IREE_VMLA_NONARY_OP(IotaI8, kernels::Iota, int8_t);
+ IREE_VMLA_NONARY_OP(IotaI16, kernels::Iota, int16_t);
+ IREE_VMLA_NONARY_OP(IotaI32, kernels::Iota, int32_t);
+ IREE_VMLA_NONARY_OP(IotaF32, kernels::Iota, float_t);
+
#define IREE_VMLA_TILE_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape) { \
@@ -832,6 +843,10 @@
vm::MakeNativeFunction("scatter.x8", &VMLAModuleState::ScatterX8),
vm::MakeNativeFunction("scatter.x16", &VMLAModuleState::ScatterX16),
vm::MakeNativeFunction("scatter.x32", &VMLAModuleState::ScatterX32),
+ vm::MakeNativeFunction("iota.i8", &VMLAModuleState::IotaI8),
+ vm::MakeNativeFunction("iota.i16", &VMLAModuleState::IotaI16),
+ vm::MakeNativeFunction("iota.i32", &VMLAModuleState::IotaI32),
+ vm::MakeNativeFunction("iota.f32", &VMLAModuleState::IotaF32),
vm::MakeNativeFunction("tile.x8", &VMLAModuleState::TileX8),
vm::MakeNativeFunction("tile.x16", &VMLAModuleState::TileX16),
vm::MakeNativeFunction("tile.x32", &VMLAModuleState::TileX32),
diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc
index 076eebc..aad2ae3 100644
--- a/iree/hal/vulkan/timepoint_util.cc
+++ b/iree/hal/vulkan/timepoint_util.cc
@@ -29,6 +29,7 @@
// static
void TimePointFence::Delete(TimePointFence* ptr) {
+ ptr->ResetStatus();
ptr->pool()->ReleaseResolved(ptr);
}
@@ -41,6 +42,11 @@
return status_;
}
+void TimePointFence::ResetStatus() {
+ absl::MutexLock lock(&status_mutex_);
+ status_ = VK_NOT_READY;
+}
+
// static
StatusOr<ref_ptr<TimePointFencePool>> TimePointFencePool::Create(
ref_ptr<VkDeviceHandle> logical_device) {
diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h
index 4135326..fc11246 100644
--- a/iree/hal/vulkan/timepoint_util.h
+++ b/iree/hal/vulkan/timepoint_util.h
@@ -67,6 +67,9 @@
// under the hood.
VkResult GetStatus();
+ // Resets the status to unsignaled (VK_NOT_READY).
+ void ResetStatus();
+
// Returns the pool from which this fence comes.
TimePointFencePool* pool() const { return pool_; }