Remove sklearn dependency and run training tests on CI (#3610)

Replaces training on polynomial features with training on random
linear data and spiral classification data. Also makes pytype linting
work with deleted files, and excludes failing targets from tap in
`iree_e2e_cartesian_product_test_suite`.
diff --git a/build_tools/pytype/check_diff.sh b/build_tools/pytype/check_diff.sh
index 2b9851a..8fa9139 100755
--- a/build_tools/pytype/check_diff.sh
+++ b/build_tools/pytype/check_diff.sh
@@ -28,7 +28,7 @@
 if [[ "${DIFF_TARGET?}" = "all" ]]; then
   FILES=$(find -name "*\.py" -not -path "./third_party/*")
 else
-  FILES=$(git diff --name-only "${DIFF_TARGET?}" | grep '.*\.py')
+  FILES=$(git diff --diff-filter=d --name-only "${DIFF_TARGET?}" | grep '.*\.py')
 fi
 
 
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
index 0af8da8..0f90b31 100644
--- a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -227,7 +227,12 @@
             tests = tests,
             # Add "+failing" to only include tests in `tests` that have the
             # "failing" tag.
-            tags = tags + ["+failing"],
+            tags = tags + [
+                "+failing",
+                "manual",
+                "nokokoro",
+                "notap",
+            ],
             # If there are kwargs that need to be passed here which only apply
             # to the generated tests and not to test_suite, they should be
             # extracted into separate named arguments.
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
index 177bcaf..7dc0bd6 100644
--- a/integrations/tensorflow/e2e/keras/train/BUILD
+++ b/integrations/tensorflow/e2e/keras/train/BUILD
@@ -16,6 +16,7 @@
     "//bindings/python:build_defs.oss.bzl",
     "INTREE_TENSORFLOW_PY_DEPS",
     "NUMPY_DEPS",
+    "iree_py_binary",
 )
 load(
     "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
@@ -28,11 +29,28 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
+[
+    iree_py_binary(
+        name = src.replace(".py", "_manual"),
+        srcs = [src],
+        main = src,
+        python_version = "PY3",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for src in glob(["*_test.py"])
+]
+
 iree_e2e_cartesian_product_test_suite(
-    name = "train_tests",
-    srcs = ["model_train_test.py"],
+    name = "classification_training_tests",
+    srcs = ["classification_training_test.py"],
     failing_configurations = [
         {
+            # TFLite doesn't support training.
+            "target_backends": "tflite",
+        },
+        {
             "target_backends": [
                 "tflite",
                 "iree_vmla",  # TODO(b/157581521)
@@ -43,9 +61,15 @@
     ],
     flags_to_values = {
         "reference_backend": "tf",
-        "optimizer_name": [
-            "sgd",
+        "optimizer": [
+            "adadelta",
+            "adagrad",
             "adam",
+            "adamax",
+            "ftrl",
+            "nadam",
+            "rmsprop",
+            "sgd",
         ],
         "target_backends": [
             "tf",
@@ -55,13 +79,49 @@
             "iree_vulkan",
         ],
     },
-    main = "model_train_test.py",
-    tags = [
-        "guitar",
-        "manual",
-        "nokokoro",
-        "notap",
+    main = "classification_training_test.py",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+    name = "regression_training_tests",
+    srcs = ["regression_training_test.py"],
+    failing_configurations = [
+        {
+            # TFLite doesn't support training.
+            "target_backends": "tflite",
+        },
+        {
+            "target_backends": [
+                "iree_vmla",  # TODO(b/157581521)
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "optimizer": [
+            "adadelta",
+            "adagrad",
+            "adam",
+            "adamax",
+            "ftrl",
+            "nadam",
+            "rmsprop",
+            "sgd",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "regression_training_test.py",
     deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
diff --git a/integrations/tensorflow/e2e/keras/train/README.md b/integrations/tensorflow/e2e/keras/train/README.md
deleted file mode 100644
index 86222b1..0000000
--- a/integrations/tensorflow/e2e/keras/train/README.md
+++ /dev/null
@@ -1,10 +0,0 @@
-# Keras Training Tests
-
-These tests require an additional python dependency on `sklearn`, which
-can be installed as follows:
-
-```shell
-python3 -m pip install sklearn
-```
-
-These tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/keras/train/classification_training_test.py b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
new file mode 100644
index 0000000..cd3305b
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
@@ -0,0 +1,119 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test training a classification model with Keras optimizers."""
+
+import os
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(
+    "optimizer", "sgd",
+    "One of 'adadelta', 'adagrad', 'adam', 'adamax', 'ftrl', 'nadam', "
+    "'rmsprop' or 'sgd'")
+
+SAMPLES_PER_SPIRAL = 8
+INPUT_DIM = 2
+NUM_CLASSES = 4
+BATCH_SIZE = NUM_CLASSES * SAMPLES_PER_SPIRAL
+
+
+def get_spiral_dataset(samples_per_spiral: int,
+                       noise_scale: float = 0,
+                       shuffle: bool = True):
+  """Creates a dataset with four spiral arms."""
+  t = np.linspace(0, 1, samples_per_spiral, dtype=np.float32)
+  cos_term = t * np.sin(2 * np.pi * t)
+  sin_term = t * np.cos(2 * np.pi * t)
+  spirals = [
+      np.stack([sin_term, cos_term], axis=-1),
+      np.stack([-sin_term, -cos_term], axis=-1),
+      np.stack([-cos_term, sin_term], axis=-1),
+      np.stack([cos_term, -sin_term], axis=-1)
+  ]
+  inputs = np.concatenate(spirals)
+  inputs = inputs + np.random.normal(scale=noise_scale, size=inputs.shape)
+  labels = np.concatenate([i * np.ones_like(t) for i in range(4)])
+
+  if shuffle:
+    # Shuffle by batch dim.
+    index = np.arange(inputs.shape[0])
+    np.random.shuffle(index)
+    inputs = inputs[index]
+    labels = labels[index]
+
+  return inputs, labels
+
+
+class ClassificationTrainingModule(tf.Module):
+  """A module for model training."""
+
+  def __init__(self):
+    inputs = tf.keras.layers.Input(INPUT_DIM)
+    x = tf.keras.layers.Dense(NUM_CLASSES)(inputs)
+    outputs = tf.keras.layers.Softmax()(x)
+    self.model = tf.keras.Model(inputs, outputs)
+    print(self.model)
+
+    self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
+    self.optimizer = tf.keras.optimizers.get(FLAGS.optimizer)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_SIZE, INPUT_DIM], tf.float32),
+      tf.TensorSpec([BATCH_SIZE], tf.float32)
+  ])
+  def train_on_batch(self, inputs, labels):
+    with tf.GradientTape() as tape:
+      probs = self.model(inputs, training=True)
+      loss = self.loss(labels, probs)
+    variables = self.model.trainable_variables
+    gradients = tape.gradient(loss, variables)
+    self.optimizer.apply_gradients(zip(gradients, variables))
+    return loss
+
+
+class ClassificationTrainingTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        ClassificationTrainingModule, exported_names=["train_on_batch"])
+
+  def test_train_on_batch(self):
+
+    def train_on_batch(module):
+      inputs, labels = get_spiral_dataset(SAMPLES_PER_SPIRAL, noise_scale=0.05)
+      module.train_on_batch(inputs, labels)
+
+    self.compare_backends(train_on_batch, self._modules)
+
+
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  ClassificationTrainingModule.__name__ = os.path.join(
+      ClassificationTrainingModule.__name__, FLAGS.optimizer)
+  tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
deleted file mode 100644
index e074590..0000000
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test keras Model training."""
-
-from absl import app
-from absl import flags
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
-from sklearn.preprocessing import PolynomialFeatures
-import tensorflow as tf
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string(
-    "optimizer_name", "sgd",
-    "optimizer name: sgd, rmsprop, nadam, adamax, adam, adagrad, adadelta")
-
-_DEGREE = 3  # polynomial degree of input feature for regression test
-_FEATURE_SIZE = _DEGREE + 1  # input feature size
-_BATCH_SIZE = 8  # batch size has to be dynamic TODO(b/142948097)
-_INPUT_DATA_SHAPE = [_BATCH_SIZE, _FEATURE_SIZE]
-_OUTPUT_DATA_SHAPE = [_BATCH_SIZE, 1]
-
-
-class ModelTrain(tf.Module):
-  """A module for model training."""
-
-  @staticmethod
-  def CreateModule(input_dim=_FEATURE_SIZE, output_dim=1):
-    """Creates a module for regression model training.
-
-    Args:
-      input_dim: input dimensionality
-      output_dim: output dimensionality
-
-    Returns:
-      model for linear regression
-    """
-
-    tf_utils.set_random_seed()
-
-    # build a single layer model
-    inputs = tf.keras.layers.Input((input_dim))
-    outputs = tf.keras.layers.Dense(output_dim)(inputs)
-    model = tf.keras.Model(inputs, outputs)
-    return ModelTrain(model)
-
-  def __init__(self, model):
-    self.model = model
-    self.loss = tf.keras.losses.MeanSquaredError()
-    self.optimizer = tf.keras.optimizers.get(FLAGS.optimizer_name)
-
-  @tf.function(input_signature=[
-      tf.TensorSpec(_INPUT_DATA_SHAPE, tf.float32),
-      tf.TensorSpec(_OUTPUT_DATA_SHAPE, tf.float32)
-  ])
-  def train_step(self, inputs, targets):
-    with tf.GradientTape() as tape:
-      predictions = self.model(inputs, training=True)
-      loss_value = self.loss(predictions, targets)
-    gradients = tape.gradient(loss_value, self.model.trainable_variables)
-    self.optimizer.apply_gradients(
-        zip(gradients, self.model.trainable_variables))
-    return loss_value
-
-
-class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(
-        ModelTrain.CreateModule, exported_names=["train_step"])
-
-  def generate_regression_data(self, size=8):
-    x = np.arange(size) - size // 2
-    y = 1.0 * x**3 + 1.0 * x**2 + 1.0 * x + np.random.randn(size) * size
-    return x, y
-
-  def test_model_train(self):
-
-    # Generate input and output data for regression problem.
-    inputs, targets = self.generate_regression_data()
-
-    # Normalize data.
-    inputs = inputs / max(inputs)
-    targets = targets / max(targets)
-
-    # Generate polynomial features.
-    inputs = np.expand_dims(inputs, axis=1)
-    polynomial = PolynomialFeatures(_DEGREE)  # returns: [1, a, b, a^2, ab, b^2]
-    inputs = polynomial.fit_transform(inputs)
-
-    targets = np.expand_dims(targets, axis=1)
-
-    def train_step(module):
-      # Run one iteration of training step.
-      module.train_step(inputs, targets)
-
-    self.compare_backends(train_step, self._modules)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, "enable_v2_behavior"):
-    tf.enable_v2_behavior()
-
-  tf.test.main()
-
-
-if __name__ == "__main__":
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/regression_training_test.py b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
new file mode 100644
index 0000000..cd064d3
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
@@ -0,0 +1,97 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test training a regression model with Keras optimizers."""
+
+import os
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(
+    "optimizer", "sgd",
+    "One of 'adadelta', 'adagrad', 'adam', 'adamax', 'ftrl', 'nadam', "
+    "'rmsprop' or 'sgd'")
+
+np.random.seed(0)
+INPUT_DIM = 8
+OUTPUT_DIM = 2
+BATCH_SIZE = 4
+WEIGHTS = np.random.uniform(-1, 1, size=(INPUT_DIM, OUTPUT_DIM))
+BIASES = np.random.uniform(-1, 1, size=(OUTPUT_DIM,))
+
+
+def get_linear_data():
+  x = np.random.uniform(-1, 1, size=(BATCH_SIZE, INPUT_DIM))
+  y = np.dot(x, WEIGHTS) + BIASES
+  return x, y
+
+
+class RegressionTrainingModule(tf.Module):
+  """A module for model training."""
+
+  def __init__(self):
+    inputs = tf.keras.layers.Input(INPUT_DIM)
+    outputs = tf.keras.layers.Dense(OUTPUT_DIM)(inputs)
+    self.model = tf.keras.Model(inputs, outputs)
+
+    self.loss = tf.keras.losses.MeanSquaredError()
+    self.optimizer = tf.keras.optimizers.get(FLAGS.optimizer)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([BATCH_SIZE, INPUT_DIM], tf.float32),
+      tf.TensorSpec([BATCH_SIZE, OUTPUT_DIM], tf.float32)
+  ])
+  def train_on_batch(self, x, y_true):
+    with tf.GradientTape() as tape:
+      y_pred = self.model(x, training=True)
+      loss = self.loss(y_pred, y_pred)
+    variables = self.model.trainable_variables
+    gradients = tape.gradient(loss, variables)
+    self.optimizer.apply_gradients(zip(gradients, variables))
+    return loss
+
+
+class RegressionTrainingTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        RegressionTrainingModule, exported_names=["train_on_batch"])
+
+  def test_train_on_batch(self):
+
+    def train_on_batch(module):
+      x, y = get_linear_data()
+      module.train_on_batch(x, y)
+
+    self.compare_backends(train_on_batch, self._modules)
+
+
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  RegressionTrainingModule.__name__ = os.path.join(
+      RegressionTrainingModule.__name__, FLAGS.optimizer)
+  tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)