E2E infrastructure/testing for fft. The kernel is currently empty, will be implemented in follow-up cl. (#3558)

diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 97ede05..ff88406 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -67,6 +67,7 @@
     "einsum_dynamic_test.py",
     "einsum_static_test.py",
     "einsum_vector_test.py",
+    "fft_test.py",
     "finite_test.py",
     "gather_test.py",
     "mandelbrot_test.py",
@@ -84,6 +85,7 @@
     "einsum_dynamic_test.py",
     "einsum_static_test.py",
     "einsum_vector_test.py",
+    "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
     "ring_buffer_test.py",  # TODO(b/148747011)
     "strings_test.py",
@@ -98,6 +100,7 @@
     "einsum_dynamic_test.py",
     "einsum_static_test.py",
     "einsum_vector_test.py",
+    "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
@@ -118,6 +121,7 @@
     "einsum_dynamic_test.py",
     "einsum_static_test.py",
     "einsum_vector_test.py",
+    "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
diff --git a/integrations/tensorflow/e2e/fft_test.py b/integrations/tensorflow/e2e/fft_test.py
new file mode 100644
index 0000000..590bff2
--- /dev/null
+++ b/integrations/tensorflow/e2e/fft_test.py
@@ -0,0 +1,75 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class FftModule(tf.Module):
+  # TODO(natashaknk) when multiple outputs are supported, make into one test.
+  @tf.function(input_signature=[
+      tf.TensorSpec([4], tf.float32),
+      tf.TensorSpec([4], tf.float32)
+  ])
+  def fft_real(self, real_array, imag_array):
+    complex_in = tf.complex(real_array, imag_array)
+    complex_out = tf.signal.fft(complex_in)
+    return tf.math.real(complex_out)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([4], tf.float32),
+      tf.TensorSpec([4], tf.float32)
+  ])
+  def fft_imag(self, real_array, imag_array):
+    complex_in = tf.complex(real_array, imag_array)
+    complex_out = tf.signal.fft(complex_in)
+    return tf.math.imag(complex_out)
+
+
+class FftTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(FftModule)
+
+  def test_fft_real(self):
+
+    def fft_real(module):
+      real_array = np.array([9., 1., 4.5, -0.3], dtype=np.float32)
+      imag_array = np.array([0., -1., 17.7, 10.], dtype=np.float32)
+      module.fft_real(real_array, imag_array)
+
+    self.compare_backends(fft_real, self._modules)
+
+  def test_fft_imag(self):
+
+    def fft_imag(module):
+      real_array = np.array([9., 1., 4.5, -0.3], dtype=np.float32)
+      imag_array = np.array([0., -1., 17.7, 10.], dtype=np.float32)
+      module.fft_imag(real_array, imag_array)
+
+    self.compare_backends(fft_imag, self._modules)
+
+
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 895a8e5..ef9238e 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -267,6 +267,16 @@
            getTypedTypeStr(op.dst_type());
   }
 };
+
+class VMLAFftImportOpConversion
+    : public VMLAImportOpConversion<IREE::VMLA::FftOp> {
+ public:
+  using VMLAImportOpConversion<IREE::VMLA::FftOp>::VMLAImportOpConversion;
+
+  std::string getImportSuffix(IREE::VMLA::FftOp op) const override {
+    return std::string(".") + getTypedTypeStr(op.real_element_type());
+  }
+};
 }  // namespace
 
 void populateVMLAToVMPatterns(MLIRContext *context,
@@ -343,6 +353,8 @@
       context, importSymbols, typeConverter, "vmla.batch.matmul");
   patterns.insert<VMLAConvImportOpConversion>(context, importSymbols,
                                               typeConverter, "vmla.conv");
+  patterns.insert<VMLAFftImportOpConversion>(context, importSymbols,
+                                             typeConverter, "vmla.fft");
 
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceSumOp, "vmla.reduce.sum");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index ff575b0..2302fe1 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -347,6 +347,12 @@
   %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
   %dst : !vm.ref<!vmla.buffer>)
 
+vm.import @fft.f32(
+  %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ...,
+  %imag_src : !vm.ref<!vmla.buffer>, %imag_src_shape : i32 ...,
+  %real_dst : !vm.ref<!vmla.buffer>,
+  %imag_dst : !vm.ref<!vmla.buffer>)
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: conversion
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index ba5b8bc..1d84e46 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -174,6 +174,15 @@
                         absl::Span<int32_t> dst_buffer, ShapeSpan src_shape);
 };
 
+struct Fft {
+  template <typename T>
+  static Status Execute(absl::Span<const T> real_src_buffer,
+                        absl::Span<const T> imag_src_buffer,
+                        absl::Span<T> real_dst_buffer,
+                        absl::Span<T> imag_dst_buffer, ShapeSpan real_src_shape,
+                        ShapeSpan imag_src_shape);
+};
+
 struct Broadcast {
   template <typename T>
   static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index 0b4f904..166a2b8 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -18,6 +18,7 @@
 #include <algorithm>
 #include <cmath>
 #include <iostream>
+#include <iterator>
 #include <numeric>
 
 #include "absl/container/flat_hash_set.h"
@@ -541,6 +542,18 @@
 }
 
 template <typename T>
+Status Fft::Execute(absl::Span<const T> real_src_buffer,
+                    absl::Span<const T> imag_src_buffer,
+                    absl::Span<T> real_dst_buffer,
+                    absl::Span<T> imag_dst_buffer, ShapeSpan real_src_shape,
+                    ShapeSpan imag_src_shape) {
+  // TODO (natashaknk): implement
+  std::fill(real_dst_buffer.begin(), real_dst_buffer.end(), 1);
+  std::fill(imag_dst_buffer.begin(), imag_dst_buffer.end(), 2);
+  return OkStatus();
+}
+
+template <typename T>
 Status Broadcast::Execute(absl::Span<const T> src_buffer,
                           absl::Span<T> dst_buffer) {
   for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 5852de0..1d791c0 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -655,6 +655,19 @@
   IREE_VMLA_SORT_OP(SortI32, int32_t);
   IREE_VMLA_SORT_OP(SortF32, float);
 
+  Status FftF32(const vm::ref<Buffer>& real_src,
+                iree_vmla_shape_t real_src_shape,
+                const vm::ref<Buffer>& imag_src,
+                iree_vmla_shape_t imag_src_shape,
+                const vm::ref<Buffer>& real_dst,
+                const vm::ref<Buffer>& imag_dst) {
+    IREE_TRACE_SCOPE0("VMLAModuleState::FftF32");
+    IREE_RETURN_IF_ERROR(kernels::Fft::Execute<float>(
+        real_src->As<float>(), imag_src->As<float>(), real_dst->As<float>(),
+        imag_dst->As<float>(), real_src_shape, imag_src_shape));
+    return OkStatus();
+  }
+
   //===--------------------------------------------------------------------===//
   // VMLA Ops: conversion
   //===--------------------------------------------------------------------===//
@@ -987,6 +1000,7 @@
     vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
     vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
     vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
+    vm::MakeNativeFunction("fft.f32", &VMLAModuleState::FftF32),
     vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
 
     vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),