Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mamengyiyi authored Apr 21, 2022
1 parent 422dc64 commit 1fdb7e8
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions offline-rl-algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,86 @@ To clone this repo:
git clone [email protected]:TJU-DRL-LAB/offline-rl-algorithms.git
```

## User Guidance
Here we introduce how to configure your own dataset and modify the algorithm based on your own design.

### Dataset
```
# Rewrite tjuOfflineRL.get_dataset.py to add get_your_data_function in get_dataset function.
def get_dataset(
env_name: str, create_mask: bool = False, mask_size: int = 1
) -> Tuple[MDPDataset, gym.Env]:
if env_name == "existing datasets":
return get_existing_datasets()
elif env_name == "your own datasets":
return get_your_data_function()
raise ValueError(f"Unrecognized env_name: {env_name}.")
# Load your datasets and transform then into MDPDataset format
def get_your_data_function():
observations = []
actions = []
rewards = []
terminals = []
episode_terminals = []
episode_step = 0
cursor = 0
dataset_size = dataset["observations"].shape[0]
while cursor < dataset_size:
# collect data for step=t
observation = dataset["observations"][cursor]
action = dataset["actions"][cursor]
if episode_step == 0:
reward = 0.0
else:
reward = dataset["rewards"][cursor - 1]
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminals.append(0.0)
# skip adding the last step when timeout
if dataset["timeouts"][cursor]:
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
continue
episode_terminals.append(0.0)
episode_step += 1
if dataset["terminals"][cursor]:
# collect data for step=t+1
dummy_observation = observation.copy()
dummy_action = action.copy()
next_reward = dataset["rewards"][cursor]
# the last observation is rarely used
observations.append(dummy_observation)
actions.append(dummy_action)
rewards.append(next_reward)
terminals.append(1.0)
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
mdp_dataset = MDPDataset(
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
terminals=np.array(terminals, dtype=np.float32),
episode_terminals=np.array(episode_terminals, dtype=np.float32),
create_mask=create_mask,
mask_size=mask_size,
)
return mdp_dataset, env
```
### Modify Algorithm


## TODO
Expand Down

0 comments on commit 1fdb7e8

Please sign in to comment.