|  | # Copyright 2019 Google LLC | 
|  | # | 
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | # you may not use this file except in compliance with the License. | 
|  | # You may obtain a copy of the License at | 
|  | # | 
|  | #      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | # Unless required by applicable law or agreed to in writing, software | 
|  | # distributed under the License is distributed on an "AS IS" BASIS, | 
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | # See the License for the specific language governing permissions and | 
|  | # limitations under the License. | 
|  |  | 
|  | from __future__ import absolute_import | 
|  | from __future__ import division | 
|  | from __future__ import print_function | 
|  |  | 
|  | import os | 
|  | from pyiree import tf_test_utils | 
|  | import tensorflow.compat.v2 as tf | 
|  |  | 
|  | # TODO(silvasean): Get this working on IREE. | 
|  | os.environ["IREE_TEST_BACKENDS"] = "tf" | 
|  |  | 
|  |  | 
|  | def complex_add(a_re, a_im, b_re, b_im): | 
|  | return a_re + b_re, a_im + b_im | 
|  |  | 
|  |  | 
|  | def complex_mul(a_re, a_im, b_re, b_im): | 
|  | c_re = a_re * b_re - a_im * b_im | 
|  | c_im = a_re * b_im + a_im * b_re | 
|  | return c_re, c_im | 
|  |  | 
|  |  | 
|  | # This is a fun but quite interesting example because the return value and most | 
|  | # of the interior computations are dynamically shaped. | 
|  | class MandelbrotModule(tf.Module): | 
|  |  | 
|  | @tf.function(input_signature=[ | 
|  | tf.TensorSpec([], tf.float32), | 
|  | tf.TensorSpec([], tf.float32), | 
|  | tf.TensorSpec([], tf.float32), | 
|  | tf.TensorSpec([], tf.int32), | 
|  | tf.TensorSpec([], tf.int32) | 
|  | ]) | 
|  | def calculate(self, center_re, center_im, view_size, view_pixels, | 
|  | num_iterations): | 
|  | """Calculates an image which represents the Mandelbrot set. | 
|  |  | 
|  | Args: | 
|  | center_re: The center point of the view (real part). | 
|  | center_im: The center point of the view (imaginary part). | 
|  | view_size: The view will display a square with this size. | 
|  | view_pixels: The returned image will be a square with this many pixels on | 
|  | a side. | 
|  | num_iterations: The number of iterations to use for determining escape. | 
|  |  | 
|  | Returns: | 
|  | A tensor of pixels with shape [view_size, view_size] which represents | 
|  | the mandelbrot set. | 
|  | """ | 
|  | re_min = center_re - view_size / 2. | 
|  | re_max = center_re + view_size / 2. | 
|  | im_min = center_im - view_size / 2. | 
|  | im_max = center_im + view_size / 2. | 
|  | re_coords = tf.linspace(re_min, re_max, view_pixels) | 
|  | im_coords = tf.linspace(im_min, im_max, view_pixels) | 
|  |  | 
|  | # Generate flat list of real and imaginary parts of the points to test. | 
|  | # This requires taking all pairs of re_coords and im_coords, which we | 
|  | # do by broadcasting into a 2d matrix (real part is broadcasted "vertically" | 
|  | # and imaginary part is broadcasted "horizontally"). | 
|  | # We use a Nx1 * 1xN -> NxN matmul to do the broadcast. | 
|  | c_re = tf.reshape( | 
|  | tf.matmul( | 
|  | tf.ones([view_pixels, 1]), tf.reshape(re_coords, [1, view_pixels])), | 
|  | [-1]) | 
|  | c_im = tf.reshape( | 
|  | tf.matmul( | 
|  | tf.reshape(im_coords, [view_pixels, 1]), tf.ones([1, view_pixels])), | 
|  | [-1]) | 
|  |  | 
|  | z_re = tf.zeros_like(c_re) | 
|  | z_im = tf.zeros_like(c_im) | 
|  | for _ in range(num_iterations): | 
|  | square_re, square_im = complex_mul(z_re, z_im, z_re, z_im) | 
|  | z_re, z_im = complex_add(square_re, square_im, c_re, c_im) | 
|  |  | 
|  | # Calculate if the points are in the set (that is, if their orbit under the | 
|  | # recurrence relationship has diverged). | 
|  | z_abs = tf.sqrt(z_re**2 + z_im**2) | 
|  | z_abs = tf.where(tf.math.is_nan(z_abs), 100. * tf.ones_like(z_abs), z_abs) | 
|  | in_the_set = tf.where(z_abs > 50., tf.ones_like(z_abs), | 
|  | tf.zeros_like(z_abs)) | 
|  | # Return an image | 
|  | return tf.reshape(in_the_set, shape=[view_pixels, view_pixels]) | 
|  |  | 
|  |  | 
|  | @tf_test_utils.compile_modules(mandelbrot=MandelbrotModule) | 
|  | class MandelbrotTest(tf_test_utils.SavedModelTestCase): | 
|  |  | 
|  | def test_mandelbrot(self): | 
|  | mandelbrot = self.modules.mandelbrot.all | 
|  |  | 
|  | # Basic view of the entire set. | 
|  | pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100) | 
|  | pixels.assert_all_close() | 
|  |  | 
|  | # This is a much more detailed view, so more iterations are needed. | 
|  | pixels = mandelbrot.calculate(-0.7436447860, 0.1318252536, 0.0000029336, | 
|  | 400, 10000) | 
|  | pixels.assert_all_close() | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | if hasattr(tf, "enable_v2_behavior"): | 
|  | tf.enable_v2_behavior() | 
|  | tf.test.main() |