Skip to content

Commit

Permalink
[Dy2St]Enhance @not_to_static API (PaddlePaddle#50453)
Browse files Browse the repository at this point in the history
* [Dy2St]Enhance @not_to_static API

* del breakpoint()
  • Loading branch information
Aurelius84 authored Feb 14, 2023
1 parent c5087da commit 842050f
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 84 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h
paddle/fluid/prim/api/generated/prim_api/*
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import unittest

import numpy as np
from test_program_translator import get_source_code

import paddle
import paddle.fluid as fluid
import paddle.jit.dy2static as _jst
from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS
from paddle.jit.dy2static.utils import func_to_source_code

SEED = 2020
np.random.seed(SEED)
Expand Down Expand Up @@ -216,103 +216,57 @@ def set_func(self):
# Situation 2 : test not_to_static


def func_sum(x):
res = paddle.sum(x)
return res


@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res

class NotToStaticHelper(paddle.nn.Layer):
def __init__(self):
super(NotToStaticHelper, self).__init__()

@paddle.jit.to_static
def func_convert_then_not_to_static(x):
y = func_not_to_static(x)
return y
def sum(self, x):
if x.shape[0] > 1:
res = x + 1
res = paddle.sum(x)
return res

def outer(self, x):
res = self.sum(x)
return res

class TestClass(paddle.nn.Layer):
@paddle.jit.not_to_static
def called_member(self, x):
return paddle.sum(x)

@paddle.jit.to_static
def forward(self, x):
y = self.called_member(x)
return y
def inner(self, x):
return self.outer(x)


class TestNotToConvert(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_not_to_static
self.net = NotToStaticHelper()
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.outer)

def test_conversion_options(self):
options = getattr(self.dygraph_func, CONVERSION_OPTIONS, None)
options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)


class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_convert_then_not_to_static


class TestNotToConvert3(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = TestClass()


class TestDynamicToStaticCode(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_answer_func()

def set_func(self):
self.func = func_not_to_static

def set_answer_func(self):
class StaticCode:
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res

self.answer_func = StaticCode.func_not_to_static

def _get_answer_code(self):
return get_source_code(self.answer_func)

def _get_transformed_code(self):
transformed_func = _jst.Call(self.func)
return get_source_code(transformed_func)

def test_code(self):
transformed_code = self._get_transformed_code()
answer_code = self._get_answer_code()
self.assertEqual(
answer_code,
transformed_code,
msg="\ntransformed_code : \n{}\nanswer_code : \n{}".format(
transformed_code, answer_code
),
# check 'if statement' is not converted
self.assertIn(
"if x.shape[0] > 1", func_to_source_code(_jst.Call(self.net.sum))
)


class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.func = func_convert_then_not_to_static
self.net = NotToStaticHelper()
# for to_static(not_to_static(function)) == enable_static
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.sum)

def set_answer_func(self):
class StaticCode:
def func_convert_then_not_to_static(x):
__return_value_0 = None
y = _jst.Call(func_not_to_static)(x)
__return_value_0 = y
return __return_value_0
def test_conversion_options(self):
options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)

self.answer_func = StaticCode.func_convert_then_not_to_static
def test_code(self):
# check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .dy2static import logging_utils
from .dy2static.convert_call_func import (
ConversionOptions,
CONVERSION_OPTIONS,
add_ignore_module,
)
from .dy2static.program_translator import (
Expand Down Expand Up @@ -348,7 +347,7 @@ def func(x):
return not_to_static

options = ConversionOptions(not_convert=True)
setattr(func, CONVERSION_OPTIONS, options)
options.attach(func)
return func


Expand Down
15 changes: 14 additions & 1 deletion python/paddle/jit/dy2static/convert_call_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

translator_logger = TranslatorLogger()

CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
CONVERSION_OPTIONS = "__jst_not_to_static"


class ConversionOptions:
Expand All @@ -58,6 +58,19 @@ class ConversionOptions:
def __init__(self, not_convert=False):
self.not_convert = not_convert

def attach(self, func):
if inspect.ismethod(func):
func = func.__func__

if inspect.isfunction(func):
setattr(func, CONVERSION_OPTIONS, self)
else:
translator_logger.warn(
"Only support @not_to_static to type(function) or type(method), but recevied {}".format(
type(func)
)
)


def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from . import error, logging_utils
from .ast_transformer import DygraphToStaticAst
from .convert_call_func import CONVERSION_OPTIONS
from .function_spec import (
FunctionSpec,
_hash_spec_names,
Expand Down Expand Up @@ -152,6 +153,12 @@ def convert_to_static(function):
"""
if getattr(function, ALREADY_D2S, None):
return function

# Return directly if decorated with @not_to_static and DO NOT Cache it
options = getattr(function, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
return function.__func__ if inspect.ismethod(function) else function

with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.convert_with_cache(function)
setattr(static_func, ALREADY_D2S, True)
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ast
import atexit
import copy
import functools
import importlib.util
import inspect
import os
Expand All @@ -23,7 +24,6 @@
import tempfile
import textwrap
import warnings
from functools import reduce
from importlib.machinery import SourceFileLoader

import astor
Expand Down Expand Up @@ -637,6 +637,8 @@ def func_to_source_code(function, dedent=True):
"""
Transforms function into raw string of source code.
"""
if isinstance(function, functools.partial):
function = function.func
if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError(
"The type of 'function' should be a function or method, but received {}.".format(
Expand Down Expand Up @@ -1429,7 +1431,9 @@ class GetterSetterHelper:
def __init__(self, getter_func, setter_func, *name_lists):
name_lists = map(lambda x: [] if x is None else x, name_lists)
name_sets = map(lambda x: set(x), name_lists)
self._union = list(reduce(lambda x, y: x | y, name_sets, set()))
self._union = list(
functools.reduce(lambda x, y: x | y, name_sets, set())
)
self._union.sort()
self.getter = getter_func
self.setter = setter_func
Expand Down

0 comments on commit 842050f

Please sign in to comment.