Skip to content

Commit

Permalink
Fix sequential rollout (sjtu-marl#14)
Browse files Browse the repository at this point in the history
* Fix: negative shift for data rolling

* Fix: last step was ignored

* Use svg logo

* Update README

* Fix issue sjtu-marl#8
  • Loading branch information
KornbergFresnel authored Jul 27, 2021
1 parent 96079c3 commit cc468cb
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 41 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

<div align=center><img src="docs/imgs/logo.png" width="35%"></div>
<div align=center><img src="docs/imgs/logo.svg" width="35%"></div>


# MALib: A parallel framework for population-based multi-agent reinforcement learning
Expand Down Expand Up @@ -29,7 +29,7 @@ pip install -e .

External environments are integrated in MALib, such as StarCraftII and vizdoom, you can install them via `pip install -e .[envs]`. For users who wanna contribute to our repository, run `pip install -e .[dev]` to complete the development dependencies.

**optional**: if you wanna use alpha-rank to solve meta-game, install open-spiel with its [installation guides](https://github.com/deepmind/open_spiel)
**NOTE**: if you wanna use alpha-rank (default solver for meta game) to solve meta-game, install open-spiel with its [installation guides](https://github.com/deepmind/open_spiel)

## Quick Start

Expand Down
321 changes: 321 additions & 0 deletions docs/imgs/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/async_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
rollout={
"type": "async",
"stopper": "simple_rollout",
"metric_type": "simple",
"stopper_config" "metric_type": "simple",
"fragment_length": env_config["scenario_configs"]["max_cycles"],
"num_episodes": 100, # episode for each evaluation/training epoch
"terminate": "any",
Expand Down
4 changes: 2 additions & 2 deletions malib/algorithm/dqn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DQNLoss(LossFunc):
def setup_optimizers(self, *args, **kwargs):
self._policy: DQN
if self.optimizers is None:
optim_cls = getattr(torch.optim, self._params.get("optimizer", "Adam"))
optim_cls = getattr(torch.optim, self._params["optimizer"])
self.optimizers = optim_cls(
self.policy.critic.parameters(), lr=self._params["lr"]
)
Expand All @@ -28,7 +28,7 @@ def step(self) -> Any:

gradients = {
"critic": {
name: param.detach().numpy()
name: param.grad.numpy() * self._params["lr"]
for name, param in self.policy.critic.named_parameters()
},
}
Expand Down
11 changes: 5 additions & 6 deletions malib/backend/datapool/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def capacity(self) -> int:
def size(self) -> int:
return self._length

def roll(self, shift: int, axis: int = 0):
self._data[: self._length] = np.roll(
self._data[: self._length], shift, axis=axis
)

def fill(self, data: DataTransferType, capacity: int = None) -> "NumpyDataArray":
"""
Flush fill the array with the input data.
Expand Down Expand Up @@ -214,7 +219,6 @@ def get_data(self) -> DataTransferType:
:return DataTransferType
"""
# FIXME(ming):
indices = np.roll(np.arange(self._length), self._offset)
return self._data[indices]

Expand All @@ -228,11 +232,6 @@ def nbytes(self) -> int:

return self._data.nbytes

def flush(self, data: DataTransferType):
self._length = len(data)
self._data = data.copy()
self._capacity = max(self._length, self._capacity)


class LinkDataArray(DataArray):
"""
Expand Down
27 changes: 9 additions & 18 deletions malib/backend/datapool/offline_dataset_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def fill(self, **kwargs):
self._capacity = max(self._size, self._capacity)

def insert(self, **kwargs):
for column in self.columns:
assert self._size == len(self._data[column]), (
self._size,
{c: len(self._data[c]) for c in self.columns},
)
# for column in self.columns:
# assert self._size == len(self._data[column]), (
# self._size,
# {c: len(self._data[c]) for c in self.columns},
# )
for column in self.columns:
if isinstance(kwargs[column], NumpyDataArray):
assert kwargs[column]._data is not None, f"{column} has empty data"
Expand Down Expand Up @@ -321,19 +321,10 @@ def sample(self, idxes, size) -> Any:

def clean_data(self):
# check length
length = self._data[Episode.CUR_OBS].size
self._data[Episode.NEXT_OBS].flush(
np.roll(self._data[Episode.CUR_OBS]._data[:length], 1, axis=0)
)
self._data[Episode.REWARD].flush(
np.roll(self._data[Episode.REWARD]._data[:length], 1, axis=0)
)
_size = self._data[Episode.CUR_OBS].size
for colum in self.columns:
assert (
_size == self._data[colum].size
), f"Expected size is {_size}, while accpeted {self._data[colum].size} for column={colum}"
self._size = _size
self._data[Episode.NEXT_OBS].insert(self._data[Episode.CUR_OBS].get_data())
self._data[Episode.NEXT_OBS].roll(-1)
self._data[Episode.REWARD].roll(-1)
self._data[Episode.DONE].roll(-1)


class MultiAgentEpisode(Episode):
Expand Down
27 changes: 15 additions & 12 deletions malib/rollout/rollout_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import ray
import numpy as np

from collections import defaultdict

from malib import settings
from malib.utils.logger import get_logger, Log
from malib.utils.metrics import get_metric, Metric
Expand Down Expand Up @@ -54,7 +56,6 @@ def sequential(
evaluated_results = []

assert fragment_length > 0, fragment_length

for ith in range(num_episodes):
env.reset()
metric.reset()
Expand Down Expand Up @@ -82,23 +83,25 @@ def sequential(
action, action_dist, extra_info = agent_interfaces[aid].compute_action(
[observation], **info
)
if aid in agent_episodes:
agent_episodes[aid].insert(
**{
Episode.CUR_OBS: [observation],
Episode.ACTION_MASK: [action_mask],
Episode.ACTION_DIST: action_dist,
Episode.ACTION: action,
Episode.REWARD: reward,
Episode.DONE: done,
}
)
# convert action to scalar
action = action[0]
else:
info["policy_id"] = behavior_policies[aid]
action = None
env.step(action)
if action is None:
action = [agent_interfaces[aid].action_space.sample()]
if aid in agent_episodes:
agent_episodes[aid].insert(
**{
Episode.CUR_OBS: [observation],
Episode.ACTION_MASK: [action_mask],
Episode.ACTION_DIST: action_dist,
Episode.ACTION: action,
Episode.REWARD: reward,
Episode.DONE: done,
}
)
metric.step(
aid,
behavior_policies[aid],
Expand Down

0 comments on commit cc468cb

Please sign in to comment.