Skip to content

Commit

Permalink
avoid multiplication with weight 1 if weight is None for reducing com…
Browse files Browse the repository at this point in the history
…putation overhead (PaddlePaddle#714)
  • Loading branch information
HydrogenSulfate authored Dec 22, 2023
1 parent 15ed976 commit dafef83
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 24 deletions.
3 changes: 2 additions & 1 deletion ppsci/constraint/boundary_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def __init__(
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

# prepare weight
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
weight = None
if weight_dict is not None:
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
weight[key] = np.full_like(next(iter(label.values())), value)
Expand Down
3 changes: 2 additions & 1 deletion ppsci/constraint/initial_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ def __init__(
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

# prepare weight
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
weight = None
if weight_dict is not None:
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
weight[key] = np.full_like(next(iter(label.values())), value)
Expand Down
3 changes: 2 additions & 1 deletion ppsci/constraint/integral_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ def __init__(

# prepare weight
# shape of each weight is [batch_size, ndim]
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
weight = None
if weight_dict is not None:
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
weight[key] = np.full_like(next(iter(label.values())), value)
Expand Down
3 changes: 2 additions & 1 deletion ppsci/constraint/interior_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def __init__(
raise NotImplementedError(f"type of {type(value)} is invalid yet.")

# prepare weight
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
weight = None
if weight_dict is not None:
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
for key, value in weight_dict.items():
if isinstance(value, str):
if value == "sdf":
Expand Down
3 changes: 2 additions & 1 deletion ppsci/constraint/periodic_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def __init__(
)

# # prepare weight, keep weight the same shape as input_periodic
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
weight = None
if weight_dict is not None:
weight = {key: np.ones_like(next(iter(label.values()))) for key in label}
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
weight[key] = np.full_like(next(iter(label.values())), value)
Expand Down
16 changes: 10 additions & 6 deletions ppsci/data/dataset/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def __init__(
}

# prepare weights
self.weight = {
key: np.ones_like(next(iter(self.label.values()))) for key in self.label
}
self.weight = (
{key: np.ones_like(next(iter(self.label.values()))) for key in self.label}
if weight_dict is not None
else {}
)
if weight_dict is not None:
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
Expand Down Expand Up @@ -231,9 +233,11 @@ def __init__(
}

# prepare weights
self.weight = {
key: np.ones_like(next(iter(self.label.values()))) for key in self.label
}
self.weight = (
{key: np.ones_like(next(iter(self.label.values()))) for key in self.label}
if weight_dict is not None
else {}
)
if weight_dict is not None:
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
Expand Down
16 changes: 10 additions & 6 deletions ppsci/data/dataset/mat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def __init__(
}

# prepare weights
self.weight = {
key: np.ones_like(next(iter(self.label.values()))) for key in self.label
}
self.weight = (
{key: np.ones_like(next(iter(self.label.values()))) for key in self.label}
if weight_dict is not None
else {}
)
if weight_dict is not None:
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
Expand Down Expand Up @@ -231,9 +233,11 @@ def __init__(
}

# prepare weights
self.weight = {
key: np.ones_like(next(iter(self.label.values()))) for key in self.label
}
self.weight = (
{key: np.ones_like(next(iter(self.label.values()))) for key in self.label}
if weight_dict is not None
else {}
)
if weight_dict is not None:
for key, value in weight_dict.items():
if isinstance(value, (int, float)):
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
label_dict[key],
"none",
)
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if self.reduction == "sum":
Expand Down
4 changes: 2 additions & 2 deletions ppsci/loss/l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
losses = 0.0
for key in label_dict:
loss = F.l1_loss(output_dict[key], label_dict[key], "none")
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if "area" in output_dict:
Expand Down Expand Up @@ -181,7 +181,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
loss = F.l1_loss(
output_dict[key][:n_output], output_dict[key][n_output:], "none"
)
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]
if "area" in output_dict:
loss *= output_dict["area"]
Expand Down
4 changes: 2 additions & 2 deletions ppsci/loss/l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
losses = 0.0
for key in label_dict:
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if "area" in output_dict:
Expand Down Expand Up @@ -181,7 +181,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
loss = F.mse_loss(
output_dict[key][:n_output], output_dict[key][n_output:], "none"
)
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if "area" in output_dict:
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
losses = 0.0
for key in label_dict:
loss = F.l1_loss(output_dict[key], label_dict[key], "none")
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if "area" in output_dict:
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
losses = 0.0
for key in label_dict:
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
if weight_dict:
if weight_dict and key in weight_dict:
loss *= weight_dict[key]

if "area" in output_dict:
Expand Down

0 comments on commit dafef83

Please sign in to comment.