Skip to content

Commit

Permalink
【PPSCI Export&Infer No.14】add export and inference function for deepc…
Browse files Browse the repository at this point in the history
…fd (PaddlePaddle#994)

* add export and infer for deepcfd

* reset default mode to train

* remove comments

* some changes

* Update deepcfd.yaml
  • Loading branch information
GoldenStain authored Sep 26, 2024
1 parent 5eeede5 commit 755754b
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 2 deletions.
18 changes: 18 additions & 0 deletions docs/zh/examples/deepcfd.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@
python deepcfd.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/deepcfd/deepcfd_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python deepcfd.py mode=export
```

=== "模型推理命令"

``` sh
# linux
wget -nc -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl
wget -nc -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl
# windows
# curl --create-dirs -o ./datasets/dataX.pkl https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl
# curl --create-dirs -o ./datasets/dataX.pkl https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl
python deepcfd.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [deepcfd_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/deepcfd/deepcfd_pretrained.pdparams) | MSE.Total_MSE(mse_validator): 1.92947<br>MSE.Ux_MSE(mse_validator): 0.70684<br>MSE.Uy_MSE(mse_validator): 0.21337<br>MSE.p_MSE(mse_validator): 1.00926 |
Expand Down
19 changes: 18 additions & 1 deletion examples/deepcfd/conf/deepcfd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ hydra:
subdir: ./

# general settings
mode: train # running mode: train/eval
mode: train # running mode: train/eval/export/infer
seed: 2023
output_dir: ${hydra:run.dir}
log_freq: 20
Expand Down Expand Up @@ -60,3 +60,20 @@ EVAL:
pretrained_model_path: null
eval_with_no_grad: true
batch_size: 8

INFER:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/deepcfd/deepcfd_pretrained.pdparams"
export_path: ./inference/deepcfd
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 6000
gpu_id: 0
max_batch_size: 100
num_cpu_threads: 4
batch_size: 100
255 changes: 254 additions & 1 deletion examples/deepcfd/deepcfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,267 @@ def metric_expr(
predict_and_save_plot(test_x, test_y, 0, solver, PLOT_DIR)


def export(cfg: DictConfig):
model = ppsci.arch.UNetEx(**cfg.MODEL)

solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)

from paddle.static import InputSpec

input_spec = [
{
key: InputSpec(
[None, cfg.CHANNEL_SIZE, cfg.X_SIZE, cfg.Y_SIZE], "float32", name=key
)
for key in model.input_keys
},
]

solver.export(input_spec, cfg.INFER.export_path)
print(f"Model has been exported to {cfg.INFER.export_path}")


def predict_and_save_plot_infer(
x: np.ndarray,
y: np.ndarray,
pred_y: np.ndarray,
index: int,
plot_dir: str,
):
"""Make prediction and save visualization of result during inference.
Args:
x (np.ndarray): Input of test dataset.
y (np.ndarray): Ground truth output of test dataset.
pred_y (np.ndarray): Predicted output from inference.
index (int): Index of data to visualize.
plot_dir (str): Directory to save plot.
"""

# Extract the true and predicted values for each channel
u_true = y[index, 0, :, :]
v_true = y[index, 1, :, :]
p_true = y[index, 2, :, :]

u_pred = pred_y[index, 0, :, :]
v_pred = pred_y[index, 1, :, :]
p_pred = pred_y[index, 2, :, :]

# Compute the absolute error between true and predicted values
error_u = np.abs(u_true - u_pred)
error_v = np.abs(v_true - v_pred)
error_p = np.abs(p_true - p_pred)

# Calculate the min and max values for each channel
min_u, max_u = u_true.min(), u_true.max()
min_v, max_v = v_true.min(), v_true.max()
min_p, max_p = p_true.min(), p_true.max()

min_error_u, max_error_u = error_u.min(), error_u.max()
min_error_v, max_error_v = error_v.min(), error_v.max()
min_error_p, max_error_p = error_p.min(), error_p.max()

# Start plotting
plt.figure(figsize=(15, 10))

# Plot Ux channel (True, Predicted, and Error)
plt.subplot(3, 3, 1)
plt.title("OpenFOAM Ux", fontsize=18)
plt.imshow(
np.transpose(u_true),
cmap="jet",
vmin=min_u,
vmax=max_u,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")
plt.ylabel("Ux", fontsize=18)

plt.subplot(3, 3, 2)
plt.title("DeepCFD Ux", fontsize=18)
plt.imshow(
np.transpose(u_pred),
cmap="jet",
vmin=min_u,
vmax=max_u,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

plt.subplot(3, 3, 3)
plt.title("Error Ux", fontsize=18)
plt.imshow(
np.transpose(error_u),
cmap="jet",
vmin=min_error_u,
vmax=max_error_u,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

# Plot Uy channel (True, Predicted, and Error)
plt.subplot(3, 3, 4)
plt.imshow(
np.transpose(v_true),
cmap="jet",
vmin=min_v,
vmax=max_v,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")
plt.ylabel("Uy", fontsize=18)

plt.subplot(3, 3, 5)
plt.imshow(
np.transpose(v_pred),
cmap="jet",
vmin=min_v,
vmax=max_v,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

plt.subplot(3, 3, 6)
plt.imshow(
np.transpose(error_v),
cmap="jet",
vmin=min_error_v,
vmax=max_error_v,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

# Plot pressure channel p (True, Predicted, and Error)
plt.subplot(3, 3, 7)
plt.imshow(
np.transpose(p_true),
cmap="jet",
vmin=min_p,
vmax=max_p,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")
plt.ylabel("p", fontsize=18)

plt.subplot(3, 3, 8)
plt.imshow(
np.transpose(p_pred),
cmap="jet",
vmin=min_p,
vmax=max_p,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

plt.subplot(3, 3, 9)
plt.imshow(
np.transpose(error_p),
cmap="jet",
vmin=min_error_p,
vmax=max_error_p,
origin="lower",
extent=[0, 260, 0, 120],
)
plt.colorbar(orientation="horizontal")

plt.tight_layout()
plt.savefig(os.path.join(plot_dir, f"cfd_{index}.png"), bbox_inches="tight")
plt.close()


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

# Load test dataset from serialized files
with open(cfg.DATAX_PATH, "rb") as file:
x = pickle.load(file)
with open(cfg.DATAY_PATH, "rb") as file:
y = pickle.load(file)

# Split data into training and test sets
_, test_dataset = split_tensors(x, y, ratio=cfg.SLIPT_RATIO)
test_x, test_y = test_dataset

input_dict = {cfg.MODEL.input_key: test_x}

# Initialize the PINN predictor model
predictor = pinn_predictor.PINNPredictor(cfg)

# Run inference and get predictions
output_dict = predictor.predict(input_dict, batch_size=cfg.INFER.batch_size)

# Handle model's output key structure
actual_output_key = cfg.MODEL.output_key

output_keys = (
actual_output_key
if isinstance(actual_output_key, (list, tuple))
else [actual_output_key]
)
if len(output_keys) != len(output_dict):
raise ValueError(
"The number of output_keys does not match the number of output_dict keys."
)

# Map model output keys to values
output_dict = {
origin: value for origin, value in zip(output_keys, output_dict.values())
}

concat_output = output_dict[actual_output_key]

if concat_output.ndim != 4 or concat_output.shape[1] != 3:
raise ValueError(
f"Unexpected shape of '{actual_output_key}': {concat_output.shape}. Expected (batch_size, 3, x_size, y_size)."
)

try:
# Extract Ux, Uy, and pressure from the predicted output
u_pred = concat_output[:, 0, :, :] # Ux
v_pred = concat_output[:, 1, :, :] # Uy
p_pred = concat_output[:, 2, :, :] # p
except IndexError as e:
print(f"Error in splitting '{actual_output_key}': {e}")
raise

# Combine the predictions into one array for further processing
pred_y = np.stack([u_pred, v_pred, p_pred], axis=1)

PLOT_DIR = os.path.join(cfg.output_dir, "infer_visual")
os.makedirs(PLOT_DIR, exist_ok=True)

# Visualize and save the first five predictions
for index in range(min(5, pred_y.shape[0])):
predict_and_save_plot_infer(test_x, test_y, pred_y, index, PLOT_DIR)

print(f"Inference completed. Results are saved in {PLOT_DIR}")


@hydra.main(version_base=None, config_path="./conf", config_name="deepcfd.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down

0 comments on commit 755754b

Please sign in to comment.