blob: 4f8daf52ab95ac1455973c9abd6bab3ae9b11045 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test MobileBERT.
Model topology and weights are from
https://github.com/google-research/google-research/tree/master/mobilebert
"""
import os
import posixpath
from absl import app
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_boolean("use_quantized_weights", False,
"Whether to use quantized or floating point weights.")
MAX_SEQ_LENGTH = 384 # Max input sequence length used in mobilebert_squad.
FILE_NAME = "mobilebert_squad_savedmodels.tar.gz"
MODEL_URL = posixpath.join(
"https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/",
FILE_NAME)
class MobileBertSquad(tf.Module):
"""Wrapper of MobileBertSquad saved model v1."""
def __init__(self):
self.model_path = self.get_model_path()
self.saved_model = tf.saved_model.load(self.model_path, tags=["serve"])
self.inference_func = self.saved_model.signatures["serving_default"]
@staticmethod
def get_model_path():
model_type = "quant_saved_model" if FLAGS.use_quantized_weights else "float"
# Get_file will download the model weights from a publicly available folder,
# save them to cache_dir=~/.keras/datasets/ and return a path to them.
model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
model_dir = os.path.dirname(model_path)
extracted_name = FILE_NAME.split(".")[0]
model_path = os.path.join(model_dir, extracted_name, model_type)
return model_path
@staticmethod
def get_legacy_tflite_saved_model_converter_kwargs():
return dict([("input_arrays", ["input_ids", "input_mask", "segment_ids"]),
("output_arrays", ["start_logits", "end_logits"]),
("exported_name", "predict"),
("model_path", MobileBertSquad.get_model_path())])
@tf.function(input_signature=[
tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
])
def predict(self, input_ids, input_mask, segment_ids):
inputs = {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
}
return self.inference_func(**inputs)
class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
"""Tests of MobileBertSquad."""
def __init__(self, methodName="runTest"):
super(MobileBertSquadTest, self).__init__(methodName)
self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
exported_names=["predict"])
def test_predict(self):
def predict(module):
input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
module.predict(input_ids, input_mask, segment_ids, atol=1e0)
self.compare_backends(predict, self._modules)
def main(argv):
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
if __name__ == '__main__':
app.run(main)