E2E infrastructure/testing for fft. The kernel is currently empty, will be implemented in follow-up cl. (#3558)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD index 97ede05..ff88406 100644 --- a/integrations/tensorflow/e2e/BUILD +++ b/integrations/tensorflow/e2e/BUILD
@@ -67,6 +67,7 @@ "einsum_dynamic_test.py", "einsum_static_test.py", "einsum_vector_test.py", + "fft_test.py", "finite_test.py", "gather_test.py", "mandelbrot_test.py", @@ -84,6 +85,7 @@ "einsum_dynamic_test.py", "einsum_static_test.py", "einsum_vector_test.py", + "fft_test.py", # TODO(natashaknk): Get this working after kernel is in. "mandelbrot_test.py", # TODO(silvasean): Get this working on IREE. "ring_buffer_test.py", # TODO(b/148747011) "strings_test.py", @@ -98,6 +100,7 @@ "einsum_dynamic_test.py", "einsum_static_test.py", "einsum_vector_test.py", + "fft_test.py", # TODO(natashaknk): Get this working after kernel is in. "fill_test.py", # TODO(jennik): Get this test working on IREE. "logical_ops_test.py", "mandelbrot_test.py", # TODO(silvasean): Get this working on IREE. @@ -118,6 +121,7 @@ "einsum_dynamic_test.py", "einsum_static_test.py", "einsum_vector_test.py", + "fft_test.py", # TODO(natashaknk): Get this working after kernel is in. "fill_test.py", # TODO(jennik): Get this test working on IREE. "logical_ops_test.py", "mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
diff --git a/integrations/tensorflow/e2e/fft_test.py b/integrations/tensorflow/e2e/fft_test.py new file mode 100644 index 0000000..590bff2 --- /dev/null +++ b/integrations/tensorflow/e2e/fft_test.py
@@ -0,0 +1,75 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import app +import numpy as np +from pyiree.tf.support import tf_test_utils +import tensorflow.compat.v2 as tf + + +class FftModule(tf.Module): + # TODO(natashaknk) when multiple outputs are supported, make into one test. + @tf.function(input_signature=[ + tf.TensorSpec([4], tf.float32), + tf.TensorSpec([4], tf.float32) + ]) + def fft_real(self, real_array, imag_array): + complex_in = tf.complex(real_array, imag_array) + complex_out = tf.signal.fft(complex_in) + return tf.math.real(complex_out) + + @tf.function(input_signature=[ + tf.TensorSpec([4], tf.float32), + tf.TensorSpec([4], tf.float32) + ]) + def fft_imag(self, real_array, imag_array): + complex_in = tf.complex(real_array, imag_array) + complex_out = tf.signal.fft(complex_in) + return tf.math.imag(complex_out) + + +class FftTest(tf_test_utils.TracedModuleTestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(FftModule) + + def test_fft_real(self): + + def fft_real(module): + real_array = np.array([9., 1., 4.5, -0.3], dtype=np.float32) + imag_array = np.array([0., -1., 17.7, 10.], dtype=np.float32) + module.fft_real(real_array, imag_array) + + self.compare_backends(fft_real, self._modules) + + def test_fft_imag(self): + + def fft_imag(module): + real_array = np.array([9., 1., 4.5, -0.3], dtype=np.float32) + imag_array = np.array([0., -1., 17.7, 10.], dtype=np.float32) + module.fft_imag(real_array, imag_array) + + self.compare_backends(fft_imag, self._modules) + + +def main(argv): + del argv # Unused + if hasattr(tf, 'enable_v2_behavior'): + tf.enable_v2_behavior() + tf.test.main() + + +if __name__ == '__main__': + app.run(main)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp index 895a8e5..ef9238e 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -267,6 +267,16 @@ getTypedTypeStr(op.dst_type()); } }; + +class VMLAFftImportOpConversion + : public VMLAImportOpConversion<IREE::VMLA::FftOp> { + public: + using VMLAImportOpConversion<IREE::VMLA::FftOp>::VMLAImportOpConversion; + + std::string getImportSuffix(IREE::VMLA::FftOp op) const override { + return std::string(".") + getTypedTypeStr(op.real_element_type()); + } +}; } // namespace void populateVMLAToVMPatterns(MLIRContext *context, @@ -343,6 +353,8 @@ context, importSymbols, typeConverter, "vmla.batch.matmul"); patterns.insert<VMLAConvImportOpConversion>(context, importSymbols, typeConverter, "vmla.conv"); + patterns.insert<VMLAFftImportOpConversion>(context, importSymbols, + typeConverter, "vmla.fft"); VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceSumOp, "vmla.reduce.sum"); VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir index ff575b0..2302fe1 100644 --- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir +++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -347,6 +347,12 @@ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %dst : !vm.ref<!vmla.buffer>) +vm.import @fft.f32( + %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ..., + %imag_src : !vm.ref<!vmla.buffer>, %imag_src_shape : i32 ..., + %real_dst : !vm.ref<!vmla.buffer>, + %imag_dst : !vm.ref<!vmla.buffer>) + //===----------------------------------------------------------------------===// // VMLA Ops: conversion //===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h index ba5b8bc..1d84e46 100644 --- a/iree/hal/vmla/op_kernels.h +++ b/iree/hal/vmla/op_kernels.h
@@ -174,6 +174,15 @@ absl::Span<int32_t> dst_buffer, ShapeSpan src_shape); }; +struct Fft { + template <typename T> + static Status Execute(absl::Span<const T> real_src_buffer, + absl::Span<const T> imag_src_buffer, + absl::Span<T> real_dst_buffer, + absl::Span<T> imag_dst_buffer, ShapeSpan real_src_shape, + ShapeSpan imag_src_shape); +}; + struct Broadcast { template <typename T> static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h index 0b4f904..166a2b8 100644 --- a/iree/hal/vmla/op_kernels_generic.h +++ b/iree/hal/vmla/op_kernels_generic.h
@@ -18,6 +18,7 @@ #include <algorithm> #include <cmath> #include <iostream> +#include <iterator> #include <numeric> #include "absl/container/flat_hash_set.h" @@ -541,6 +542,18 @@ } template <typename T> +Status Fft::Execute(absl::Span<const T> real_src_buffer, + absl::Span<const T> imag_src_buffer, + absl::Span<T> real_dst_buffer, + absl::Span<T> imag_dst_buffer, ShapeSpan real_src_shape, + ShapeSpan imag_src_shape) { + // TODO (natashaknk): implement + std::fill(real_dst_buffer.begin(), real_dst_buffer.end(), 1); + std::fill(imag_dst_buffer.begin(), imag_dst_buffer.end(), 2); + return OkStatus(); +} + +template <typename T> Status Broadcast::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc index 5852de0..1d791c0 100644 --- a/iree/hal/vmla/vmla_module.cc +++ b/iree/hal/vmla/vmla_module.cc
@@ -655,6 +655,19 @@ IREE_VMLA_SORT_OP(SortI32, int32_t); IREE_VMLA_SORT_OP(SortF32, float); + Status FftF32(const vm::ref<Buffer>& real_src, + iree_vmla_shape_t real_src_shape, + const vm::ref<Buffer>& imag_src, + iree_vmla_shape_t imag_src_shape, + const vm::ref<Buffer>& real_dst, + const vm::ref<Buffer>& imag_dst) { + IREE_TRACE_SCOPE0("VMLAModuleState::FftF32"); + IREE_RETURN_IF_ERROR(kernels::Fft::Execute<float>( + real_src->As<float>(), imag_src->As<float>(), real_dst->As<float>(), + imag_dst->As<float>(), real_src_shape, imag_src_shape)); + return OkStatus(); + } + //===--------------------------------------------------------------------===// // VMLA Ops: conversion //===--------------------------------------------------------------------===// @@ -987,6 +1000,7 @@ vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16), vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32), vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32), + vm::MakeNativeFunction("fft.f32", &VMLAModuleState::FftF32), vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32), vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),