From 1fdb7e8fbe85809f6658e2e9bb9191700b8352af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E4=BA=BF?= <18222860970@163.com> Date: Thu, 21 Apr 2022 14:45:08 +0800 Subject: [PATCH] Update README.md --- offline-rl-algorithms/README.md | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/offline-rl-algorithms/README.md b/offline-rl-algorithms/README.md index d0d3869..c4c4a68 100644 --- a/offline-rl-algorithms/README.md +++ b/offline-rl-algorithms/README.md @@ -92,6 +92,86 @@ To clone this repo: git clone git@github.com:TJU-DRL-LAB/offline-rl-algorithms.git ``` +## User Guidance +Here we introduce how to configure your own dataset and modify the algorithm based on your own design. + +### Dataset +``` +# Rewrite tjuOfflineRL.get_dataset.py to add get_your_data_function in get_dataset function. +def get_dataset( + env_name: str, create_mask: bool = False, mask_size: int = 1 +) -> Tuple[MDPDataset, gym.Env]: + + if env_name == "existing datasets": + return get_existing_datasets() + elif env_name == "your own datasets": + return get_your_data_function() + raise ValueError(f"Unrecognized env_name: {env_name}.") + +# Load your datasets and transform then into MDPDataset format +def get_your_data_function(): + observations = [] + actions = [] + rewards = [] + terminals = [] + episode_terminals = [] + episode_step = 0 + cursor = 0 + dataset_size = dataset["observations"].shape[0] + while cursor < dataset_size: + # collect data for step=t + observation = dataset["observations"][cursor] + action = dataset["actions"][cursor] + if episode_step == 0: + reward = 0.0 + else: + reward = dataset["rewards"][cursor - 1] + + observations.append(observation) + actions.append(action) + rewards.append(reward) + terminals.append(0.0) + + # skip adding the last step when timeout + if dataset["timeouts"][cursor]: + episode_terminals.append(1.0) + episode_step = 0 + cursor += 1 + continue + + episode_terminals.append(0.0) + episode_step += 1 + + if dataset["terminals"][cursor]: + # collect data for step=t+1 + dummy_observation = observation.copy() + dummy_action = action.copy() + next_reward = dataset["rewards"][cursor] + + # the last observation is rarely used + observations.append(dummy_observation) + actions.append(dummy_action) + rewards.append(next_reward) + terminals.append(1.0) + episode_terminals.append(1.0) + episode_step = 0 + + cursor += 1 + + mdp_dataset = MDPDataset( + observations=np.array(observations, dtype=np.float32), + actions=np.array(actions, dtype=np.float32), + rewards=np.array(rewards, dtype=np.float32), + terminals=np.array(terminals, dtype=np.float32), + episode_terminals=np.array(episode_terminals, dtype=np.float32), + create_mask=create_mask, + mask_size=mask_size, + ) + + return mdp_dataset, env + +``` +### Modify Algorithm ## TODO