blob: 9bbe33bb75f700c57581f871593c4e5fb0ca2b3e [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from pyiree import tf_test_utils
import tensorflow.compat.v2 as tf
# TODO(silvasean): Get this working on IREE.
os.environ["IREE_TEST_BACKENDS"] = "tf"
def complex_add(a_re, a_im, b_re, b_im):
return a_re + b_re, a_im + b_im
def complex_mul(a_re, a_im, b_re, b_im):
c_re = a_re * b_re - a_im * b_im
c_im = a_re * b_im + a_im * b_re
return c_re, c_im
# This is a fun but quite interesting example because the return value and most
# of the interior computations are dynamically shaped.
class MandelbrotModule(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec([], tf.float32),
tf.TensorSpec([], tf.float32),
tf.TensorSpec([], tf.float32),
tf.TensorSpec([], tf.int32),
tf.TensorSpec([], tf.int32)
])
def calculate(self, center_re, center_im, view_size, view_pixels,
num_iterations):
"""Calculates an image which represents the Mandelbrot set.
Args:
center_re: The center point of the view (real part).
center_im: The center point of the view (imaginary part).
view_size: The view will display a square with this size.
view_pixels: The returned image will be a square with this many pixels on
a side.
num_iterations: The number of iterations to use for determining escape.
Returns:
A tensor of pixels with shape [view_size, view_size] which represents
the mandelbrot set.
"""
re_min = center_re - view_size / 2.
re_max = center_re + view_size / 2.
im_min = center_im - view_size / 2.
im_max = center_im + view_size / 2.
re_coords = tf.linspace(re_min, re_max, view_pixels)
im_coords = tf.linspace(im_min, im_max, view_pixels)
# Generate flat list of real and imaginary parts of the points to test.
# This requires taking all pairs of re_coords and im_coords, which we
# do by broadcasting into a 2d matrix (real part is broadcasted "vertically"
# and imaginary part is broadcasted "horizontally").
# We use a Nx1 * 1xN -> NxN matmul to do the broadcast.
c_re = tf.reshape(
tf.matmul(
tf.ones([view_pixels, 1]), tf.reshape(re_coords, [1, view_pixels])),
[-1])
c_im = tf.reshape(
tf.matmul(
tf.reshape(im_coords, [view_pixels, 1]), tf.ones([1, view_pixels])),
[-1])
z_re = tf.zeros_like(c_re)
z_im = tf.zeros_like(c_im)
for _ in range(num_iterations):
square_re, square_im = complex_mul(z_re, z_im, z_re, z_im)
z_re, z_im = complex_add(square_re, square_im, c_re, c_im)
# Calculate if the points are in the set (that is, if their orbit under the
# recurrence relationship has diverged).
z_abs = tf.sqrt(z_re**2 + z_im**2)
z_abs = tf.where(tf.math.is_nan(z_abs), 100. * tf.ones_like(z_abs), z_abs)
in_the_set = tf.where(z_abs > 50., tf.ones_like(z_abs),
tf.zeros_like(z_abs))
# Return an image
return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
@tf_test_utils.compile_modules(mandelbrot=MandelbrotModule)
class MandelbrotTest(tf_test_utils.SavedModelTestCase):
def test_mandelbrot(self):
mandelbrot = self.modules.mandelbrot.all
# Basic view of the entire set.
pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100)
pixels.assert_all_close()
# This is a much more detailed view, so more iterations are needed.
pixels = mandelbrot.calculate(-0.7436447860, 0.1318252536, 0.0000029336,
400, 10000)
pixels.assert_all_close()
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
tf.test.main()