Skip to content

Commit

Permalink
cylinder3d steady case (PaddlePaddle#114)
Browse files Browse the repository at this point in the history
* cylinder3d steady case

* cylinder3d unsteady case

* cylinder3d unsteady case fix

* cylinder3d unsteady case fix

* cylinder3d unsteady case fix
  • Loading branch information
shjNT authored Jul 11, 2022
1 parent 007591b commit 7ff9265
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
Binary file added tests/test_models/standard/cylinder3d_steady.npz
Binary file not shown.
Binary file not shown.
27 changes: 18 additions & 9 deletions tests/test_models/test_cylinder3d_steady.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle
import pytest
import sys
from tool import compare


def cylinder3d_steady(static=True):
Expand Down Expand Up @@ -106,12 +107,12 @@ def cylinder3d_steady(static=True):
solver.feed_data_user(real_sol) # add real solution

res_shpae = [(2044, 4), (58, 4), (58, 4), (4, 4), (10000, 4)]
solution = solver.solve(num_epoch=1)
for i in range(len(solution)):
if i != 4:
assert solution[i].shape[1] == res_shpae[i][1]
else:
assert solution[i].shape == res_shpae[i]
solution = solver.solve(num_epoch=10)
res = [np.sum(item, axis=0) for item in solution]
return sum(res)


standard = np.load("./standard/cylinder3d_steady.npz", allow_pickle=True)


@pytest.mark.cylinder3d_steady
Expand All @@ -121,8 +122,14 @@ def test_cylinder3d_steady_0():
"""
test cylinder3d_steady
"""
cylinder3d_steady(static=False)
cylinder3d_steady()
dyn_standard, stc_standard = standard['dyn_solution'], standard[
'stc_solution']
dyn_rslt = cylinder3d_steady(static=False)
stc_rslt = cylinder3d_steady()

compare(dyn_rslt, stc_rslt)
compare(dyn_standard, dyn_rslt, mode="equal")
compare(stc_rslt, stc_rslt, mode="equal")


@pytest.mark.cylinder3d_steady
Expand All @@ -133,7 +140,9 @@ def test_cylinder3d_steady_1():
test cylinder3d_steady
distributed case: padding
"""
cylinder3d_steady()
dst_standard = standard["dst_solution"]
dst_rslt = cylinder3d_steady()
compare(dst_standard, dst_rslt, mode="equal")


if __name__ == '__main__':
Expand Down
31 changes: 24 additions & 7 deletions tests/test_models/test_cylinder3d_unsteady.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import zipfile
import sys
import pytest
from tool import compare


def cylinder3d_unsteady(static=True):
Expand Down Expand Up @@ -118,7 +119,7 @@ def GetRealPhyInfo(time, need_info=None):
activation='tanh')

# Loss
loss = psci.loss.L2(p=2)
loss = psci.loss.L2(p=2, data_weight=100.0)

# Algorithm
algo = psci.algorithm.PINNs(net=net, loss=loss)
Expand All @@ -134,6 +135,7 @@ def GetRealPhyInfo(time, need_info=None):
current_interior = np.zeros(
(len(pde_disc.geometry.interior), 3)).astype(np.float32)
current_user = GetRealPhyInfo(start_time, need_info='physic')[:, 0:3]
rslt = []
for next_time in range(
int(pde_disc.time_internal[0]) + 1,
int(pde_disc.time_internal[1]) + 1):
Expand All @@ -143,13 +145,18 @@ def GetRealPhyInfo(time, need_info=None):
solver.feed_data_user_next(
GetRealPhyInfo(
next_time, need_info='physic')) # add u(n+1) user
next_uvwp = solver.solve(num_epoch=1)
for i in range(len(next_uvwp)):
assert next_uvwp[i].shape == res_shape[i]
next_uvwp = solver.solve(num_epoch=10)

res = [np.sum(item, axis=0) for item in next_uvwp]
rslt.append(res)

# current_info need to be modified as follows: current_time -> next time
current_interior = np.array(next_uvwp[0])[:, 0:3]
current_user = np.array(next_uvwp[-1])[:, 0:3]
return np.mean(rslt, axis=0)


standard = np.load("./standard/cylinder3d_unsteady.npz", allow_pickle=True)


@pytest.mark.cylinder3d_unsteady
Expand All @@ -159,8 +166,16 @@ def test_cylinder3d_unsteady_0():
"""
test cylinder3d_steady
"""
cylinder3d_unsteady(static=False)
cylinder3d_unsteady(static=True)
dyn_standard, stc_standard = standard['dyn_solution'], standard[
'stc_solution']
dyn_rslt = cylinder3d_unsteady(static=False)
stc_rslt = cylinder3d_unsteady(static=True)

print(dyn_standard)
print(dyn_rslt)
compare(dyn_rslt, stc_rslt, delta=1e-5)
compare(dyn_standard, dyn_rslt, mode="equal")
compare(stc_standard, stc_rslt, mode="equal")


@pytest.mark.cylinder3d_unsteady
Expand All @@ -171,7 +186,9 @@ def test_cylinder3d_steady_1():
test cylinder3d_steady
distributed case: padding
"""
cylinder3d_unsteady()
dst_standard = standard['dst_solution']
dst_rslt = cylinder3d_unsteady(static=True)
compare(dst_standard, dst_rslt, mode="equal")


if __name__ == '__main__':
Expand Down

0 comments on commit 7ff9265

Please sign in to comment.