diff --git a/site/en/guide/function.ipynb b/site/en/guide/function.ipynb index 44fd0e4a23c..5d7d45b4afc 100644 --- a/site/en/guide/function.ipynb +++ b/site/en/guide/function.ipynb @@ -746,6 +746,204 @@ "\n", "square_if_positive_vectorized(tf.range(-5, 5))" ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EXFVx-cJ57_F" + }, + "source": [ + "## Re-tracing\n", + "\n", + "Key points:\n", + "\n", + "* Exercise caution when calling functions with non-tensor arguments, or with arguments that change shapes.\n", + "* Decorate module-level functions, and methods of module-level classes, and avoid decorating local functions or methods.\n", + "\n", + "`tf.function` can give you significant speedup over eager execution, at the cost of a slower first-time execution. This is because when executed for the first time, the function is also *traced* into a TensorFlow graph. Constructing and optimizing a graph is usually much slower compared to actually executing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iahT-4wT6vlA" + }, + "outputs": [], + "source": [ + "import timeit\n", + "\n", + "\n", + "@tf.function\n", + "def f(x, y):\n", + " return tf.matmul(x, y)\n", + "\n", + "print(\n", + " \"First invocation:\",\n", + " timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))\n", + "\n", + "print(\n", + " \"Second invocation:\",\n", + " timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "q0Wojo2Z7hKg" + }, + "source": [ + "You can easily tell when a function is traced by adding a `print` statement to the top of the function. Because any Python code is only executed at trace time, you will only see the otput of `print` when the function is traced:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2IHE7-jT7gZs" + }, + "outputs": [], + "source": [ + "@tf.function\n", + "def f():\n", + " print('Tracing!')\n", + " tf.print('Executing')\n", + "\n", + "print('First invocation:')\n", + "f()\n", + "\n", + "print('Second invocation:')\n", + "f()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "J2VBoQC58PdU" + }, + "source": [ + "`tf.function` may also *re-trace* when called with different non-tensor arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-c6VUwrz808l" + }, + "outputs": [], + "source": [ + "@tf.function\n", + "def f(n):\n", + " print(n, 'Tracing!')\n", + " tf.print(n, 'Executing')\n", + "\n", + "f(1)\n", + "f(1)\n", + "\n", + "f(2)\n", + "f(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aKOrjBLCE8cy" + }, + "source": [ + "A *re-trace* can also happen when tensor arguments change shape, unless you specified an `input_signature`:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "txhtkn0rE8dH" + }, + "outputs": [], + "source": [ + "@tf.function\n", + "def f(x):\n", + " print(x.shape, 'Tracing!')\n", + " tf.print(x, 'Executing')\n", + "\n", + "f(tf.constant([1]))\n", + "f(tf.constant([2]))\n", + "\n", + "f(tf.constant([1, 2]))\n", + "f(tf.constant([3, 4]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "sdN40ZqT9XaG" + }, + "source": [ + "In addition, tf.function always creates a new graph function with its own set of traces whenever it is called:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "GT1iBa5i9enE" + }, + "outputs": [], + "source": [ + "def f():\n", + " print('Tracing!')\n", + " tf.print('Executing')\n", + "\n", + "tf.function(f)()\n", + "tf.function(f)()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "n7_JDzFK9nnC" + }, + "source": [ + "This can lead to surprising behavior when using the `@tf.function` decorator in a nested function:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "P3pBG7Uf9u4g" + }, + "outputs": [], + "source": [ + "def outer():\n", + " @tf.function\n", + " def f():\n", + " print('Tracing!')\n", + " tf.print('Executing')\n", + " f()\n", + "\n", + "outer()\n", + "outer()" + ] } ], "metadata": { @@ -753,10 +951,6 @@ "collapsed_sections": [ "Jxv6goXm7oGF" ], - "last_runtime": { - "build_target": "", - "kind": "local" - }, "name": "function.ipynb", "private_outputs": true, "provenance": [