Skip to content

Commit

Permalink
added aec waterworld rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
benblack769 committed Dec 1, 2020
1 parent 4282893 commit 3a9526c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions pettingzoo/sisl/waterworld/waterworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def step(self, action):

is_last = self._agent_selector.is_last()
self.env.step(action, self.agent_name_mapping[agent], is_last)

for r in self.rewards:
self.rewards[r] = self.env.control_rewards[self.agent_name_mapping[r]]
if is_last:
for r in self.rewards:
self.rewards[r] = self.env.last_rewards[self.agent_name_mapping[r]]
else:
self._clear_rewards()
self.rewards[r] += self.env.last_rewards[self.agent_name_mapping[r]]

if self.env.frames >= self.env.max_cycles:
self.dones = dict(zip(self.agents, [True for _ in self.agents]))
Expand Down
7 changes: 4 additions & 3 deletions pettingzoo/sisl/waterworld/waterworld_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ def step(self, action, agent_id, is_last):
poison.position[i] = np.clip(poison.position[i], 0, 1)
poison.velocity[i] = -1 * poison.velocity[i]

self.control_rewards[agent_id] = self.control_penalty * (action**2).sum()
control_reward = self.control_penalty * (action**2).sum()
self.control_rewards = (control_reward / self.n_pursuers) * np.ones(self.n_pursuers) * (1 - self.local_ratio)
self.control_rewards[agent_id] += control_reward * self.local_ratio

if is_last:
rewards = np.zeros(self.n_pursuers)
Expand All @@ -524,11 +526,10 @@ def step(self, action, agent_id, is_last):
sensorfeatures_Np_K_O, is_colliding_ev_Np_Ne, is_colliding_po_Np_Npo)
self.last_obs = obs_list

local_reward = rewards + np.array(self.control_rewards)
local_reward = rewards
global_reward = local_reward.mean()
self.last_rewards = local_reward * self.local_ratio + global_reward * (1 - self.local_ratio)

self.control_rewards = [0 for _ in range(self.n_pursuers)]
self.frames += 1

return self.observe(agent_id)
Expand Down

0 comments on commit 3a9526c

Please sign in to comment.