Skip to content

Commit

Permalink
MADDPG critic should take all observations
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqiqbal2810 committed Feb 22, 2018
1 parent b5224c0 commit e43f45b
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def update(self, sample, agent_i, parallel=False, logger=None):
else:
all_trgt_acs = [pi(nobs) for pi, nobs in zip(self.target_policies,
next_obs)]
trgt_vf_in = torch.cat((next_obs[agent_i], *all_trgt_acs), dim=1)
trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
else: # DDPG
if self.discrete_action:
trgt_vf_in = torch.cat((next_obs[agent_i],
Expand All @@ -116,7 +116,7 @@ def update(self, sample, agent_i, parallel=False, logger=None):
(1 - dones[agent_i].view(-1, 1)))

if self.alg_types[agent_i] == 'MADDPG':
vf_in = torch.cat((obs[agent_i], *acs), dim=1)
vf_in = torch.cat((*obs, *acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
actual_value = curr_agent.critic(vf_in)
Expand Down Expand Up @@ -145,12 +145,11 @@ def update(self, sample, agent_i, parallel=False, logger=None):
for i, pi, ob in zip(range(self.nagents), self.policies, obs):
if i == agent_i:
all_pol_acs.append(curr_pol_out)
elif self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
if self.discrete_action:
all_pol_acs.append(onehot_from_logits(pi(ob)))
else:
all_pol_acs.append(pi(ob))
vf_in = torch.cat((obs[agent_i], *all_pol_acs), dim=1)
all_pol_acs.append(pi(ob))
vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
else: # DDPG
vf_in = torch.cat((obs[agent_i], curr_pol_out),
dim=1)
Expand Down Expand Up @@ -243,12 +242,14 @@ def init_from_env(cls, env, agent_alg="MADDPG", adversary_alg="MADDPG",
discrete_action = True
get_shape = lambda x: x.n
num_out_pol = get_shape(acsp)
num_in_critic = obsp.shape[0]
if algtype == "MADDPG":
num_in_critic = 0
for oobsp in env.observation_space:
num_in_critic += oobsp.shape[0]
for oacsp in env.action_space:
num_in_critic += get_shape(oacsp)
else:
num_in_critic += get_shape(acsp)
num_in_critic = obsp.shape[0] + get_shape(acsp)
agent_init_params.append({'num_in_pol': num_in_pol,
'num_out_pol': num_out_pol,
'num_in_critic': num_in_critic})
Expand Down

0 comments on commit e43f45b

Please sign in to comment.