Skip to content

Commit

Permalink
update TraceLayer save_inference_model api doc and dygraph_to_static …
Browse files Browse the repository at this point in the history
…example code (PaddlePaddle#3406)

update TraceLayer save_inference_model api doc and the example code of dygraph_to_static guide doc
  • Loading branch information
CtfGo authored Apr 8, 2021
1 parent 9d09f51 commit d682d8e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 5 additions & 3 deletions doc/paddle/api/paddle/fluid/dygraph/jit/TracedLayer_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TracedLayer只能用于将data independent的动态图模型转换为静态图
print(out_static_graph[0].shape) # (2, 10)
# 将静态图模型保存为预测模型
static_layer.save_inference_model(dirname='./saved_infer_model')
static_layer.save_inference_model(path='./saved_infer_model')
.. py:method:: set_strategy(build_strategy=None, exec_strategy=None)
Expand Down Expand Up @@ -93,12 +93,14 @@ TracedLayer只能用于将data independent的动态图模型转换为静态图
static_layer.set_strategy(build_strategy=build_strategy, exec_strategy=exec_strategy)
out_static_graph = static_layer([in_var])
.. py:method:: save_inference_model(dirname, feed=None, fetch=None)
.. py:method:: save_inference_model(path, feed=None, fetch=None)
将TracedLayer保存为用于预测部署的模型。保存的预测模型可被C++预测接口加载。

``path`` 是存储目标的前缀,存储的模型结构 ``Program`` 文件的后缀为 ``.pdmodel``,存储的持久参数变量文件的后缀为 ``.pdiparams``.

参数:
- **dirname** (str) - 预测模型的保存目录
- **path** (str) - 存储模型的路径前缀。格式为 ``dirname/file_prefix`` 或者 ``file_prefix``
- **feed** (list(int), 可选) - 预测模型输入变量的索引。若为None,则TracedLayer的所有输入变量均会作为预测模型的输入。默认值为None。
- **fetch** (list(int), 可选) - 预测模型输出变量的索引。若为None,则TracedLayer的所有输出变量均会作为预测模型的输出。默认值为None。

Expand Down
2 changes: 2 additions & 0 deletions doc/paddle/guides/04_dygraph_to_static/basic_usage_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ PaddlePaddle主要的动转静方式是基于源代码级别转换的ProgramTran
.. code-block:: python
import paddle
import numpy as np
@paddle.jit.to_static
def func(input_var):
Expand Down Expand Up @@ -106,6 +107,7 @@ trace是指在模型运行时记录下其运行过哪些算子。TracedLayer就

.. code-block:: python
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feed_vars, fetch_vars = paddle.static.load_inference_model(save_dirname, exe)
Expand Down

0 comments on commit d682d8e

Please sign in to comment.