Add einsum_test.;y (#3367)

Einsum isn't currently supported by IREE, but these tests
should be helpful for enabling it.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 5b89789..ed15f36 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -57,12 +57,16 @@
     "mobile_bert_squad_test.py",
 ]
 
+# keep sorted
 TFLITE_FAILING = [
     "broadcasting_test.py",
     "complex_test.py",
     "concat_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
+    "einsum_dynamic_test.py",
+    "einsum_static_test.py",
+    "einsum_vector_test.py",
     "finite_test.py",
     "gather_test.py",
     "mandelbrot_test.py",
@@ -77,6 +81,9 @@
 
 # keep sorted
 VMLA_FAILING = [
+    "einsum_dynamic_test.py",
+    "einsum_static_test.py",
+    "einsum_vector_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "ring_buffer_test.py",  # TODO(b/148747011)
     "strings_test.py",
@@ -88,6 +95,9 @@
     "broadcasting_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
+    "einsum_dynamic_test.py",
+    "einsum_static_test.py",
+    "einsum_vector_test.py",
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
@@ -106,6 +116,9 @@
     "broadcasting_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
+    "einsum_dynamic_test.py",
+    "einsum_static_test.py",
+    "einsum_vector_test.py",
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
diff --git a/integrations/tensorflow/e2e/einsum_dynamic_test.py b/integrations/tensorflow/e2e/einsum_dynamic_test.py
new file mode 100644
index 0000000..9b29d69
--- /dev/null
+++ b/integrations/tensorflow/e2e/einsum_dynamic_test.py
@@ -0,0 +1,145 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test matrix ops via einsum"""
+
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+LEFT_DIM = 6
+INNER_DIM = 3
+RIGHT_DIM = 6
+BATCH_DIM = 8
+
+
+class EinsumDynamicModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None], tf.float32),
+  ])
+  def einsum_dynamic_dim_identity(self, x):
+    return tf.einsum('ij', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, None], tf.float32),
+  ])
+  def einsum_dynamic_rank_identity(self, x):
+    return tf.einsum('...', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_dynamic_dim_transpose(self, x):
+    return tf.einsum('bij -> bji', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_dynamic_rank_diag(self, x):
+    return tf.einsum('...ii -> ...i', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_dynamic_dim_sum(self, x):
+    return tf.einsum('abij -> ab', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None], tf.float32),
+      tf.TensorSpec([None, None], tf.float32),
+  ])
+  def einsum_dynamic_dim_matmul(self, lhs, rhs):
+    return tf.einsum('ij, jk -> ik', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_dynamic_dim_lhs_batch(self, lhs, rhs):
+    return tf.einsum('bij, jk -> bik', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([None, None, 8, 6], tf.float32),
+      tf.TensorSpec([12, 6, 4], tf.float32),
+  ])
+  def einsum_dynamic_rank_split_heads(self, seq, weights):
+    # l: seq_len, m: d_model, h: num_heads, d: attention_depth
+    return tf.einsum('...lm, hmd -> ...hld', seq, weights)
+
+
+class EinsumDynamicTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super(EinsumDynamicTest, self).__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(EinsumDynamicModule)
+
+  # yapf: disable
+  def test_einsum_dynamic_dim_identity(self):
+    def einsum_dynamic_dim_identity(module):
+      module.einsum_dynamic_dim_identity(
+          tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_dim_identity, self._modules)
+
+  def test_einsum_dynamic_rank_identity(self):
+    def einsum_dynamic_rank_identity(module):
+      module.einsum_dynamic_rank_identity(
+          tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_rank_identity, self._modules)
+
+  def test_einsum_dynamic_dim_transpose(self):
+    def einsum_dynamic_dim_transpose(module):
+      module.einsum_dynamic_dim_transpose(
+          tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_dim_transpose, self._modules)
+
+  def test_einsum_dynamic_rank_diag(self):
+    def einsum_dynamic_rank_diag(module):
+      module.einsum_dynamic_rank_diag(
+          tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_rank_diag, self._modules)
+
+  def test_einsum_dynamic_dim_sum(self):
+    def einsum_dynamic_dim_sum(module):
+      module.einsum_dynamic_dim_sum(
+           tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_dim_sum, self._modules)
+
+  def test_einsum_dynamic_dim_matmul(self):
+    def einsum_dynamic_dim_matmul(module):
+      module.einsum_dynamic_dim_matmul(
+          tf_utils.ndarange([LEFT_DIM, INNER_DIM]),
+          tf_utils.ndarange([INNER_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_dim_matmul, self._modules)
+
+  def test_einsum_dynamic_dim_lhs_batch(self):
+    def einsum_dynamic_dim_lhs_batch(module):
+      module.einsum_dynamic_dim_lhs_batch(
+          tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]),
+          tf_utils.ndarange([INNER_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_dynamic_dim_lhs_batch, self._modules)
+
+  def test_einsum_dynamic_rank_split_heads(self):
+    def einsum_dynamic_rank_split_heads(module):
+      module.einsum_dynamic_rank_split_heads(
+          tf_utils.ndarange([BATCH_DIM, BATCH_DIM, 8, 6]),
+          tf_utils.ndarange([12, 6, 4]))
+    self.compare_backends(einsum_dynamic_rank_split_heads, self._modules)
+  # yapf: enable
+
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/einsum_static_test.py b/integrations/tensorflow/e2e/einsum_static_test.py
new file mode 100644
index 0000000..b673ca9
--- /dev/null
+++ b/integrations/tensorflow/e2e/einsum_static_test.py
@@ -0,0 +1,226 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test matrix ops via einsum"""
+
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+LEFT_DIM = 6
+INNER_DIM = 3
+RIGHT_DIM = 6
+BATCH_DIM = 8
+
+
+class EinsumStaticModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_identity(self, x):
+    return tf.einsum('ij', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_implicit_transpose(self, x):
+    return tf.einsum('ji', x)  # :woozy:
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_explicit_transpose(self, x):
+    return tf.einsum('ij -> ji', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_implicit_trace(self, x):
+    return tf.einsum('ii', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_explicit_trace(self, x):
+    return tf.einsum('ii ->', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_diag(self, x):
+    return tf.einsum('ii -> i', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_sum(self, x):
+    return tf.einsum('ij ->', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_sum_axis_0(self, x):
+    return tf.einsum('ij -> j', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_sum_axis_1(self, x):
+    return tf.einsum('ij -> i', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_matmul(self, lhs, rhs):
+    return tf.einsum('ij, jk -> ik', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_lhs_batch(self, lhs, rhs):
+    return tf.einsum('bij, jk -> bik', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32),
+      tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32),
+  ])
+  def einsum_broadcast_singleton_dimension(self, lhs, rhs):
+    return tf.einsum('lij, rjk -> rik', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_DIM, 8, 6], tf.float32),
+      tf.TensorSpec([12, 6, 4], tf.float32),
+  ])
+  def einsum_split_heads(self, seq, weights):
+    # l: seq_len, m: d_model, h: num_heads, d: attention_depth
+    return tf.einsum('blm, hmd -> bhld', seq, weights)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_DIM, 5, 3, 2, 6], tf.float32),
+      tf.TensorSpec([BATCH_DIM, 5, 6], tf.float32),
+  ])
+  def einsum_batched_high_rank_matrix_vector_mul(self, lhs, rhs):
+    return tf.einsum('bijxy, biy -> bijx', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_DIM, 2, 6], tf.float32),
+      tf.TensorSpec([BATCH_DIM, 5, 3, 6], tf.float32),
+  ])
+  def einsum_batched_matrix_high_rank_vector_mul(self, lhs, rhs):
+    return tf.einsum('bxy, bijy -> bijx', lhs, rhs)
+
+
+class EinsumStaticTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super(EinsumStaticTest, self).__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(EinsumStaticModule)
+
+  # yapf: disable
+  def test_einsum_identity(self):
+    def einsum_identity(module):
+      module.einsum_identity(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_identity, self._modules)
+
+  def test_einsum_implicit_transpose(self):
+    def einsum_implicit_transpose(module):
+      module.einsum_implicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_implicit_transpose, self._modules)
+
+  def test_einsum_explicit_transpose(self):
+    def einsum_explicit_transpose(module):
+      module.einsum_explicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_explicit_transpose, self._modules)
+
+  def test_einsum_implicit_trace(self):
+    def einsum_implicit_trace(module):
+      module.einsum_implicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_implicit_trace, self._modules)
+
+  def test_einsum_explicit_trace(self):
+    def einsum_explicit_trace(module):
+      module.einsum_explicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_explicit_trace, self._modules)
+
+  def test_einsum_diag(self):
+    def einsum_diag(module):
+      module.einsum_diag(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_diag, self._modules)
+
+  def test_einsum_sum(self):
+    def einsum_sum(module):
+      module.einsum_sum(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_sum, self._modules)
+
+  def test_einsum_sum_axis_0(self):
+    def einsum_sum_axis_0(module):
+      module.einsum_sum_axis_0(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_sum_axis_0, self._modules)
+
+  def test_einsum_sum_axis_1(self):
+    def einsum_sum_axis_1(module):
+      module.einsum_sum_axis_1(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_sum_axis_1, self._modules)
+
+  def test_einsum_matmul(self):
+    def einsum_matmul(module):
+      module.einsum_matmul(tf_utils.ndarange([LEFT_DIM, INNER_DIM]),
+                           tf_utils.ndarange([INNER_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_matmul, self._modules)
+
+  def test_einsum_lhs_batch(self):
+    def einsum_lhs_batch(module):
+      module.einsum_lhs_batch(
+          tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]),
+          tf_utils.ndarange([INNER_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_lhs_batch, self._modules)
+
+  def test_einsum_broadcast_singleton_dimension(self):
+    def einsum_broadcast_singleton_dimension(module):
+      module.einsum_broadcast_singleton_dimension(
+          tf_utils.ndarange([1, LEFT_DIM, INNER_DIM]),
+          tf_utils.ndarange([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
+    self.compare_backends(einsum_broadcast_singleton_dimension, self._modules)
+
+  def test_einsum_split_heads(self):
+    def einsum_split_heads(module):
+      module.einsum_split_heads(tf_utils.ndarange([BATCH_DIM, 8, 6]),
+                                tf_utils.ndarange([12, 6, 4]))
+    self.compare_backends(einsum_split_heads, self._modules)
+
+  def test_einsum_batched_high_rank_matrix_vector_mul(self):
+    def einsum_batched_high_rank_matrix_vector_mul(module):
+      module.einsum_batched_high_rank_matrix_vector_mul(
+          tf_utils.ndarange([BATCH_DIM, 5, 3, 2, 6]),
+          tf_utils.ndarange([BATCH_DIM, 5, 6]))
+    self.compare_backends(einsum_batched_high_rank_matrix_vector_mul,
+                          self._modules)
+
+  def test_einsum_batched_matrix_high_rank_vector_mul(self):
+    def einsum_batched_matrix_high_rank_vector_mul(module):
+      module.einsum_batched_matrix_high_rank_vector_mul(
+          tf_utils.ndarange([BATCH_DIM, 2, 6]),
+          tf_utils.ndarange([BATCH_DIM, 5, 3, 6]))
+    self.compare_backends(einsum_batched_matrix_high_rank_vector_mul,
+                          self._modules)
+  # yapf: enable
+
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/einsum_vector_test.py b/integrations/tensorflow/e2e/einsum_vector_test.py
new file mode 100644
index 0000000..910a43f
--- /dev/null
+++ b/integrations/tensorflow/e2e/einsum_vector_test.py
@@ -0,0 +1,113 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test matrix ops via einsum"""
+
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+VECTOR_DIM = 16
+
+
+class EinsumVectorModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_identity(self, x):
+    return tf.einsum('i', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_sum(self, x):
+    return tf.einsum('i ->', x)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_mul(self, lhs, rhs):
+    return tf.einsum('i, i -> i', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_implicit_inner_product(self, lhs, rhs):
+    return tf.einsum('i, i', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_explicit_inner_product(self, lhs, rhs):
+    return tf.einsum('i, i ->', lhs, rhs)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+      tf.TensorSpec([VECTOR_DIM], tf.float32),
+  ])
+  def einsum_outer_product(self, lhs, rhs):
+    return tf.einsum('i, j -> ij', lhs, rhs)
+
+
+class EinsumVectorTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super(EinsumVectorTest, self).__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(EinsumVectorModule)
+
+  # yapf: disable
+  def test_einsum_identity(self):
+    def einsum_identity(module):
+      module.einsum_identity(tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_identity, self._modules)
+
+  def test_einsum_sum(self):
+    def einsum_sum(module):
+      module.einsum_sum(tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_sum, self._modules)
+
+  def test_einsum_mul(self):
+    def einsum_mul(module):
+      module.einsum_mul(tf_utils.ndarange([VECTOR_DIM]),
+                        tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_mul, self._modules)
+
+  def test_einsum_implicit_inner_product(self):
+    def einsum_implicit_inner_product(module):
+      module.einsum_implicit_inner_product(tf_utils.ndarange([VECTOR_DIM]),
+                                           tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_implicit_inner_product, self._modules)
+
+  def test_einsum_explicit_inner_product(self):
+    def einsum_explicit_inner_product(module):
+      module.einsum_explicit_inner_product(tf_utils.ndarange([VECTOR_DIM]),
+                                           tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_explicit_inner_product, self._modules)
+
+  def test_einsum_outer_product(self):
+    def einsum_outer_product(module):
+      module.einsum_outer_product(tf_utils.ndarange([VECTOR_DIM]),
+                                  tf_utils.ndarange([VECTOR_DIM]))
+    self.compare_backends(einsum_outer_product, self._modules)
+  # yapf: enable
+
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()