Skip to content

Commit

Permalink
【Fix PIR JIT SaveLoad Unittest No.16,17】modify test error and test_pi…
Browse files Browse the repository at this point in the history
…r_translated_layer (PaddlePaddle#64475)

* modify test_modbile_net.py

* delete print

* modify ci

* code style

* modify error and translatelayer
  • Loading branch information
xiaoguoguo626807 authored May 21, 2024
1 parent c01b2a4 commit b704fe5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 49 deletions.
49 changes: 0 additions & 49 deletions test/dygraph_to_static/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def func_decorated_by_other_2():
return 1


class LayerErrorInCompiletime(paddle.nn.Layer):
def __init__(self, fc_size=20):
super().__init__()
self._linear = paddle.nn.Linear(fc_size, fc_size)

@paddle.jit.to_static(
input_spec=[paddle.static.InputSpec(shape=[20, 20], dtype='float32')],
full_graph=True,
)
def forward(self, x):
y = self._linear(x)
z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")
out = paddle.mean(y[z])
return out


class LayerErrorInCompiletime2(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -348,39 +332,6 @@ def set_message(self):
]


class TestJitSaveInCompiletime(TestErrorBase):
def setUp(self):
self.reset_flags_to_default()
self.set_func_call()
self.filepath = inspect.getfile(inspect.unwrap(self.func_call))
self.set_exception_type()
self.set_message()

def set_exception_type(self):
self.exception_type = TypeError

def set_message(self):
self.expected_message = [
'def forward(self, x):',
'y = self._linear(x)',
'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'out = paddle.mean(y[z])',
'return out',
]

def set_func_call(self):
layer = LayerErrorInCompiletime()
self.func_call = lambda: paddle.jit.save(
layer, path="./test_dy2stat_error/model"
)

def test_error(self):
# TODO(pir-save-load): Open this test after we support PIR save load
...
# self._test_raise_new_exception()


@paddle.jit.to_static(full_graph=True)
def func_ker_error(x):
d = {'x': x}
Expand Down
8 changes: 8 additions & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,14 @@ foreach(PIR_COVERAGE_TEST ${PIR_COVERAGE_TESTS})
message(STATUS "PIR Copied OpTest: ${PIR_COVERAGE_TEST}_pir in legacy_test")
endforeach()

set(PIR_ONLY_TEST_FILES test_pir_translated_layer)
foreach(ITEST ${PIR_ONLY_TEST_FILES})
if(TEST ${ITEST})
set_tests_properties(${ITEST} PROPERTIES ENVIRONMENT
"FLAGS_enable_pir_api=True")
endif()
endforeach()

set_tests_properties(test_imperative_optimizer_static_build PROPERTIES TIMEOUT
250)
set_tests_properties(test_sync_batch_norm_op_static_build
Expand Down

0 comments on commit b704fe5

Please sign in to comment.