Replace usages of Tensorflow DistributionStrategy method experimental_run_v2 with run. (#2332)
Also re-ran the notebook with TensorFlow 2.5.0 (on Windows).
PiperOrigin-RevId: 318519824
Co-authored-by: kfranko <kfranko@google.com>
diff --git a/colab/mnist_tensorflow.ipynb b/colab/mnist_tensorflow.ipynb
index df7c055..b791a3b 100644
--- a/colab/mnist_tensorflow.ipynb
+++ b/colab/mnist_tensorflow.ipynb
@@ -51,7 +51,7 @@
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
- "execution_count": 0,
+ "execution_count": 1,
"outputs": []
},
{
@@ -95,7 +95,7 @@
"base_uri": "https://localhost:8080/",
"height": 51
},
- "outputId": "d471e3cd-ac1d-42f0-c344-4490a51f3a54"
+ "outputId": "fe0d703a-2ad7-4d14-9aef-c69b4c342a16"
},
"source": [
"import os\n",
@@ -119,8 +119,8 @@
{
"output_type": "stream",
"text": [
- "TensorFlow version: 2.1.0-dev20191126\n",
- "Numpy version: 1.17.4\n"
+ "TensorFlow version: 2.5.0-dev20200626\n",
+ "Numpy version: 1.18.4\n"
],
"name": "stdout"
}
@@ -132,11 +132,11 @@
"id": "43BH_9YcsGs8",
"colab_type": "code",
"cellView": "form",
- "outputId": "fc788105-8739-45d6-8df2-e245d5ff6cb1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
- }
+ },
+ "outputId": "46a560cb-5947-4cfa-f71e-9065bf2c07ab"
},
"source": [
"#@title Notebook settings { run: \"auto\" }\n",
@@ -170,7 +170,7 @@
{
"output_type": "stream",
"text": [
- "Using IREE compiler backend 'vulkan-spirv' and runtime driver 'vulkan'\n"
+ "Using IREE compiler backend 'vulkan-spirv' and runtime driver 'vulkan'\r\n"
],
"name": "stdout"
}
@@ -194,11 +194,11 @@
"id": "GXZIrReTbTHN",
"colab_type": "code",
"cellView": "form",
- "outputId": "5faf02eb-aab4-4e4c-9df9-e7665ce5a810",
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 506
- }
+ "height": 486
+ },
+ "outputId": "9c01fab1-f8cb-4a63-fff1-6b82a9b6b49d"
},
"source": [
"#@title Load MNIST dataset, setup training and evaluation\n",
@@ -258,8 +258,8 @@
{
"output_type": "stream",
"text": [
- "Loaded MNIST dataset!\n",
- "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ('/device:CPU:0',), variable_device = '/device:CPU:0'\n",
+ "Loaded MNIST dataset!\r\n",
+ "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:CPU:0'], variable_device = '/job:localhost/replica:0/task:0/device:CPU:0'\n",
"Configured data for training and evaluation!\n",
" sample shape: (28, 28, 1)\n",
" training samples: 60000\n",
@@ -278,7 +278,7 @@
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANzUlEQVR4nO3df6zV9X3H8dcL5IdFVBiMMSRaLMRiF6G9oXV1m8a1s/xRbLK5ks5hY3O7rG5tQtIat6Q2/RGzVN2WNV1oJaWLP+L8UVlqOpHaOFuCXhwFhLZQhyvsChJuB24ZcK/v/XG/NFe93++5nPM9P+T9fCQ355zv+3y/33eOvvie8/2c7/k4IgTg7Dep2w0A6AzCDiRB2IEkCDuQBGEHkjinkzub6mkxXTM6uUsglf/T/+hknPB4tZbCbvs6SX8nabKkb0bEHVXPn64Zeq+vbWWXACpsjc2ltabfxtueLOlrkj4kaamk1baXNrs9AO3Vymf2FZL2RcSLEXFS0gOSVtXTFoC6tRL2BZJ+MebxgWLZ69jutz1ge+CUTrSwOwCtaPvZ+IhYFxF9EdE3RdPavTsAJVoJ+0FJC8c8vqhYBqAHtRL25yQttv1221MlfVTSxnraAlC3pofeImLY9i2S/lWjQ2/rI+KF2joDUKuWxtkj4nFJj9fUC4A24uuyQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dGfkkZz9n/pysr6yPTyyTnnXv5K5bpbrni4qZ5Ou/T7H6+sz3z23NLavL//UUv7xpnhyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gOGvru4sr5r2T+0bd+nyofoJ+Qn13yzsn5v3/zS2oObfq9y3ZE9e5vqCePjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gGNxtF/uOyBtu37H3+5qLJ+15YPVNYvubj6evgnlj5SWf/YzMHS2pdvmlO57qLPMc5ep5bCbnu/pOOSRiQNR0RfHU0BqF8dR/ZrIuJIDdsB0EZ8ZgeSaDXsIekJ29ts94/3BNv9tgdsD5zSiRZ3B6BZrb6NvyoiDtr+dUmbbP8kIp4e+4SIWCdpnSSd79ktXnYBoFktHdkj4mBxe1jSo5JW1NEUgPo1HXbbM2zPPH1f0gcl7aqrMQD1auVt/DxJj9o+vZ37IuJ7tXT1FjN87Xsq69+/4msNtjClsvq3Q0sq60/9ccWI538drlx3ydBAZX3S9OmV9a9s/a3K+m1zdpbWhmcNV66LejUd9oh4UdIVNfYCoI0YegOSIOxAEoQdSIKwA0kQdiAJLnGtwasLplbWJzX4N7XR0NoPPlw9vDXy4k8r663Y94XllfX7Zt/ZYAvTSisXfY9jTSfxagNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyz1+DCb2+prP/hwJ9U1j10rLI+PLj/DDuqzydWPllZP29S+Tg6egtHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2DhjZ/bNut1Bq/5evrKzffOFXG2yh+qem1w6+r7Q288k9leuONNgzzgxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2s9wvb6weR//hn1aPo18wqXocfcuJyZX17V8q/935c489W7ku6tXwyG57ve3DtneNWTbb9ibbe4vbWe1tE0CrJvI2/luSrnvDslslbY6IxZI2F48B9LCGYY+IpyUdfcPiVZI2FPc3SLq+5r4A1KzZz+zzImKwuP+ypHllT7TdL6lfkqbrbU3uDkCrWj4bHxEhKSrq6yKiLyL6plRM8gegvZoN+yHb8yWpuD1cX0sA2qHZsG+UtKa4v0bSY/W0A6BdGn5mt32/pKslzbF9QNLnJd0h6UHbN0t6SdIN7WwSzTvy7tJPWJIaj6M3suYHn6isL/kOY+m9omHYI2J1SenamnsB0EZ8XRZIgrADSRB2IAnCDiRB2IEkuMT1LHBy08WltS2X3dlg7eqhtyu2rKmsv3Ptzyvr/Bx07+DIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM7+FnDOoksq6198xz+X1mY1uIR124nqfV/8xeqR8pGhoeoNoGdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnfwu49MGDlfXlU5v/N3v15j+rrC/58XNNbxu9hSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHsPGFpzZWX9C/Ma/fb7tNLKmv2/X7nmOz+7r7LO776fPRoe2W2vt33Y9q4xy263fdD29uJvZXvbBNCqibyN/5ak68ZZfndELCv+Hq+3LQB1axj2iHha0tEO9AKgjVo5QXeL7R3F2/xZZU+y3W97wPbAKTX4wTMAbdNs2L8u6VJJyyQNSio9gxQR6yKiLyL6plScSALQXk2FPSIORcRIRLwm6RuSVtTbFoC6NRV22/PHPPyIpF1lzwXQGxqOs9u+X9LVkubYPiDp85Kutr1MUkjaL+mTbezxLe+cBb9ZWf+dv9xaWT9vUvMff7bsfkdlfckQ16tn0TDsEbF6nMX3tKEXAG3E12WBJAg7kARhB5Ig7EAShB1IgktcO2DPbQsr69/5jX9pafvX7Pyj0hqXsOI0juxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7B2w7cN3N3hGa7/gc8Gfv1ZaGx4aamnbOHtwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwucmndBaW3KyQUd7OTNRl45UlqLE9XTgXla9fcPJs+d01RPkjQy98LK+t61U5ve9kTEiEtrl/1Fg98gOHasqX1yZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwt896H13W6h1G//+3iTAI86cuj8ynVnzT1eWd/6nvua6qnXLf3rWyrriz67pantNjyy215o+ynbu22/YPvTxfLZtjfZ3lvczmqqAwAdMZG38cOS1kbEUknvk/Qp20sl3Sppc0QslrS5eAygRzUMe0QMRsTzxf3jkvZIWiBplaQNxdM2SLq+XU0CaN0ZfWa3fYmk5ZK2SpoXEYNF6WVJ80rW6ZfUL0nT9bZm+wTQogmfjbd9nqSHJX0mIl73TfyICEkx3noRsS4i+iKib0qLP6wIoHkTCrvtKRoN+r0R8Uix+JDt+UV9vqTD7WkRQB0avo23bUn3SNoTEXeNKW2UtEbSHcXtY23p8CywavfHKuub3/VQhzrpvB8tv79r+/7fOFlaOxXlP789ESt33FRZ/+/tzV9+u+CZ4abXrTKRz+zvl3SjpJ22txfLbtNoyB+0fbOklyTd0JYOAdSiYdgj4hlJZVfaX1tvOwDaha/LAkkQdiAJwg4kQdiBJAg7kASXuHbAuX/wH5X1y79SfUljtPG/0szLjlbW23kZ6eX/9vHKevznjJa2v+ihV8uLz+5saduztLelejdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDz6IzOdcb5nx3vNhXJAu2yNzToWR8e9SpUjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTRMOy2F9p+yvZu2y/Y/nSx/HbbB21vL/5Wtr9dAM2ayPQDw5LWRsTztmdK2mZ7U1G7OyK+2r72ANRlIvOzD0oaLO4ft71H0oJ2NwagXmf0md32JZKWS9paLLrF9g7b623PKlmn3/aA7YFTOtFSswCaN+Gw2z5P0sOSPhMRxyR9XdKlkpZp9Mh/53jrRcS6iOiLiL4pmlZDywCaMaGw256i0aDfGxGPSFJEHIqIkYh4TdI3JK1oX5sAWjWRs/GWdI+kPRFx15jl88c87SOSdtXfHoC6TORs/Psl3Shpp+3txbLbJK22vUxSSNov6ZNt6RBALSZyNv4ZSeP9DvXj9bcDoF34Bh2QBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0Tndma/IumlMYvmSDrSsQbOTK/21qt9SfTWrDp7uzgi5o5X6GjY37RzeyAi+rrWQIVe7a1X+5LorVmd6o238UAShB1IotthX9fl/Vfp1d56tS+J3prVkd66+pkdQOd0+8gOoEMIO5BEV8Ju+zrbP7W9z/at3eihjO39tncW01APdLmX9bYP2941Ztls25ts7y1ux51jr0u99cQ03hXTjHf1tev29Ocd/8xue7Kkn0n6gKQDkp6TtDoidne0kRK290vqi4iufwHD9u9KelXStyPiXcWyv5F0NCLuKP6hnBURn+uR3m6X9Gq3p/EuZiuaP3aacUnXS7pJXXztKvq6QR143bpxZF8haV9EvBgRJyU9IGlVF/roeRHxtKSjb1i8StKG4v4Gjf7P0nElvfWEiBiMiOeL+8clnZ5mvKuvXUVfHdGNsC+Q9Isxjw+ot+Z7D0lP2N5mu7/bzYxjXkQMFvdfljSvm82Mo+E03p30hmnGe+a1a2b681Zxgu7NroqId0v6kKRPFW9Xe1KMfgbrpbHTCU3j3SnjTDP+K9187Zqd/rxV3Qj7QUkLxzy+qFjWEyLiYHF7WNKj6r2pqA+dnkG3uD3c5X5+pZem8R5vmnH1wGvXzenPuxH25yQttv1221MlfVTSxi708Sa2ZxQnTmR7hqQPqvemot4oaU1xf42kx7rYy+v0yjTeZdOMq8uvXdenP4+Ijv9JWqnRM/I/l/RX3eihpK9Fkn5c/L3Q7d4k3a/Rt3WnNHpu42ZJvyZps6S9kp6UNLuHevsnSTsl7dBosOZ3qberNPoWfYek7cXfym6/dhV9deR14+uyQBKcoAOSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJP4fcKgKSEIBgPIAAAAASUVORK5CYII=\n"
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANxUlEQVR4nO3de4xU93nG8ecBc7EwtqFgSjGygwOycSpDsiJx3YstN6nDH8GRckOJgyNHpGrcJhJSYrmV4igXWVVst1WjVCRGIZUvcn2JqWIlJsSR6wRhLy4BbJJAXOpgVmDEpuBWhd312z/2UG3wzpll5sycMe/3I41m5rxzznk18OyZmd+c+TkiBODsN6nuBgB0B2EHkiDsQBKEHUiCsANJnNPNnU31tJiuGd3cJZDK/+q/dTJOeLxaW2G3fYOkv5c0WdK3IuLOssdP1wy909e3s0sAJbbFloa1ll/G254s6euS3itpqaTVtpe2uj0AndXOe/YVkvZFxEsRcVLSg5JWVdMWgKq1E/YFkn495v6BYtlvsb3Wdr/t/iGdaGN3ANrRTtjH+xDgDd+9jYj1EdEXEX1TNK2N3QFoRzthPyBp4Zj7F0s62F47ADqlnbA/J2mx7bfYnirpI5I2VdMWgKq1PPQWEcO2b5X0A40OvW2IiBcq6wxApdoaZ4+IJyQ9UVEvADqIr8sCSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEV39KGq3Z/+WrS+sj0xtPzjn3yldL19161SMt9XTKZT/6RGl95rPnNqzN+4eftrVvnBmO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsPWDwe4tL67uX/WPH9j3UeIh+Qn5+3bdK6/f1zW9Ye2jzn5SuO7Jnb0s9YXwc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZu6DZOPpPlj3YsX3/028Wldbv3vru0vqll5SfD//k0kdL6x+dOdCw9pWb55Suu+jzjLNXqa2w294v6bikEUnDEdFXRVMAqlfFkf26iDhSwXYAdBDv2YEk2g17SHrS9nbba8d7gO21tvtt9w/pRJu7A9Cqdl/GXxMRB21fJGmz7Z9HxNNjHxAR6yWtl6TzPbvN0y4AtKqtI3tEHCyuD0t6TNKKKpoCUL2Ww257hu2Zp25Leo+k3VU1BqBa7byMnyfpMduntnN/RHy/kq7eZIavf0dp/UdXfb3JFqaUVv9ucElp/akPl4x4Hjxcuu6Swf7S+qTp00vrX932+6X12+fsalgbnjVcui6q1XLYI+IlSVdV2AuADmLoDUiCsANJEHYgCcIOJEHYgSQ4xbUCry2YWlqf1ORvarOhtR+/r3x4a+SlX5TW27Hvi8tL6/fPvqvJFqY1rFz8fY413cSzDSRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5egQu/s7W0/oH+j5XWPXistD48sP8MO6rOJ1f+sLR+3qTG4+joLRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtm7YOTFX9bdQkP7v3J1af2WC7/WZAvlPzW9buBdDWszf7indN2RJnvGmeHIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5+lvvNTeXj6D/5ePk4+gWTysfRt56YXFrf8eXGvzt/7rFnS9dFtZoe2W1vsH3Y9u4xy2bb3mx7b3E9q7NtAmjXRF7Gf1vSDactu03SlohYLGlLcR9AD2sa9oh4WtLR0xavkrSxuL1R0o0V9wWgYq1+QDcvIgYkqbi+qNEDba+13W+7f0gnWtwdgHZ1/NP4iFgfEX0R0TelZJI/AJ3VatgP2Z4vScX14epaAtAJrYZ9k6Q1xe01kh6vph0AndJ0nN32A5KulTTH9gFJX5B0p6SHbN8i6WVJH+xkk2jdkbdHab3ZOHoza378ydL6ku8ylt4rmoY9IlY3KF1fcS8AOoivywJJEHYgCcIOJEHYgSQIO5AEp7ieBU5uvqRhbevldzVZu3zo7aqta0rrV6z7VWmdn4PuHRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtnfBM5ZdGlp/Utv/ZeGtVlNTmHd3uSXwi75UvlI+cjgYPkG0DM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzvwlc9tArpfXlU1v/m716y5+X1pf87LmWt43ewpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0HDK65urT+xXnNfvt9WsPKmv1/WrrmFZ/bV1rnd9/PHk2P7LY32D5se/eYZXfYfsX2juKysrNtAmjXRF7Gf1vSDeMsvycilhWXJ6ptC0DVmoY9Ip6WdLQLvQDooHY+oLvV9s7iZf6sRg+yvdZ2v+3+ITX5wTMAHdNq2L8h6TJJyyQNSGr4CVJErI+Ivojom1LyQRKAzmop7BFxKCJGIuJ1Sd+UtKLatgBUraWw254/5u77Je1u9FgAvaHpOLvtByRdK2mO7QOSviDpWtvLJIWk/ZI+1cEe3/TOWfB7pfU/+qttpfXzJrX+9mfri28trS8Z5Hz1LJqGPSJWj7P43g70AqCD+LoskARhB5Ig7EAShB1IgrADSXCKaxfsuX1haf27v/uvbW3/ul0fbFjjFFacwpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0Ltr/vniaPaO8XfC74i9cb1oYHB9vaNs4eHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2c8CQ/MuaFibcnJBFzt5o5FXjzSsxYny6cA8rfz7B5PnzmmpJ0kamXthaX3vuqktb3siYsQNa5f/ZZPfIDh2rKV9cmQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZz8LfO/hDXW30NAf/Pt4kwCPOnLo/NJ1Z809Xlrf9o77W+qp1y39m1tL64s+t7Wl7TY9stteaPsp23tsv2D7M8Xy2bY3295bXM9qqQMAXTGRl/HDktZFxBWS3iXp07aXSrpN0paIWCxpS3EfQI9qGvaIGIiI54vbxyXtkbRA0ipJG4uHbZR0Y6eaBNC+M/qAzvalkpZL2iZpXkQMSKN/ECRd1GCdtbb7bfcPqfy70AA6Z8Jht32epEckfTYiJvxN/IhYHxF9EdE3pc0fVgTQugmF3fYUjQb9voh4tFh8yPb8oj5f0uHOtAigCk2H3mxb0r2S9kTE3WNKmyStkXRncf14Rzo8C6x68aOl9S1ve7hLnXTfT5c/UNu+/ydONqwNReOf356IlTtvLq3/147WT79d8Mxwy+uWmcg4+zWSbpK0y/aOYtntGg35Q7ZvkfSypMaThAOoXdOwR8QzkhqdaX99te0A6BS+LgskQdiBJAg7kARhB5Ig7EASnOLaBef+2X+U1q/8avkpjdHBf6WZlx8trXfyNNIr/+0TpfV4eUZb21/08GuNi8/uamvbs7S3rXodOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKOiK7t7HzPjneaE+WATtkWW3Qsjo57lipHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiiadhtL7T9lO09tl+w/Zli+R22X7G9o7is7Hy7AFo1kekHhiWti4jnbc+UtN325qJ2T0R8rXPtAajKROZnH5A0UNw+bnuPpAWdbgxAtc7oPbvtSyUtl7StWHSr7Z22N9ie1WCdtbb7bfcP6URbzQJo3YTDbvs8SY9I+mxEHJP0DUmXSVqm0SP/XeOtFxHrI6IvIvqmaFoFLQNoxYTCbnuKRoN+X0Q8KkkRcSgiRiLidUnflLSic20CaNdEPo23pHsl7YmIu8csnz/mYe+XtLv69gBUZSKfxl8j6SZJu2zvKJbdLmm17WWSQtJ+SZ/qSIcAKjGRT+OfkTTe71A/UX07ADqFb9ABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeScER0b2f2q5L+c8yiOZKOdK2BM9OrvfVqXxK9tarK3i6JiLnjFboa9jfs3O6PiL7aGijRq731al8SvbWqW73xMh5IgrADSdQd9vU1779Mr/bWq31J9NaqrvRW63t2AN1T95EdQJcQdiCJWsJu+wbbv7C9z/ZtdfTQiO39tncV01D319zLBtuHbe8es2y27c229xbX486xV1NvPTGNd8k047U+d3VPf9719+y2J0v6paR3Szog6TlJqyPixa420oDt/ZL6IqL2L2DY/mNJr0n6TkS8rVj2t5KORsSdxR/KWRHx+R7p7Q5Jr9U9jXcxW9H8sdOMS7pR0s2q8bkr6etD6sLzVseRfYWkfRHxUkSclPSgpFU19NHzIuJpSUdPW7xK0sbi9kaN/mfpuga99YSIGIiI54vbxyWdmma81ueupK+uqCPsCyT9esz9A+qt+d5D0pO2t9teW3cz45gXEQPS6H8eSRfV3M/pmk7j3U2nTTPeM89dK9Oft6uOsI83lVQvjf9dExFvl/ReSZ8uXq5iYiY0jXe3jDPNeE9odfrzdtUR9gOSFo65f7GkgzX0Ma6IOFhcH5b0mHpvKupDp2bQLa4P19zP/+ulabzHm2ZcPfDc1Tn9eR1hf07SYttvsT1V0kckbaqhjzewPaP44ES2Z0h6j3pvKupNktYUt9dIerzGXn5Lr0zj3WiacdX83NU+/XlEdP0iaaVGP5H/laS/rqOHBn0tkvSz4vJC3b1JekCjL+uGNPqK6BZJvyNpi6S9xfXsHurtnyXtkrRTo8GaX1Nvf6jRt4Y7Je0oLivrfu5K+urK88bXZYEk+AYdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTxfy43Cn7d/BIFAAAAAElFTkSuQmCC\n"
},
"metadata": {
"tags": [],
@@ -289,7 +289,7 @@
"output_type": "stream",
"text": [
"\n",
- "Ground truth labels: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n"
+ "Ground truth labels: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\r\n"
],
"name": "stdout"
}
@@ -318,7 +318,7 @@
" model.add(tf.keras.layers.Dense(10, activation='softmax'))\n",
" return model"
],
- "execution_count": 0,
+ "execution_count": 5,
"outputs": []
},
{
@@ -327,11 +327,11 @@
"id": "7Gdxh7qWcPSO",
"colab_type": "code",
"cellView": "form",
- "outputId": "9886da4f-1916-4540-9f53-cc82f8ab113d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 374
- }
+ },
+ "outputId": "50b0aede-9a8f-4ce5-b340-783be5fbfc06"
},
"source": [
"#@title Train the Keras model\n",
@@ -364,7 +364,7 @@
" training_loss.update_state(loss)\n",
" training_accuracy.update_state(labels, logits)\n",
"\n",
- " strategy.experimental_run_v2(step_fn, args=(next(iterator),))\n",
+ " strategy.run(step_fn, args=(next(iterator),))\n",
"\n",
" @tf.function\n",
" def test_step(iterator):\n",
@@ -378,7 +378,7 @@
" test_loss.update_state(loss)\n",
" test_accuracy.update_state(labels, logits)\n",
"\n",
- " strategy.experimental_run_v2(step_fn, args=(next(iterator),))\n",
+ " strategy.run(step_fn, args=(next(iterator),))\n",
"\n",
" for epoch in range(0, num_epochs):\n",
" tf.print(\"Running epoch #%s\" % (epoch + 1))\n",
@@ -413,24 +413,24 @@
"text": [
"Constructed Keras MNIST model, training...\n",
"Running epoch #1\n",
- " Training loss: 0.733241, accuracy: 81.463333\n",
- " Test loss : 0.388786, accuracy: 89.590004\n",
+ " Training loss: 0.732439, accuracy: 81.403336\n",
+ " Test loss : 0.390855, accuracy: 89.490005\n",
"Running epoch #2\n",
- " Training loss: 0.362705, accuracy: 89.995003\n",
- " Test loss : 0.313008, accuracy: 91.320000\n",
+ " Training loss: 0.365308, accuracy: 89.811668\n",
+ " Test loss : 0.315630, accuracy: 91.119995\n",
"Running epoch #3\n",
- " Training loss: 0.308992, accuracy: 91.294998\n",
- " Test loss : 0.279223, accuracy: 92.150002\n",
+ " Training loss: 0.312111, accuracy: 91.129997\n",
+ " Test loss : 0.281829, accuracy: 92.040001\n",
"Running epoch #4\n",
- " Training loss: 0.278085, accuracy: 92.198334\n",
- " Test loss : 0.256378, accuracy: 92.629997\n",
+ " Training loss: 0.281028, accuracy: 92.038330\n",
+ " Test loss : 0.258432, accuracy: 92.629997\n",
"Running epoch #5\n",
- " Training loss: 0.255360, accuracy: 92.818329\n",
- " Test loss : 0.238477, accuracy: 93.099998\n",
+ " Training loss: 0.257909, accuracy: 92.753334\n",
+ " Test loss : 0.240058, accuracy: 93.229996\n",
"Completed training!\n",
"\n",
"Sample prediction:\n",
- "[1.24276197 0.00645936839 94.5438538 1.92194772 5.14127169e-06 1.25558567 0.77046895 1.55102152e-05 0.25887084 2.61111491e-05]\n",
+ "[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]\n",
"\n"
],
"name": "stdout"
@@ -442,11 +442,11 @@
"metadata": {
"id": "DmespEaFcSEL",
"colab_type": "code",
- "outputId": "34ecdcca-5ceb-43f0-c793-e209e25dafce",
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 122
- }
+ "height": 153
+ },
+ "outputId": "3c8579db-7b3c-4164-ff91-394346595107"
},
"source": [
"#@title Export the trained model as a SavedModel, with IREE-compatible settings\n",
@@ -472,11 +472,14 @@
{
"output_type": "stream",
"text": [
- "Exporting SavedModel to /tmp/mnist.sm\n",
- "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1788: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
+ "Exporting SavedModel to /tmp/mnist.sm\r\n",
+ "WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
- "If using Keras pass *_constraint arguments to layers.\n",
- "INFO:tensorflow:Assets written to: /tmp/mnist.sm/assets\n"
+ "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
+ "WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
+ "INFO:tensorflow:Assets written to: /tmp/mnist.sm\\assets\n"
],
"name": "stdout"
}
@@ -497,11 +500,11 @@
"metadata": {
"id": "rqwIx4j4gS1a",
"colab_type": "code",
- "outputId": "28c7516c-903c-4579-c4bc-8299712a9edd",
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 751
- }
+ "height": 836
+ },
+ "outputId": "24cded90-c436-47ce-b4f4-7a5da46ea38a"
},
"source": [
"#@title Load the SavedModel into IREE's compiler as MLIR xla_hlo\n",
@@ -524,16 +527,16 @@
"Imported MLIR:\n",
" \n",
"\n",
- "module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 175 : i32}} {\n",
- " flow.variable @\"__iree_flow___sm_node15__model.layer-2.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<784x128xf32>\n",
- " flow.variable @\"__iree_flow___sm_node16__model.layer-2.bias\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128xf32>\n",
- " flow.variable @\"__iree_flow___sm_node21__model.layer-3.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128x10xf32>\n",
- " flow.variable @\"__iree_flow___sm_node22__model.layer-3.bias\" dense<[-0.0719004869, 0.1290171, 0.0102811698, -0.106104381, 0.0260288324, 0.166622087, -8.5693755E-4, 0.070880115, -0.222566754, -0.00140074058]> : tensor<10xf32>\n",
- " func @predict(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x10xf32> attributes {iree.module.export, iree.reflection = {abi = \"sip\", abiv = 1 : i32, sip = \"I8!S5!k0_0R3!_0\"}, tf._input_shapes = [\"tfshape$dim { size: 1 } dim { size: 28 } dim { size: 28 } dim { size: 1 }\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful} {\n",
- " %0 = flow.variable.address @\"__iree_flow___sm_node15__model.layer-2.kernel\" : !iree.ptr<tensor<784x128xf32>>\n",
- " %1 = flow.variable.address @\"__iree_flow___sm_node16__model.layer-2.bias\" : !iree.ptr<tensor<128xf32>>\n",
- " %2 = flow.variable.address @\"__iree_flow___sm_node21__model.layer-3.kernel\" : !iree.ptr<tensor<128x10xf32>>\n",
- " %3 = flow.variable.address @\"__iree_flow___sm_node22__model.layer-3.bias\" : !iree.ptr<tensor<10xf32>>\n",
+ "module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 443 : i32}} {\n",
+ " flow.variable @\"__iree_flow___sm_node14__model.layer-1.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<784x128xf32> attributes {sym_visibility = \"private\"}\n",
+ " flow.variable @\"__iree_flow___sm_node15__model.layer-1.bias\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128xf32> attributes {sym_visibility = \"private\"}\n",
+ " flow.variable @\"__iree_flow___sm_node20__model.layer-2.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128x10xf32> attributes {sym_visibility = \"private\"}\n",
+ " flow.variable @\"__iree_flow___sm_node21__model.layer-2.bias\" dense<[-0.114143081, 0.0953421518, 4.84912744E-5, -0.0384164825, 0.0063888072, 0.218958765, 0.0256200824, 0.0551806651, -0.22108613, -0.0278935507]> : tensor<10xf32> attributes {sym_visibility = \"private\"}\n",
+ " func @predict(%arg0: tensor<1x28x28x1xf32> {tf._user_specified_name = \"x\"}) -> tensor<1x10xf32> attributes {iree.module.export, iree.reflection = {abi = \"sip\", abiv = 1 : i32, sip = \"I8!S5!k0_0R3!_0\"}, tf._input_shapes = [#tf.shape<1x28x28x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {\n",
+ " %0 = flow.variable.address @\"__iree_flow___sm_node14__model.layer-1.kernel\" : !iree.ptr<tensor<784x128xf32>>\n",
+ " %1 = flow.variable.address @\"__iree_flow___sm_node15__model.layer-1.bias\" : !iree.ptr<tensor<128xf32>>\n",
+ " %2 = flow.variable.address @\"__iree_flow___sm_node20__model.layer-2.kernel\" : !iree.ptr<tensor<128x10xf32>>\n",
+ " %3 = flow.variable.address @\"__iree_flow___sm_node21__model.layer-2.bias\" : !iree.ptr<tensor<10xf32>>\n",
" %4 = xla_hlo.constant dense<0xFF800000> : tensor<f32>\n",
" %5 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>\n",
" %6 = flow.variable.load.indirect %3 : !iree.ptr<tensor<10xf32>> -> tensor<10xf32>\n",
@@ -542,27 +545,32 @@
" %9 = flow.variable.load.indirect %0 : !iree.ptr<tensor<784x128xf32>> -> tensor<784x128xf32>\n",
" %10 = \"xla_hlo.reshape\"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>\n",
" %11 = \"xla_hlo.dot\"(%10, %9) : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>\n",
- " %12 = \"xla_hlo.add\"(%11, %8) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>\n",
- " %13 = \"xla_hlo.maximum\"(%12, %5) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x128xf32>, tensor<f32>) -> tensor<1x128xf32>\n",
- " %14 = \"xla_hlo.dot\"(%13, %7) : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>\n",
- " %15 = \"xla_hlo.add\"(%14, %6) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<10xf32>) -> tensor<1x10xf32>\n",
- " %16 = \"xla_hlo.reduce\"(%15, %4) ( {\n",
- " ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):\t// no predecessors\n",
- " %21 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>\n",
- " \"xla_hlo.return\"(%21) : (tensor<f32>) -> ()\n",
+ " %12 = \"xla_hlo.broadcast_in_dim\"(%8) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>\n",
+ " %13 = xla_hlo.add %11, %12 : tensor<1x128xf32>\n",
+ " %14 = \"xla_hlo.broadcast_in_dim\"(%5) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x128xf32>\n",
+ " %15 = xla_hlo.maximum %14, %13 : tensor<1x128xf32>\n",
+ " %16 = \"xla_hlo.dot\"(%15, %7) : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>\n",
+ " %17 = \"xla_hlo.broadcast_in_dim\"(%6) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>) -> tensor<1x10xf32>\n",
+ " %18 = xla_hlo.add %16, %17 : tensor<1x10xf32>\n",
+ " %19 = \"xla_hlo.reduce\"(%18, %4) ( {\n",
+ " ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors\n",
+ " %26 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>\n",
+ " \"xla_hlo.return\"(%26) : (tensor<f32>) -> ()\n",
" }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n",
- " %17 = \"xla_hlo.subtract\"(%15, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1xf32>) -> tensor<1x10xf32>\n",
- " %18 = \"xla_hlo.exponential\"(%17) : (tensor<1x10xf32>) -> tensor<1x10xf32>\n",
- " %19 = \"xla_hlo.reduce\"(%18, %5) ( {\n",
- " ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):\t// no predecessors\n",
- " %21 = xla_hlo.add %arg1, %arg2 : tensor<f32>\n",
- " \"xla_hlo.return\"(%21) : (tensor<f32>) -> ()\n",
+ " %20 = \"xla_hlo.broadcast_in_dim\"(%19) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>\n",
+ " %21 = xla_hlo.subtract %18, %20 : tensor<1x10xf32>\n",
+ " %22 = \"xla_hlo.exponential\"(%21) : (tensor<1x10xf32>) -> tensor<1x10xf32>\n",
+ " %23 = \"xla_hlo.reduce\"(%22, %5) ( {\n",
+ " ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors\n",
+ " %26 = xla_hlo.add %arg1, %arg2 : tensor<f32>\n",
+ " \"xla_hlo.return\"(%26) : (tensor<f32>) -> ()\n",
" }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n",
- " %20 = \"xla_hlo.divide\"(%18, %19) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1xf32>) -> tensor<1x10xf32>\n",
- " return %20 : tensor<1x10xf32>\n",
+ " %24 = \"xla_hlo.broadcast_in_dim\"(%23) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>\n",
+ " %25 = xla_hlo.divide %22, %24 : tensor<1x10xf32>\n",
+ " return %25 : tensor<1x10xf32>\n",
" }\n",
- "}\n",
- "Wrote MLIR to path '/usr/local/google/home/scotttodd/saved_models/mnist.mlir'\n"
+ "}\r\n",
+ "Wrote MLIR to path 'C:\\Users\\Scott\\saved_models\\mnist.mlir'\n"
],
"name": "stdout"
}
@@ -573,11 +581,11 @@
"metadata": {
"id": "IDHI7h3khJr9",
"colab_type": "code",
- "outputId": "f2abfd0c-5ada-4533-ab2f-a47f77ccd260",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
- }
+ },
+ "outputId": "b8958b7f-c7bb-4fbd-b800-e58c46134086"
},
"source": [
"#@title Compile the xla_hlo MLIR and prepare a context to execute it\n",
@@ -596,8 +604,8 @@
{
"output_type": "stream",
"text": [
- "Created IREE driver vulkan: <pyiree.rt.binding.HalDriver object at 0x7f58bc3d27f0>\n",
- "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x7f58bc3d27f0>\n"
+ "Created IREE driver vulkan: <pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>\n",
+ "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>\n"
],
"name": "stderr"
}
@@ -608,11 +616,11 @@
"metadata": {
"id": "SKflpnLtkLYE",
"colab_type": "code",
- "outputId": "59c52840-5cc6-4038-c330-707bd9313f4b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
- }
+ },
+ "outputId": "337ddf79-6746-4685-89f5-65787280c0af"
},
"source": [
"#@title Execute the compiled module and compare the results with TensorFlow\n",
@@ -632,10 +640,10 @@
"output_type": "stream",
"text": [
"IREE prediction ('vulkan-spirv' backend, 'vulkan' driver):\n",
- "[1.24276221 0.00645937538 94.5438538 1.92194688 5.14127305e-06 1.25558555 0.770468712 1.55102171e-05 0.25887078 2.61111945e-05]\n",
+ "[0.243133873 0.00337268622 95.5214233 0.92537272 2.25061631e-05 0.992090821 2.20864058 3.87712225e-06 0.105901062 4.44369434e-05]\n",
"\n",
"TensorFlow prediction:\n",
- "[1.24276197 0.00645936839 94.5438538 1.92194772 5.14127169e-06 1.25558567 0.77046895 1.55102152e-05 0.25887084 2.61111491e-05]\n"
+ "[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]\n"
],
"name": "stdout"
}