Skip to content

Commit

Permalink
Address sample_sequence_length greater than min_length_time_axis (#…
Browse files Browse the repository at this point in the history
…45)

* Set min_length_time_axis correctly in tests

* Fix remaining sample seq len > min len time axis

* Replace add_batch_size with min_length

* fixup! Fix remaining sample seq len > min len time axis
  • Loading branch information
mickvangelderen authored Jan 14, 2025
1 parent 1352bfa commit f4aa2eb
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 24 deletions.
9 changes: 4 additions & 5 deletions flashbax/buffers/flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_sample(
rng_key1, rng_key2 = jax.random.split(rng_key)

# Fill buffer to the point that we can sample
fake_batch = get_fake_batch(fake_transition, int(min_length + 10))
fake_batch = get_fake_batch(fake_transition, min_length)

buffer = flat_buffer.make_flat_buffer(
max_length, min_length, sample_batch_size, False, int(min_length + 10)
max_length, min_length, sample_batch_size, False, add_batch_size=min_length
)
state = buffer.init(fake_transition)

Expand Down Expand Up @@ -222,9 +222,8 @@ def test_flat_replay_buffer_does_not_smoke(
):
"""Create the FlatBuffer NamedTuple, and check that it is pmap-able and does not smoke."""

add_batch_size = int(min_length + 5)
buffer = flat_buffer.make_flat_buffer(
max_length, min_length, sample_batch_size, False, add_batch_size
max_length, min_length, sample_batch_size, False, add_batch_size=min_length
)

# Initialise the buffer's state.
Expand All @@ -236,7 +235,7 @@ def test_flat_replay_buffer_does_not_smoke(
# Now fill the buffer above its minimum length.

fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)(
fake_transition_per_device, add_batch_size
fake_transition_per_device, min_length
)
# Add two items thereby giving a single transition.
state = jax.pmap(buffer.add)(state, fake_batch)
Expand Down
8 changes: 4 additions & 4 deletions flashbax/buffers/mixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_mixed_trajectory_sample(
for i in range(3):
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=200 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_mixed_prioritised_trajectory_sample(
for i in range(3):
buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer(
max_length_time_axis=200 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_mixed_flat_buffer_sample(
for i in range(3):
buffer = flat_buffer.make_flat_buffer(
max_length=200 * (i + 1),
min_length=0,
min_length=add_batch_size,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
add_sequences=True,
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_mixed_buffer_does_not_smoke(
for i in range(3):
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=2000 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down
20 changes: 8 additions & 12 deletions flashbax/buffers/prioritised_flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,11 @@ def test_sample(
"""Test the random sampling from the buffer."""
rng_key1, rng_key2 = jax.random.split(rng_key)

add_batch_size = int(min_length + 10)
# Fill buffer to the point that we can sample
fake_batch = get_fake_batch(fake_transition, add_batch_size)
fake_batch = get_fake_batch(fake_transition, min_length)

buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length, min_length, sample_batch_size, False, add_batch_size
max_length, min_length, sample_batch_size, False, add_batch_size=min_length
)
state = buffer.init(fake_transition)

Expand Down Expand Up @@ -132,16 +131,15 @@ def test_adjust_priorities(
"""Test the adjustment of priorities in the buffer."""
rng_key1, rng_key2 = jax.random.split(rng_key)

add_batch_size = int(min_length + 10)
# Fill buffer to the point that we can sample.
fake_batch = get_fake_batch(fake_transition, add_batch_size)
fake_batch = get_fake_batch(fake_transition, min_length)
buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length,
min_length,
sample_batch_size,
False,
add_batch_size,
priority_exponent,
add_batch_size=min_length,
priority_exponent=priority_exponent,
)
state = buffer.init(fake_transition)

Expand Down Expand Up @@ -175,15 +173,13 @@ def test_prioritised_flat_buffer_does_not_smoke(
):
"""Create the FlatBuffer NamedTuple, and check that it is pmap-able and does not smoke."""

add_batch_size = int(min_length + 5)

buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length,
min_length,
sample_batch_size,
False,
add_batch_size,
priority_exponent,
add_batch_size=min_length,
priority_exponent=priority_exponent,
)

# Initialise the buffer's state.
Expand All @@ -195,7 +191,7 @@ def test_prioritised_flat_buffer_does_not_smoke(
# Now fill the buffer above its minimum length.

fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)(
fake_transition_per_device, add_batch_size
fake_transition_per_device, min_length
)
# Add two items thereby giving a single transition.
state = jax.pmap(buffer.add)(state, fake_batch)
Expand Down
4 changes: 2 additions & 2 deletions flashbax/buffers/trajectory_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_add_sample_max_capacity(
sample_sequence_length = add_sequence_length
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=add_sequence_length,
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_uniform_index_cal(

buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=max_length,
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ include = ["flashbax/*"]
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
"ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
"ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
]
Expand Down

0 comments on commit f4aa2eb

Please sign in to comment.