Skip to content

Commit

Permalink
Optional param input in function definition (PaddlePaddle#151)
Browse files Browse the repository at this point in the history
* remove print

* add padding in create input from loss

* fix nproc typo

* update name param to params

* update param in algo
  • Loading branch information
xingfeng01 authored Aug 1, 2022
1 parent 9b9d483 commit 006118d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
22 changes: 8 additions & 14 deletions paddlescience/algorithm/algorithm_pinns.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def feed_data_user_next(self, labels, labels_attr, data):
# print("idx user next: ", idx)
return labels

def compute_forward(self, *inputs, params=None):
def compute_forward(self, params, *inputs):

outs = list()

Expand All @@ -521,14 +521,8 @@ def compute_forward(self, *inputs, params=None):

return outs

def compute(self,
*inputs_labels,
ninputs,
inputs_attr,
nlabels,
labels_attr,
pde,
params=None):
def compute(self, params, *inputs_labels, ninputs, inputs_attr, nlabels,
labels_attr, pde):

outs = list()

Expand Down Expand Up @@ -561,7 +555,7 @@ def compute(self,
labels,
labels_attr["interior"],
bs=-1,
params=None) # TODO: bs is not used
params=params) # TODO: bs is not used
loss_eq += loss_i
outs.append(out_i)
n += 1
Expand All @@ -581,7 +575,7 @@ def compute(self,
labels,
labels_attr,
bs=-1,
params=None) # TODO: bs is not used
params=params) # TODO: bs is not used
loss_bc += loss_b
outs.append(out_b)
n += 1
Expand All @@ -597,7 +591,7 @@ def compute(self,
labels,
labels_attr,
bs=-1,
params=None)
params=params)
loss_ic += loss_it
outs.append(out_it)
n += 1
Expand All @@ -617,7 +611,7 @@ def compute(self,
labels,
labels_attr["user"],
bs=-1,
params=None)
params=params)
loss_eq += loss_id

# data loss
Expand All @@ -629,7 +623,7 @@ def compute(self,
labels,
labels_attr["user"],
bs=-1,
params=None) # TODO: bs is not used
params=params) # TODO: bs is not used
loss_data += loss_d
outs.append(out_id)

Expand Down
11 changes: 7 additions & 4 deletions paddlescience/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from paddle.distributed.auto_parallel.engine import Engine
from paddle.incubate.optimizer.functional.lbfgs import minimize_lbfgs
from paddle.incubate.optimizer.functional.bfgs import minimize_bfgs
paddle.disable_static()
from . import utils
from .. import config
from visualdl import LogWriter
Expand Down Expand Up @@ -56,6 +55,7 @@ def forward(self, *inputs_labels):
input.stop_gradient = False

self.loss, self.outs, self.loss_details = self.algo.compute(
None,
*inputs_labels,
ninputs=self.ninputs,
inputs_attr=self.inputs_attr,
Expand Down Expand Up @@ -210,6 +210,7 @@ def __solve_dynamic(self, num_epoch, bs, checkpoint_freq, checkpoint_path):
# TODO: error out num_epoch==0

loss, outs, loss_details = self.algo.compute(
None,
*inputs_labels,
ninputs=ninputs,
inputs_attr=inputs_attr,
Expand Down Expand Up @@ -266,6 +267,7 @@ def __solve_dynamic(self, num_epoch, bs, checkpoint_freq, checkpoint_path):
def _f(x):
self.algo.net.reconstruct(x)
loss, self.outs, self.loss_details = self.algo.compute(
None,
*inputs_labels,
ninputs=ninputs,
inputs_attr=inputs_attr,
Expand Down Expand Up @@ -347,7 +349,7 @@ def __predict_dynamic(self):
inputs[i] = paddle.to_tensor(
inputs[i], dtype=self._dtype, stop_gradient=False)

outs = self.algo.compute_forward(*inputs)
outs = self.algo.compute_forward(None, *inputs)

for i in range(len(outs)):
outs[i] = outs[i].numpy()
Expand Down Expand Up @@ -413,6 +415,7 @@ def __init_static(self):
inputs_labels.append(label)

self.loss, self.outs, self.loss_details = self.algo.compute(
None,
*inputs_labels,
ninputs=ninputs,
inputs_attr=self.inputs_attr,
Expand Down Expand Up @@ -543,7 +546,7 @@ def __predict_static(self):
input.stop_gradient = False
ins.append(input)

self.outs_predict = self.algo.compute_forward(*ins)
self.outs_predict = self.algo.compute_forward(None, *ins)

# startup program
self.exe.run(self.startup_program)
Expand Down Expand Up @@ -652,7 +655,7 @@ def __solve_static_auto_dist(self, num_epoch, bs, checkpoint_freq):
input.stop_gradient = False
ins.append(input)

self.outs_predict = self.algo.compute_forward(*ins)
self.outs_predict = self.algo.compute_forward(None, *ins)

# feeds inputs
feeds = dict()
Expand Down

0 comments on commit 006118d

Please sign in to comment.