Skip to content

Commit

Permalink
fix Issue 207 and update evaluation results of PPO (#211)
Browse files Browse the repository at this point in the history
* fix issue 207

* update evaluation results of PPO
  • Loading branch information
aaa123git authored Jul 30, 2021
1 parent 93bb6c2 commit f822c20
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Performance (the first row is the default config for each module. Empty entries
| BERTNLU | RuleDST | RulePolicy | **SCLSTM** | 48.5 | 40.2 | 56.9 | 62.3/62.5/58.7 | 11.9/27.1 |
| BERTNLU | RuleDST | **MLEPolicy** | TemplateNLG | 42.7 | 35.9 | 17.6 | 62.8/69.8/62.9 | 12.1/24.1 |
| BERTNLU | RuleDST | **PGPolicy** | TemplateNLG | 37.4 | 31.7 | 17.4 | 57.4/63.7/56.9 | 11.0/25.3 |
| BERTNLU | RuleDST | **PPOPolicy** | TemplateNLG | 61.1 | 44.0 | 44.6 | 63.9/76.8/67.2 | 12.5/20.8 |
| BERTNLU | RuleDST | **PPOPolicy** | TemplateNLG | 75.5 | 71.7 | 86.6 | 69.4/85.8/74.1 | 13.1/17.8 |
| BERTNLU | RuleDST | **GDPLPolicy** | TemplateNLG | 49.4 | 38.4 | 20.1 | 64.5/73.8/65.6 | 11.5/21.3 |
| None | **TRADE** | RulePolicy | TemplateNLG | 32.4 | 20.1 | 34.7 | 46.9/48.5/44.0 | 11.4/23.9 |
| None | **SUMBT** | RulePolicy | TemplateNLG | 34.5 | 29.4 | 62.4 | 54.1/50.3/48.3 | 11.0/28.1 |
Expand Down Expand Up @@ -158,7 +158,7 @@ By running `convlab2/policy/evalutate.py --model_name $model`
| --------- | ----------------- |
| MLE | 0.56 |
| PG | 0.54 |
| PPO | 0.74 |
| PPO | 0.89 |
| GDPL | 0.58 |

### NLG
Expand Down
9 changes: 4 additions & 5 deletions convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,10 @@ def _update_current_domain(self, sys_action, goal: Goal):
self.cur_domain = domain

def _setdefault_current_domain_by_usraction(self, usr_action):
if self.cur_domain is None:
for diaact in usr_action.keys():
domain, _ = diaact.split('-')
if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']:
self.cur_domain = domain
for diaact in usr_action.keys():
domain, _ = diaact.split('-')
if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']:
self.cur_domain = domain

def _remove_item(self, diaact, slot=DEF_VAL_UNK):
for idx in range(len(self.__stack)):
Expand Down

0 comments on commit f822c20

Please sign in to comment.