Skip to content

Commit

Permalink
Add a section to the tf.function guide about re-tracing, along with a…
Browse files Browse the repository at this point in the history
… few advices.

PiperOrigin-RevId: 284001183
  • Loading branch information
Dan Moldovan authored and copybara-github committed Dec 5, 2019
1 parent 4d099f7 commit 13dd618
Showing 1 changed file with 198 additions and 4 deletions.
202 changes: 198 additions & 4 deletions site/en/guide/function.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,211 @@
"\n",
"square_if_positive_vectorized(tf.range(-5, 5))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "EXFVx-cJ57_F"
},
"source": [
"## Re-tracing\n",
"\n",
"Key points:\n",
"\n",
"* Exercise caution when calling functions with non-tensor arguments, or with arguments that change shapes.\n",
"* Decorate module-level functions, and methods of module-level classes, and avoid decorating local functions or methods.\n",
"\n",
"`tf.function` can give you significant speedup over eager execution, at the cost of a slower first-time execution. This is because when executed for the first time, the function is also *traced* into a TensorFlow graph. Constructing and optimizing a graph is usually much slower compared to actually executing it:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "iahT-4wT6vlA"
},
"outputs": [],
"source": [
"import timeit\n",
"\n",
"\n",
"@tf.function\n",
"def f(x, y):\n",
" return tf.matmul(x, y)\n",
"\n",
"print(\n",
" \"First invocation:\",\n",
" timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))\n",
"\n",
"print(\n",
" \"Second invocation:\",\n",
" timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "q0Wojo2Z7hKg"
},
"source": [
"You can easily tell when a function is traced by adding a `print` statement to the top of the function. Because any Python code is only executed at trace time, you will only see the otput of `print` when the function is traced:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2IHE7-jT7gZs"
},
"outputs": [],
"source": [
"@tf.function\n",
"def f():\n",
" print('Tracing!')\n",
" tf.print('Executing')\n",
"\n",
"print('First invocation:')\n",
"f()\n",
"\n",
"print('Second invocation:')\n",
"f()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "J2VBoQC58PdU"
},
"source": [
"`tf.function` may also *re-trace* when called with different non-tensor arguments:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "-c6VUwrz808l"
},
"outputs": [],
"source": [
"@tf.function\n",
"def f(n):\n",
" print(n, 'Tracing!')\n",
" tf.print(n, 'Executing')\n",
"\n",
"f(1)\n",
"f(1)\n",
"\n",
"f(2)\n",
"f(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "aKOrjBLCE8cy"
},
"source": [
"A *re-trace* can also happen when tensor arguments change shape, unless you specified an `input_signature`:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "txhtkn0rE8dH"
},
"outputs": [],
"source": [
"@tf.function\n",
"def f(x):\n",
" print(x.shape, 'Tracing!')\n",
" tf.print(x, 'Executing')\n",
"\n",
"f(tf.constant([1]))\n",
"f(tf.constant([2]))\n",
"\n",
"f(tf.constant([1, 2]))\n",
"f(tf.constant([3, 4]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sdN40ZqT9XaG"
},
"source": [
"In addition, tf.function always creates a new graph function with its own set of traces whenever it is called:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GT1iBa5i9enE"
},
"outputs": [],
"source": [
"def f():\n",
" print('Tracing!')\n",
" tf.print('Executing')\n",
"\n",
"tf.function(f)()\n",
"tf.function(f)()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "n7_JDzFK9nnC"
},
"source": [
"This can lead to surprising behavior when using the `@tf.function` decorator in a nested function:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "P3pBG7Uf9u4g"
},
"outputs": [],
"source": [
"def outer():\n",
" @tf.function\n",
" def f():\n",
" print('Tracing!')\n",
" tf.print('Executing')\n",
" f()\n",
"\n",
"outer()\n",
"outer()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"Jxv6goXm7oGF"
],
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "function.ipynb",
"private_outputs": true,
"provenance": [
Expand Down

0 comments on commit 13dd618

Please sign in to comment.