Skip to content

Commit

Permalink
[Fix&Feature]convert dtype to float32 uniformly during visualize (Pad…
Browse files Browse the repository at this point in the history
…dlePaddle#498)

* convert dtype to float32 uniformly during visualize; correct dot in user_guide.md

* add colorlog for logging.warn/err/debug; move jointContribution/math.py to ppsci/experimental/math.py
  • Loading branch information
HydrogenSulfate authored Aug 22, 2023
1 parent cafdaa5 commit 859ac31
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/zh/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
4. `latest.pdstates`,该文件保存了 latest 对应 epoch 的所有评估指标以及 epoch 数。
5. `latest.pdscaler`(可选),在开启自动混合精度(AMP)功能时,该文件保存了 `GradScaler` 梯度缩放器内部的参数。

因此我们只需要在 `Solver` 时指定 `checkpoint_path` 参数为 `latest*` 的所在路径,即可自动载入上述的几个文件,并从 `latest` 中记录的 epoch 开始继续训练。
因此我们只需要在 `Solver` 时指定 `checkpoint_path` 参数为 `latest.*` 的所在路径,即可自动载入上述的几个文件,并从 `latest` 中记录的 epoch 开始继续训练。

``` py hl_lines="9"
import ppsci
Expand Down
17 changes: 0 additions & 17 deletions jointContribution/math.py

This file was deleted.

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ nav:
- ppsci.utils: zh/api/utils.md
- ppsci.validate: zh/api/validate.md
- ppsci.visualize: zh/api/visualize.md
- ppsci.experimental: zh/api/experimental.md
- 使用指南: zh/user_guide.md
- 开发与复现指南:
- 开发指南: zh/development.md
Expand Down
4 changes: 2 additions & 2 deletions ppsci/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ppsci import visualize # isort:skip
from ppsci import validate # isort:skip
from ppsci import solver # isort:skip
from jointContribution import math # isort:skip
from ppsci import experimental # isort:skip

__all__ = [
"arch",
Expand All @@ -41,5 +41,5 @@
"visualize",
"validate",
"solver",
"math",
"experimental",
]
29 changes: 29 additions & 0 deletions ppsci/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module is for experimental API
"""

from ppsci.experimental.math import bessel_i0
from ppsci.experimental.math import bessel_i0e
from ppsci.experimental.math import bessel_i1
from ppsci.experimental.math import bessel_i1e

__all__ = [
"bessel_i0",
"bessel_i0e",
"bessel_i1",
"bessel_i1e",
]
31 changes: 31 additions & 0 deletions ppsci/experimental/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


def bessel_i0(x: paddle.Tensor) -> paddle.Tensor:
return paddle.i0(x)


def bessel_i0e(x: paddle.Tensor) -> paddle.Tensor:
return paddle.i0e(x)


def bessel_i1(x: paddle.Tensor) -> paddle.Tensor:
return paddle.i1(x)


def bessel_i1e(x: paddle.Tensor) -> paddle.Tensor:
return paddle.i1e(x)
14 changes: 8 additions & 6 deletions ppsci/solver/visu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,25 @@ def visualize_func(solver: "solver.Solver", epoch_id: int):
batch_input_dict[key].stop_gradient = False

# forward
with solver.no_grad_context_manager(solver.eval_with_no_grad):
with solver.autocast_context_manager(
solver.use_amp, solver.amp_level
), solver.no_grad_context_manager(solver.eval_with_no_grad):
batch_output_dict = solver.forward_helper.visu_forward(
_visualizer.output_expr, batch_input_dict, solver.model
)

# collect batch data
# collect batch data with float32 dtype
for key, batch_input in batch_input_dict.items():
all_input[key].append(
batch_input.detach()
batch_input.detach().astype("float32")
if solver.world_size == 1
else misc.all_gather(batch_input.detach())
else misc.all_gather(batch_input.detach().astype("float32"))
)
for key, batch_output in batch_output_dict.items():
all_output[key].append(
batch_output.detach()
batch_output.detach().astype("float32")
if solver.world_size == 1
else misc.all_gather(batch_output.detach())
else misc.all_gather(batch_output.detach().astype("float32"))
)

# concate all data
Expand Down
13 changes: 11 additions & 2 deletions ppsci/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
import sys
from typing import Optional

import colorlog
import paddle.distributed as dist

_logger = None


COLORLOG_CONFIG = {
"DEBUG": "green",
"WARNING": "yellow",
"ERROR": "red",
}


def init_logger(
name: str = "ppsci",
log_file: Optional[str] = None,
Expand Down Expand Up @@ -53,9 +61,10 @@ def init_logger(
_logger.handlers.clear()

# add stream_handler, output to stdout such as terminal
stream_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
stream_formatter = colorlog.ColoredFormatter(
"%(log_color)s[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%Y/%m/%d %H:%M:%S",
log_colors=COLORLOG_CONFIG,
)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(stream_formatter)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"meshio==5.3.4",
"tqdm",
"imageio",
"colorlog",
]

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ tqdm
imageio
typing-extensions
seaborn
colorlog
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,6 @@ def readme():
"meshio==5.3.4",
"tqdm",
"imageio",
"colorlog",
],
)

0 comments on commit 859ac31

Please sign in to comment.