blob: 99abbbbeff1b8f82ebfd0f513d18f90a22c9ed38 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test MobileBERT.
Model topology and weights are from
https://github.com/google-research/google-research/tree/master/mobilebert
"""
import os
import posixpath
from absl import app
from absl import flags
from iree.tf.support import tf_test_utils
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_boolean('use_quantized_weights', False,
'Whether to use quantized or floating point weights.')
MAX_SEQ_LENGTH = 384 # Max input sequence length used in mobilebert_squad.
FILE_NAME = 'mobilebert_squad_savedmodels'
MODEL_URL = posixpath.join(
f'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/{FILE_NAME}.tar.gz'
)
class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
"""Tests of MobileBertSquad."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_type = 'quant_saved_model' if FLAGS.use_quantized_weights else 'float'
# Get_file will download the model weights from a publicly available folder,
# save them to cache_dir=~/.keras/datasets/ and return a path to them.
model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
model_dir = os.path.dirname(model_path)
extracted_name = FILE_NAME.split('.')[0]
model_path = os.path.join(model_dir, extracted_name, model_type)
self._modules = tf_test_utils.compile_tf_signature_def_saved_model(
saved_model_dir=model_path,
saved_model_tags=set(['serve']),
module_name='MobileBertSquad',
exported_name='serving_default',
input_names=['input_ids', 'input_mask', 'segment_ids'],
output_names=['start_logits', 'end_logits'])
def test_serving_default(self):
def serving_default(module):
input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
module.serving_default(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
atol=1e0)
self.compare_backends(serving_default, self._modules)
def main(argv):
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
if __name__ == '__main__':
app.run(main)