Skip to content

Commit

Permalink
Buffered Batch Iterator + Tests (minerllabs#535)
Browse files Browse the repository at this point in the history
* Buffered batch iter and tests

* Appease pep8

* Remove excess printouts and change return logic to finish all batches

* Add docstrings and other documentation to buffered batch iterator method

* Include data in test_build_deloy script

Co-authored-by: Brandon Houghton <[email protected]>
  • Loading branch information
decodyng and brandonhoughton authored Jun 25, 2021
1 parent 94feb40 commit 9e489d7
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 9 deletions.
19 changes: 15 additions & 4 deletions docs/source/tutorials/data_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,24 @@ Sampling the Dataset with :code:`batch_iter`

Now we can build the dataset for :code:`MineRLObtainDiamond-v0`

.. code-block:: python
There are two ways of sampling from the MineRL dataset: the deprecated but still supported `batch_iter`, and
`buffered_batch_iter`. `batch_iter` is the legacy method, which we've kept in the code to avoid breaking changes,
but we have recently realized that, when using `batch_size > 1`, `batch_iter` can fail to return a substantial
portion of the data in the epoch.

**If you are not already using `data_pipeline.batch_iter`, we recommend against it, because of these issues"

data = minerl.data.make('MineRLObtainDiamond-v0')
The recommended way of sampling from the dataset is:

.. code-block:: python
from minerl.data import BufferedBatchIter
data = minerl.data.make(
'MineRLObtainDiamond-v0')
iterator = BufferedBatchIter(data)
for current_state, action, reward, next_state, done \
in data.batch_iter(
batch_size=1, num_epochs=1, seq_len=32):
in iterator.buffered_batch_iter(batch_size=1, num_epochs=1):
# Print the POV @ the first step of the sequence
print(current_state['pov'][0])
Expand Down
1 change: 1 addition & 0 deletions minerl/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from minerl.data.data_pipeline import DataPipeline
from minerl.data.buffered_batch_iter import BufferedBatchIter
from minerl.data.download import download
import os

Expand Down
145 changes: 145 additions & 0 deletions minerl/data/buffered_batch_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import minerl
import os
import time
from copy import deepcopy
import numpy as np
from minerl.data.util import multimap
import random

MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT')


def stack(*args):
return np.stack(args)


class BufferedBatchIter:
"""
A class that maintains and exposes an iterator which loads trajectories into a
configurably-sized buffer, samples batches from that buffer, and refills the buffer
when necessary.
"""
def __init__(self,
data_pipeline,
buffer_target_size=50000):
"""
Args:
data_pipeline: A data pipeline object that you want to construct an iterator from
buffer_target_size: How large you'd like your data buffer to be (in units of timesteps)
Note that this is not an exact cap, since we don't know how large a trajectory will be
until we load it in. This implementation tries to maintain a buffer size by keeping
track of the average size of trajectories in this data pipeline, and loading a new
trajectory when the size of the buffer is more than <average_size> below the target
"""
self.data_pipeline = data_pipeline
self.data_buffer = []
self.buffer_target_size = buffer_target_size
self.traj_sizes = []
self.avg_traj_size = 0
self.all_trajectories = self.data_pipeline.get_trajectory_names()
# available_trajectories is a dynamic, per-epoch list that will keep track of
# which trajectories we haven't yet used in a given epoch
self.available_trajectories = deepcopy(self.all_trajectories)
random.shuffle(self.available_trajectories)

def optionally_fill_buffer(self):
"""
This method is run after every batch, but only actually executes a buffer
refill and re-shuffle if more data is needed
"""
buffer_updated = False

# Add trajectories to the buffer if the remaining space is
# greater than our anticipated trajectory size (in the form of the empirical average)
while (self.buffer_target_size - len(self.data_buffer)) > self.avg_traj_size:
if len(self.available_trajectories) == 0:
return
traj_to_load = self.available_trajectories.pop()
data_loader = self.data_pipeline.load_data(traj_to_load)
traj_len = 0
for data_tuple in data_loader:
traj_len += 1
self.data_buffer.append(data_tuple)

self.traj_sizes.append(traj_len)
self.avg_traj_size = np.mean(self.traj_sizes)
buffer_updated = True
if buffer_updated:
random.shuffle(self.data_buffer)

def get_batch(self, batch_size):
"""A simple utility method for constructing a return batch in the expected format"""
ret_dict_list = []
for _ in range(batch_size):
data_tuple = self.data_buffer.pop()
ret_dict = dict(obs=data_tuple[0],
act=data_tuple[1],
reward=data_tuple[2],
next_obs=data_tuple[3],
done=data_tuple[4])
ret_dict_list.append(ret_dict)
return multimap(stack, *ret_dict_list)

def buffered_batch_iter(self, batch_size, num_epochs=None, num_batches=None):
"""
The actual generator method that returns batches. You can specify either
a desired number of batches, or a desired number of epochs, but not both,
since they might conflict.
** You must specify one or the other **
Args:
batch_size: The number of transitions/timesteps to be returned in each batch
num_epochs: Optional, how many full passes through all trajectories to return
num_batches: Optional, how many batches to return
"""
assert num_batches is not None or num_epochs is not None, "One of num_epochs or " \
"num_batches must be non-None"
assert num_batches is None or num_epochs is None, "You cannot specify both " \
"num_batches and num_epochs"

epoch_count = 0
batch_count = 0

while True:
# If we've hit the desired number of epochs
if num_epochs is not None and epoch_count >= num_epochs:
return
# If we've hit the desired number of batches
if num_batches is not None and batch_count >= num_batches:
return
# Refill the buffer if we need to
# (doing this before getting batch so it'll run on the first iteration)
self.optionally_fill_buffer()
ret_batch = self.get_batch(batch_size=batch_size)
batch_count += 1
if len(self.data_buffer) < batch_size:
assert len(self.available_trajectories) == 0, "You've reached the end of your " \
"data buffer while still having " \
"trajectories available; " \
"something seems to have gone wrong"
epoch_count += 1
self.available_trajectories = deepcopy(self.all_trajectories)
random.shuffle(self.available_trajectories)

keys = ('obs', 'act', 'reward', 'next_obs', 'done')
yield tuple([ret_batch[key] for key in keys])


if __name__ == "__main__":

env = "MineRLBasaltMakeWaterfall-v0"
test_batch_size = 32

start_time = time.time()
data_pipeline = minerl.data.make(env, MINERL_DATA_ROOT)
bbi = BufferedBatchIter(data_pipeline, buffer_target_size=10000)
num_timesteps = 0
for data_dict in bbi.buffered_batch_iter(batch_size=test_batch_size, num_epochs=1):
num_timesteps += len(data_dict['obs']['pov'])

print(f"{num_timesteps} found for env {env} using batch_iter")
end_time = time.time()
print(f"Total time: {end_time - start_time} seconds")
2 changes: 1 addition & 1 deletion scripts/test_build_deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mkdir -p "$MINERL_DATA_ROOT"
pip install .

# Copy data to the ci machines if needed for tests
#az storage copy -s $AZ_MINERL_DATA -d $MINERL_DATA_ROOT --recursive --subscription sci
az storage copy -s $AZ_MINERL_DATA -d $MINERL_DATA_ROOT --recursive --subscription sci

# Note tests that lauch Minecraft MUST be marked serial via the "@pytest.mark.serial" annotation
pytest . -n 4
Expand Down
26 changes: 22 additions & 4 deletions tests/unit/test_batch_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,34 @@
import logging
import numpy as np
import tqdm
from minerl.data import BufferedBatchIter


def _test_batch_iter():
def test_batch_iter():
dat = minerl.data.make('MineRLTreechopVectorObf-v0')

act_vectors = []
i = 0
for _ in tqdm.tqdm(dat.batch_iter(1, 32, 1, preload_buffer_size=2)):
for (current_state, action,
reward, next_state, done) in tqdm.tqdm(dat.batch_iter(1, 32, 1, preload_buffer_size=2)):
i += 1
if i > 100:
# assert False
break
pass


def test_buffered_batch_iter():
dat = minerl.data.make('MineRLTreechopVectorObf-v0')
bbi = BufferedBatchIter(dat)
i = 0
for (current_state, action,
reward, next_state, done) in tqdm.tqdm(bbi.buffered_batch_iter(batch_size=10,
num_batches=200)):

print(current_state['pov'][0])
print(reward[-1])
print(done[-1])
i += 1
print(_)
if i > 100:
# assert False
break
Expand Down

0 comments on commit 9e489d7

Please sign in to comment.