Automated sync from github.com/tensorflow/tensorflow (#2558)
BUG=automated sync from upstream
NO_CHECK_TFLITE_FILES=automated sync from upstream
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 5425030..3526810 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -921,11 +921,14 @@
case BuiltinOperator_STABLEHLO_PAD: {
return ParseStablehloPad(op, error_reporter, allocator, builtin_data);
}
+ case BuiltinOperator_STABLEHLO_COMPOSITE: {
+ return ParseStablehloComposite(op, error_reporter, allocator,
+ builtin_data);
+ }
// TODO: skip param parsing for now since ops below don't have kernels
case BuiltinOperator_STABLEHLO_SLICE:
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
case BuiltinOperator_STABLEHLO_CONVOLUTION:
- case BuiltinOperator_STABLEHLO_COMPOSITE:
case BuiltinOperator_STABLEHLO_LOGISTIC:
case BuiltinOperator_STABLEHLO_ADD:
case BuiltinOperator_STABLEHLO_DIVIDE:
@@ -2382,6 +2385,31 @@
return kTfLiteError;
}
+TfLiteStatus ParseStablehloComposite(const Operator* op,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ auto params = safe_allocator.Allocate<TfLiteStablehloCompositeParams>();
+ const StableHLOCompositeOptions* schema_params =
+ op->builtin_options_2_as_StableHLOCompositeOptions();
+ if (schema_params) {
+ params->name = schema_params->name()->c_str();
+ params->version = schema_params->version();
+ params->subgraph_index = schema_params->decomposition_subgraph_index();
+ params->attributes = schema_params->composite_attributes()->data();
+ params->attributes_size = schema_params->composite_attributes()->size();
+ *builtin_data = params.release();
+ return kTfLiteOk;
+ }
+ TF_LITE_REPORT_ERROR(
+ error_reporter,
+ "Could not get 'stablehlo.composite' operation parameters.");
+ return kTfLiteError;
+}
+
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 1c90e9f..c01e887 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -445,6 +445,11 @@
BuiltinDataAllocator* allocator,
void** builtin_data);
+TfLiteStatus ParseStablehloComposite(const Operator* op,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator,
+ void** builtin_data);
+
} // namespace tflite
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h
index 1ac385b..e1428e7 100644
--- a/tensorflow/lite/core/c/builtin_op_data.h
+++ b/tensorflow/lite/core/c/builtin_op_data.h
@@ -21,6 +21,7 @@
#define TENSORFLOW_LITE_CORE_C_BUILTIN_OP_DATA_H_
#include <stdbool.h>
+#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/core/c/common.h"
@@ -645,6 +646,14 @@
int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT];
} TfLiteStablehloPadParams;
+typedef struct {
+ const char* name;
+ int32_t subgraph_index;
+ int32_t version;
+ const uint8_t* attributes;
+ size_t attributes_size;
+} TfLiteStablehloCompositeParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus