Skip to content

Commit

Permalink
[Example]fix bug of transformer4sr (PaddlePaddle#1017)
Browse files Browse the repository at this point in the history
* [Example]fix bug of transformer4sr

* Update examples/transformer4sr/conf/transformer4sr.yaml

---------

Co-authored-by: HydrogenSulfate <[email protected]>
  • Loading branch information
lijialin03 and HydrogenSulfate authored Nov 14, 2024
1 parent 474bf7b commit 615686a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/transformer4sr/conf/transformer4sr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DATA_GENERATE:
num_zfill: 8
DATA:
data_path: "./data_generated/" # ${DATA_GENERATE.data_path}
data_path_srsd: ["./srsd-feynman_easy/", "./srsd-feynman_medium/", "./srsd-feynman_hard/"]
data_path_srsd: ["./srsd-feynman_easy/"]
ratio: [0.8,0.1,0.1]
sampling_times: ${DATA_GENERATE.sampling_times}
seq_length_max: 30 # ${DATA_GENERATE.seq_length_max}
Expand Down
3 changes: 1 addition & 2 deletions examples/transformer4sr/transformer4sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def evaluate(cfg: DictConfig):
vocab_size=vocab_size,
seq_length=data_funcs.seq_length_max,
)
param_dict = paddle.load(cfg.EVAL.pretrained_model_path)
model.set_state_dict(param_dict)
ppsci.utils.save_load.load_pretrain(model, path=cfg.EVAL.pretrained_model_path)
model.eval()

# evaluate
Expand Down
1 change: 0 additions & 1 deletion ppsci/arch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ def __init__(
self.act = act
self.dropout = dropout

super().__init__()
self.encoder = Encoder(
num_layers_enc, num_var_max, d_model, heads, act="relu", dropout=dropout
)
Expand Down

0 comments on commit 615686a

Please sign in to comment.